mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Reintroduce with perf fixes: feature: unify new_tokens format sample state to trtllm samper tokens format (#5513)
58a8a8f- these changes were previously merged to main here.6aef149- the changes were temporarily reverted in main, due to a significant perf regression in models using the TorchSampler (observed by @byshiue). This PR is meant to re-merge these changes along with a fix to prevent the regression. The first commit of this PR is actually just the reverted revert - filter it out of the changes to see previously unmerged changes. Signed-off-by: Netanel Haber <nhaber@nvidia.com>
This commit is contained in:
parent
f28cd3056e
commit
6ee94c7ac8
@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
|
||||
import torch
|
||||
from torch._prims_common import DeviceLikeType
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
|
||||
from tensorrt_llm._utils import nvtx_range
|
||||
|
||||
from ...._utils import mpi_rank, mpi_world_size
|
||||
@ -256,6 +257,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
|
||||
assert isinstance(executor_config.pytorch_backend_config, LlmArgs), msg
|
||||
ad_config: LlmArgs = executor_config.pytorch_backend_config
|
||||
|
||||
max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size
|
||||
# some derivative properties
|
||||
max_draft_tokens = (
|
||||
0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_tokens
|
||||
@ -272,7 +274,13 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
|
||||
max_seq_len=ad_config.max_seq_len,
|
||||
max_batch_size=ad_config.max_batch_size,
|
||||
)
|
||||
resource_manager = ResourceManager({ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})
|
||||
seq_slot_manager = SeqSlotManager(max_num_sequences=max_num_sequences)
|
||||
resource_manager = ResourceManager(
|
||||
{
|
||||
ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager,
|
||||
ResourceManagerType.SEQ_SLOT_MANAGER: seq_slot_manager,
|
||||
}
|
||||
)
|
||||
resource_manager.resource_managers.move_to_end(ResourceManagerType.KV_CACHE_MANAGER, last=True)
|
||||
|
||||
# scheduling
|
||||
@ -287,10 +295,14 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
|
||||
# https://github.com/NVIDIA/TensorRT-LLM/issues/5254
|
||||
# We should expose mixed_sample to our build_and_run_ad script so we can configure this
|
||||
# correctly for models as needed.
|
||||
sampler = TorchSampler(
|
||||
sampler_args = TorchSampler.Args(
|
||||
max_seq_len=ad_config.max_seq_len,
|
||||
max_draft_tokens=max_draft_tokens,
|
||||
max_num_sequences=max_num_sequences,
|
||||
max_beam_width=executor_config.max_beam_width,
|
||||
mixed_sampler=ad_config.mixed_sampler,
|
||||
)
|
||||
sampler = TorchSampler(sampler_args)
|
||||
|
||||
# creating the executor object
|
||||
py_executor = PyExecutor(
|
||||
@ -299,6 +311,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
|
||||
model_engine=engine,
|
||||
sampler=sampler,
|
||||
dist=mpi_dist,
|
||||
max_num_sequences=max_num_sequences,
|
||||
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
|
||||
max_input_len=ad_config.max_input_len,
|
||||
max_batch_size=ad_config.max_batch_size,
|
||||
|
||||
@ -26,8 +26,7 @@ from .py_executor import PyExecutor
|
||||
from .resource_manager import (KVCacheManager, MambaHybridCacheManager,
|
||||
PeftCacheManager, ResourceManager,
|
||||
ResourceManagerType)
|
||||
from .sampler import (EarlyStopSampler, TorchSampler, TorchStarAttentionSampler,
|
||||
TRTLLMSampler)
|
||||
from .sampler import EarlyStopSampler, TorchSampler, TRTLLMSampler
|
||||
from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler,
|
||||
SimpleScheduler)
|
||||
from .seq_slot_manager import SeqSlotManager
|
||||
@ -514,6 +513,7 @@ def create_py_executor_instance(
|
||||
sampler=sampler,
|
||||
drafter=drafter,
|
||||
dist=dist,
|
||||
max_num_sequences=max_num_sequences,
|
||||
disable_overlap_scheduler=pytorch_backend_config.
|
||||
disable_overlap_scheduler,
|
||||
max_batch_size=executor_config.max_batch_size,
|
||||
@ -525,27 +525,44 @@ def create_py_executor_instance(
|
||||
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)
|
||||
|
||||
|
||||
def instantiate_sampler(model_engine: PyTorchModelEngine,
|
||||
def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,
|
||||
*, max_seq_len: int, mixed_sampler: bool):
|
||||
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
|
||||
max_draft_tokens = (0 if executor_config.speculative_config is None else
|
||||
executor_config.speculative_config.max_draft_tokens)
|
||||
return TorchSampler.Args(
|
||||
max_seq_len=max_seq_len,
|
||||
max_draft_tokens=max_draft_tokens,
|
||||
max_num_sequences=max_num_sequences,
|
||||
max_beam_width=executor_config.max_beam_width,
|
||||
mixed_sampler=mixed_sampler,
|
||||
)
|
||||
|
||||
|
||||
def instantiate_sampler(engine: PyTorchModelEngine,
|
||||
executor_config: ExecutorConfig,
|
||||
pytorch_backend_config: PyTorchConfig,
|
||||
mapping: Mapping):
|
||||
sampler_args = create_torch_sampler_args(
|
||||
executor_config,
|
||||
mapping,
|
||||
max_seq_len=engine.max_seq_len,
|
||||
mixed_sampler=pytorch_backend_config.mixed_sampler)
|
||||
if mapping.cp_config.get('cp_type') == 'star_attention':
|
||||
assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
|
||||
return TorchStarAttentionSampler(max_seq_len=model_engine.max_seq_len)
|
||||
spec_config = model_engine.spec_config
|
||||
if spec_config is not None and spec_config.spec_dec_mode.has_spec_decoder():
|
||||
return get_spec_decoder(max_seq_len=model_engine.max_seq_len,
|
||||
spec_config=spec_config)
|
||||
return TorchSampler(sampler_args)
|
||||
if engine.spec_config is not None and engine.spec_config.spec_dec_mode.has_spec_decoder(
|
||||
):
|
||||
return get_spec_decoder(sampler_args, engine.spec_config)
|
||||
if pytorch_backend_config.enable_trtllm_sampler:
|
||||
return TRTLLMSampler(executor_config, model_engine.model,
|
||||
model_engine.dtype, mapping,
|
||||
get_decoding_mode(executor_config),
|
||||
decoding_mode = get_decoding_mode(executor_config)
|
||||
return TRTLLMSampler(executor_config, engine.model, engine.dtype,
|
||||
mapping, decoding_mode,
|
||||
pytorch_backend_config.disable_overlap_scheduler)
|
||||
elif not model_engine.model.model_config.is_generation:
|
||||
if not engine.model.model_config.is_generation:
|
||||
# NOTE: choose sampler based on model type
|
||||
return EarlyStopSampler()
|
||||
return TorchSampler(max_seq_len=model_engine.max_seq_len,
|
||||
mixed_sampler=pytorch_backend_config.mixed_sampler)
|
||||
return TorchSampler(sampler_args)
|
||||
|
||||
|
||||
def get_decoding_mode(executor_config: ExecutorConfig) -> DecodingMode:
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import itertools
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
@ -52,8 +51,7 @@ class GuidedDecoder:
|
||||
|
||||
def build(self, scheduled_requests: ScheduledRequests,
|
||||
resource_manager: SeqSlotManager) -> None:
|
||||
for llm_req in itertools.chain(scheduled_requests.context_requests,
|
||||
scheduled_requests.generation_requests):
|
||||
for llm_req in scheduled_requests.all_requests():
|
||||
if llm_req.guided_decoding_params is None:
|
||||
continue
|
||||
slot = resource_manager.slot_manager.get_slot(llm_req.request_id)
|
||||
@ -84,9 +82,7 @@ class GuidedDecoder:
|
||||
torch.cuda.current_stream().wait_stream(self._stream)
|
||||
|
||||
batched_logits, batched_bitmask = [], []
|
||||
for i, llm_req in enumerate(
|
||||
itertools.chain(scheduled_requests.context_requests,
|
||||
scheduled_requests.generation_requests)):
|
||||
for i, llm_req in enumerate(scheduled_requests.all_requests()):
|
||||
if llm_req.guided_decoding_params is None:
|
||||
continue
|
||||
if llm_req.is_context_init_state and not llm_req.is_last_context_chunk:
|
||||
|
||||
@ -254,6 +254,7 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
|
||||
exclude_last_generation_logits: bool = False,
|
||||
return_perf_metrics: bool = False,
|
||||
stop_words_list: list[list[int]] | None = None,
|
||||
is_draft: bool = False,
|
||||
**kwargs):
|
||||
self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
|
||||
None)
|
||||
@ -288,6 +289,7 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
|
||||
self.py_return_context_logits = return_context_logits
|
||||
self.py_return_generation_logits = return_generation_logits
|
||||
self.py_return_logits_device_memory = return_logits_device_memory
|
||||
self.py_is_draft = is_draft
|
||||
|
||||
# TODO: remove this when use DynamicDecodeOp in pytorch flow.
|
||||
# currently, keep py_stop_words_list as python list, rather than tensor.
|
||||
|
||||
@ -4,7 +4,6 @@ import functools
|
||||
import gc
|
||||
import glob
|
||||
import inspect
|
||||
import itertools
|
||||
import math
|
||||
import multiprocessing
|
||||
import os
|
||||
@ -21,6 +20,7 @@ import torch
|
||||
import torch._dynamo.config
|
||||
|
||||
import tensorrt_llm.bindings.internal.userbuffers as ub
|
||||
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
|
||||
from tensorrt_llm._torch.pyexecutor.sampler import SampleStateTensors
|
||||
from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP
|
||||
from tensorrt_llm._utils import (is_trace_enabled, local_mpi_rank,
|
||||
@ -319,6 +319,7 @@ def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int],
|
||||
|
||||
|
||||
class PyTorchModelEngine(ModelEngine):
|
||||
BEAM_WIDTH = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -659,13 +660,12 @@ class PyTorchModelEngine(ModelEngine):
|
||||
return result
|
||||
|
||||
@contextlib.contextmanager
|
||||
def release_batch(result):
|
||||
def release_batch(result: ScheduledRequests | None):
|
||||
try:
|
||||
yield result
|
||||
finally:
|
||||
if result is not None:
|
||||
for req in itertools.chain(result.generation_requests,
|
||||
result.context_requests):
|
||||
for req in result.all_requests():
|
||||
kv_cache_manager.free_resources(req)
|
||||
if spec_resource_manager is not None:
|
||||
spec_resource_manager.free_resources(req)
|
||||
@ -1153,7 +1153,15 @@ class PyTorchModelEngine(ModelEngine):
|
||||
draft_lens = []
|
||||
mrope_config = defaultdict(list)
|
||||
|
||||
batch_idx = 0
|
||||
mtp_batch_idx = 0 # Temporary: MTP (and Eagle3OneModel) remain the only samplers to index new_tokens serially
|
||||
|
||||
def py_batch_idx(request: LlmRequest) -> int:
|
||||
if not self.without_logits:
|
||||
return request.seq_slot
|
||||
nonlocal mtp_batch_idx
|
||||
batch_idx = mtp_batch_idx
|
||||
mtp_batch_idx += 1
|
||||
return batch_idx
|
||||
|
||||
for request in scheduled_requests.context_requests:
|
||||
request_ids.append(request.py_request_id)
|
||||
@ -1184,10 +1192,9 @@ class PyTorchModelEngine(ModelEngine):
|
||||
) if mrope_rotary_cos_sin.device == 'cpu' else mrope_rotary_cos_sin
|
||||
mrope_config['mrope_rotary_cos_sin'].append(
|
||||
mrope_rotary_cos_sin.to('cuda', non_blocking=True))
|
||||
request.py_batch_idx = batch_idx
|
||||
batch_idx += 1
|
||||
request.py_batch_idx = py_batch_idx(request)
|
||||
|
||||
num_ctx_requests = batch_idx
|
||||
num_ctx_requests = len(scheduled_requests.context_requests)
|
||||
num_ctx_tokens = len(input_ids)
|
||||
new_tokens_device, new_tokens_lens_device, next_draft_tokens_device = None, None, None
|
||||
if new_tensors_device is not None:
|
||||
@ -1227,7 +1234,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
assert spec_dec_mode.support_overlap_scheduler(
|
||||
), f"{self.spec_config.spec_dec_name} does not support overlap scheduler"
|
||||
|
||||
# will contain previous batch incices of generation requests
|
||||
# will contain previous batch indices of generation requests
|
||||
previous_batch_indices = []
|
||||
previous_pos_indices = []
|
||||
for request in extend_requests:
|
||||
@ -1267,13 +1274,11 @@ class PyTorchModelEngine(ModelEngine):
|
||||
num_cached_tokens_per_seq.append(past_seen_token_num)
|
||||
request_ids.append(request.py_request_id)
|
||||
# update batch index
|
||||
request.py_batch_idx = batch_idx
|
||||
batch_idx += 1
|
||||
request.py_batch_idx = py_batch_idx(request)
|
||||
else:
|
||||
# update batch index
|
||||
previous_batch_idx = request.py_batch_idx
|
||||
request.py_batch_idx = batch_idx
|
||||
batch_idx += 1
|
||||
request.py_batch_idx = py_batch_idx(request)
|
||||
# inputs
|
||||
# overlap scheduler can only support the speculative decoding
|
||||
# methods with a fixed number of draft tokens
|
||||
@ -1324,12 +1329,21 @@ class PyTorchModelEngine(ModelEngine):
|
||||
prompt_lengths.append(request.py_prompt_len)
|
||||
draft_lens.append(0)
|
||||
|
||||
request.py_batch_idx = batch_idx
|
||||
batch_idx += 1
|
||||
request.py_batch_idx = py_batch_idx(request)
|
||||
|
||||
previous_batch_len = len(previous_batch_indices)
|
||||
|
||||
def previous_seq_slots_device():
|
||||
previous_batch_indices_host = torch.tensor(previous_batch_indices,
|
||||
dtype=torch.int,
|
||||
pin_memory=True)
|
||||
previous_slots = self.previous_batch_indices_cuda[:
|
||||
previous_batch_len]
|
||||
previous_slots.copy_(previous_batch_indices_host, non_blocking=True)
|
||||
return previous_slots
|
||||
|
||||
num_tokens = len(input_ids)
|
||||
num_draft_tokens = len(draft_tokens)
|
||||
previous_batchs = len(previous_batch_indices)
|
||||
num_requests = len(request_ids)
|
||||
total_num_tokens = len(position_ids)
|
||||
assert total_num_tokens <= self.max_num_tokens, (
|
||||
@ -1347,67 +1361,55 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self.draft_tokens_cuda[:len(draft_tokens)].copy_(draft_tokens,
|
||||
non_blocking=True)
|
||||
if next_draft_tokens_device is not None:
|
||||
if len(previous_batch_indices) > 0:
|
||||
previous_batch_indices = torch.tensor(previous_batch_indices,
|
||||
dtype=torch.int,
|
||||
pin_memory=True)
|
||||
self.previous_batch_indices_cuda[:previous_batchs].copy_(
|
||||
previous_batch_indices, non_blocking=True)
|
||||
if previous_batch_len > 0:
|
||||
previous_slots = previous_seq_slots_device()
|
||||
# previous input ids
|
||||
previous_batch_tokens = previous_batchs * (1 +
|
||||
self.max_draft_len)
|
||||
self.input_ids_cuda[
|
||||
num_tokens:num_tokens +
|
||||
previous_batch_tokens].copy_(new_tokens_device[
|
||||
self.previous_batch_indices_cuda[:previous_batchs], :].
|
||||
flatten(),
|
||||
non_blocking=True)
|
||||
previous_batch_tokens = previous_batch_len * (
|
||||
1 + self.max_draft_len)
|
||||
new_tokens = new_tokens_device[previous_slots, :].flatten()
|
||||
self.input_ids_cuda[num_tokens:num_tokens +
|
||||
previous_batch_tokens].copy_(
|
||||
new_tokens, non_blocking=True)
|
||||
# previous draft tokens
|
||||
previous_batch_draft_tokens = previous_batchs * self.max_draft_len
|
||||
self.draft_tokens_cuda[
|
||||
num_draft_tokens:num_draft_tokens +
|
||||
previous_batch_draft_tokens].copy_(next_draft_tokens_device[
|
||||
self.previous_batch_indices_cuda[:previous_batchs], :].
|
||||
flatten(),
|
||||
non_blocking=True)
|
||||
previous_batch_draft_tokens = previous_batch_len * self.max_draft_len
|
||||
self.draft_tokens_cuda[num_draft_tokens:num_draft_tokens +
|
||||
previous_batch_draft_tokens].copy_(
|
||||
next_draft_tokens_device[
|
||||
previous_slots, :].flatten(),
|
||||
non_blocking=True)
|
||||
# prepare data for the preprocess inputs
|
||||
kv_len_offsets_device = new_tokens_lens_device - self.max_draft_len - 1
|
||||
previous_pos_indices = torch.tensor(previous_pos_indices,
|
||||
dtype=torch.int,
|
||||
pin_memory=True)
|
||||
previous_pos_indices_host = torch.tensor(previous_pos_indices,
|
||||
dtype=torch.int,
|
||||
pin_memory=True)
|
||||
self.previous_pos_indices_cuda[0:previous_batch_tokens].copy_(
|
||||
previous_pos_indices, non_blocking=True)
|
||||
previous_pos_indices_host, non_blocking=True)
|
||||
self.previous_pos_id_offsets_cuda[
|
||||
0:previous_batch_tokens].copy_(
|
||||
new_tokens_lens_device[self.previous_pos_indices_cuda[
|
||||
0:previous_batch_tokens]],
|
||||
non_blocking=True)
|
||||
self.previous_kv_lens_offsets_cuda[0:previous_batchs].copy_(
|
||||
kv_len_offsets_device[
|
||||
self.previous_batch_indices_cuda[:previous_batchs]],
|
||||
non_blocking=True)
|
||||
self.previous_kv_lens_offsets_cuda[0:previous_batch_len].copy_(
|
||||
kv_len_offsets_device[previous_slots], non_blocking=True)
|
||||
# for the requests that do not have previous batch, set the previous_pos_id_offsets and
|
||||
# previous_kv_lens_offsets to zeros to skip the value changes in _preprocess_inputs
|
||||
self.previous_pos_id_offsets_cuda[
|
||||
previous_batch_tokens:num_requests *
|
||||
(1 + self.max_draft_len)] *= 0
|
||||
self.previous_kv_lens_offsets_cuda[
|
||||
previous_batchs:num_requests] *= 0
|
||||
previous_batch_len:num_requests] *= 0
|
||||
else:
|
||||
# change the data to zeros to skip the value changes in _preprocess_inputs
|
||||
self.previous_pos_id_offsets_cuda *= 0
|
||||
self.previous_kv_lens_offsets_cuda *= 0
|
||||
elif new_tokens_device is not None:
|
||||
previous_batch_tokens = len(previous_batch_indices)
|
||||
previous_batch_indices = torch.tensor(previous_batch_indices,
|
||||
dtype=torch.int,
|
||||
pin_memory=True)
|
||||
self.previous_batch_indices_cuda[:previous_batch_tokens].copy_(
|
||||
previous_batch_indices, non_blocking=True)
|
||||
self.input_ids_cuda[num_tokens:num_tokens + previous_batchs].copy_(
|
||||
new_tokens_device[
|
||||
self.previous_batch_indices_cuda[:previous_batchs]],
|
||||
non_blocking=True)
|
||||
seq_slots_device = previous_seq_slots_device()
|
||||
max_draft_len = max(draft_lens)
|
||||
new_tokens = new_tokens_device[:max_draft_len + 1,
|
||||
seq_slots_device, :self.BEAM_WIDTH]
|
||||
self.input_ids_cuda[num_tokens:num_tokens +
|
||||
previous_batch_len].copy_(new_tokens.flatten(),
|
||||
non_blocking=True)
|
||||
|
||||
position_ids = torch.tensor(position_ids,
|
||||
dtype=torch.int,
|
||||
@ -1645,7 +1647,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# for star attention, we need customized block ids
|
||||
block_ids_per_seq = []
|
||||
num_cached_tokens_per_seq = []
|
||||
output_token_idx = 0
|
||||
for request in scheduled_requests.context_requests:
|
||||
request_ids.append(request.py_request_id)
|
||||
prompt_lengths.append(request.py_prompt_len)
|
||||
@ -1702,8 +1703,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
sequence_lengths.append(len(input_id))
|
||||
block_ids_per_seq.extend([all_cache_indices])
|
||||
num_cached_tokens_per_seq.append(past_seen_token_num)
|
||||
request.output_token_idx = output_token_idx
|
||||
output_token_idx += 1
|
||||
num_contexts = len(sequence_lengths)
|
||||
for request in scheduled_requests.context_requests:
|
||||
ctx_iter = request.ctx_iters
|
||||
@ -1743,8 +1742,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
sequence_lengths.append(len(input_id))
|
||||
block_ids_per_seq.extend([all_cache_indices])
|
||||
num_cached_tokens_per_seq.append(past_seen_token_num)
|
||||
request.output_token_idx = output_token_idx
|
||||
output_token_idx += 1
|
||||
num_queries = len(sequence_lengths) - num_contexts
|
||||
|
||||
# Requests with draft tokens are treated like extend requests.
|
||||
@ -1802,8 +1799,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
position_ids.append(last_query_pos_id + request.gen_iters + 1)
|
||||
block_ids_per_seq.extend([all_cache_indices])
|
||||
num_cached_tokens_per_seq.append(past_seen_token_num)
|
||||
request.output_token_idx = output_token_idx
|
||||
output_token_idx += 1
|
||||
|
||||
num_tokens = len(input_ids)
|
||||
assert num_tokens <= self.max_num_tokens, (
|
||||
@ -2171,9 +2166,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
num_ctx_req = len(scheduled_requests.context_requests)
|
||||
logits_tensor = outputs["logits"]
|
||||
|
||||
for idx, request in enumerate(
|
||||
itertools.chain(scheduled_requests.context_requests,
|
||||
scheduled_requests.generation_requests)):
|
||||
for idx, request in enumerate(scheduled_requests.all_requests()):
|
||||
logits_processors = getattr(request, "py_logits_post_processors",
|
||||
None)
|
||||
if not logits_processors:
|
||||
|
||||
@ -11,12 +11,12 @@ import traceback
|
||||
import weakref
|
||||
from collections import namedtuple
|
||||
from contextlib import contextmanager
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
|
||||
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
|
||||
from tensorrt_llm._utils import (customized_gc_thresholds, global_mpi_rank,
|
||||
is_trace_enabled, nvtx_range, trace_func)
|
||||
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
|
||||
@ -36,7 +36,7 @@ from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
|
||||
LlmResponse, executor_request_to_llm_request)
|
||||
from .model_engine import ModelEngine
|
||||
from .sampler import Sampler, SampleState, SampleStateTensors, TorchSampler
|
||||
from .scheduler import ScheduledRequests
|
||||
from .scheduler import RequestScheduler, ScheduledRequests
|
||||
|
||||
# Environment variable to specify iteration ranges for profiling start/stop.
|
||||
# Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..."
|
||||
@ -163,10 +163,11 @@ class PyExecutor:
|
||||
|
||||
def __init__(self,
|
||||
resource_manager,
|
||||
scheduler,
|
||||
scheduler: RequestScheduler,
|
||||
model_engine: ModelEngine,
|
||||
sampler: Sampler,
|
||||
dist: Distributed,
|
||||
max_num_sequences: int,
|
||||
drafter: Drafter = None,
|
||||
disable_overlap_scheduler: bool = False,
|
||||
max_input_len: int = 2048,
|
||||
@ -271,11 +272,13 @@ class PyExecutor:
|
||||
if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"):
|
||||
self.event_loop = trace_func(self.event_loop)
|
||||
|
||||
if self.draft_model_engine is not None and self.event_loop.__name__ != self._executor_loop.__name__:
|
||||
raise NotImplementedError(
|
||||
"Drafting is not supported for selected executor loop. "
|
||||
"Please disable disagg/pipeline parallelism/overlap scheduler.")
|
||||
|
||||
if self.draft_model_engine is not None:
|
||||
if self.event_loop.__name__ != self._executor_loop.__name__:
|
||||
raise NotImplementedError(
|
||||
"Drafting is not supported for selected executor loop. "
|
||||
"Please disable disagg/pipeline parallelism/overlap scheduler."
|
||||
)
|
||||
self.draft_seq_slot_manager = SeqSlotManager(max_num_sequences)
|
||||
self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold
|
||||
|
||||
self.worker_started = False
|
||||
@ -757,7 +760,7 @@ class PyExecutor:
|
||||
"cpu", non_blocking=True)
|
||||
sample_state = self._sample_async(
|
||||
scheduled_batch, batch_outputs)
|
||||
sample_state.logits_host = logits_host
|
||||
sample_state.host.logits = logits_host
|
||||
self._update_request_states(scheduled_batch)
|
||||
|
||||
if self.enable_iter_perf_stats:
|
||||
@ -788,7 +791,6 @@ class PyExecutor:
|
||||
# Receive tokens from previous pp rank (w.r.t model forward direction)
|
||||
(
|
||||
logits,
|
||||
sample_state.log_probs,
|
||||
sample_state.host,
|
||||
) = self.dist.recv_object(
|
||||
src=self.dist.prev_pp_rank,
|
||||
@ -796,8 +798,9 @@ class PyExecutor:
|
||||
)
|
||||
if logits is not None:
|
||||
logits_host = torch.from_numpy(logits)
|
||||
sample_state.logits_host = logits_host
|
||||
sample_state.logits = logits_host.to(self.device_id)
|
||||
sample_state.host.logits = logits_host
|
||||
sample_state.device.logits = logits_host.to(
|
||||
self.device_id)
|
||||
else:
|
||||
torch.cuda.nvtx.range_push("_handle_new_tokens_last_pp")
|
||||
sample_state.sampler_event.synchronize()
|
||||
@ -807,16 +810,16 @@ class PyExecutor:
|
||||
if not self.dist.is_second_last_pp_rank:
|
||||
if self.send_handles[prev_microbatch_id] is not None:
|
||||
self.send_handles[prev_microbatch_id].Wait()
|
||||
needs_logits = (
|
||||
self._need_return_logits(scheduled_batch)
|
||||
or (self._need_return_log_probs(scheduled_batch)
|
||||
and sample_state.host.log_probs is not None))
|
||||
serialized_logits = sample_state.host.logits.numpy(
|
||||
) if needs_logits else None
|
||||
self.send_handles[
|
||||
prev_microbatch_id] = self.dist.isend_object(
|
||||
(
|
||||
sample_state.logits_host.numpy() if
|
||||
self._need_return_logits(scheduled_batch) or
|
||||
(self._need_return_log_probs(
|
||||
scheduled_batch)
|
||||
and sample_state.log_probs is not None)
|
||||
else None,
|
||||
sample_state.log_probs,
|
||||
serialized_logits,
|
||||
sample_state.host,
|
||||
),
|
||||
dest=self.dist.next_pp_rank,
|
||||
@ -1754,41 +1757,35 @@ class PyExecutor:
|
||||
input_tokens = spec_config.get_draft_model_prompt(
|
||||
request.get_tokens()[beam_idx])
|
||||
|
||||
if request.max_beam_num_tokens - 1 == request.py_prompt_len:
|
||||
# This is the first time the draft model is seeing this request.
|
||||
# Prepare a context request. We discard the first token and take
|
||||
# the newly decoded one - this is the convention for EAGLE 2 and 3.
|
||||
assert num_draft_tokens == 0
|
||||
new_request = LlmRequest(
|
||||
def create_new_request(input_tokens):
|
||||
return LlmRequest(
|
||||
request_id=request.py_request_id,
|
||||
max_new_tokens=request.py_max_new_tokens,
|
||||
input_tokens=input_tokens,
|
||||
sampling_config=request.sampling_config,
|
||||
return_perf_metrics=request.return_perf_metrics,
|
||||
is_streaming=False)
|
||||
is_streaming=False,
|
||||
is_draft=True)
|
||||
|
||||
if request.max_beam_num_tokens - 1 == request.py_prompt_len:
|
||||
# This is the first time the draft model is seeing this request.
|
||||
# Prepare a context request. We discard the first token and take
|
||||
# the newly decoded one - this is the convention for EAGLE 2 and 3.
|
||||
assert num_draft_tokens == 0
|
||||
new_request = create_new_request(input_tokens)
|
||||
draft_batch.context_requests.append(new_request)
|
||||
elif num_accepted_tokens == 0:
|
||||
new_request = LlmRequest(
|
||||
request_id=request.py_request_id,
|
||||
max_new_tokens=request.py_max_new_tokens,
|
||||
input_tokens=input_tokens[:-1],
|
||||
sampling_config=request.sampling_config,
|
||||
return_perf_metrics=request.return_perf_metrics,
|
||||
is_streaming=False)
|
||||
new_request = create_new_request(input_tokens[:-1])
|
||||
# Explicitly add the last token so get_last_tokens() returns
|
||||
# the right value
|
||||
new_request.add_new_token(input_tokens[-1], beam_idx)
|
||||
new_request.state = LlmRequestState.GENERATION_IN_PROGRESS
|
||||
draft_batch.generation_requests.append(new_request)
|
||||
else:
|
||||
new_request = LlmRequest(
|
||||
request_id=request.py_request_id,
|
||||
max_new_tokens=request.py_max_new_tokens,
|
||||
input_tokens=input_tokens,
|
||||
sampling_config=request.sampling_config,
|
||||
return_perf_metrics=request.return_perf_metrics,
|
||||
is_streaming=False)
|
||||
new_request = create_new_request(input_tokens)
|
||||
new_request.context_chunk_size = num_accepted_tokens + 1
|
||||
new_request.context_current_position = len(
|
||||
input_tokens) - num_accepted_tokens - 1
|
||||
new_request.context_chunk_size = num_accepted_tokens + 1
|
||||
new_request.context_current_position = len(
|
||||
input_tokens) - num_accepted_tokens - 1
|
||||
@ -1807,16 +1804,19 @@ class PyExecutor:
|
||||
|
||||
@nvtx_range("_prepare_draft_tokens")
|
||||
def _prepare_draft_tokens(self, scheduled_requests: ScheduledRequests):
|
||||
if not self.draft_model_engine:
|
||||
raise ValueError("Draft model engine is not set")
|
||||
|
||||
try:
|
||||
draft_batch = self._prepare_draft_batch(scheduled_requests)
|
||||
|
||||
if draft_batch.batch_size == 0:
|
||||
return
|
||||
self.draft_seq_slot_manager.prepare_resources(draft_batch)
|
||||
|
||||
req_id_to_old_request = {
|
||||
req.py_request_id: req
|
||||
for req in chain(scheduled_requests.context_requests,
|
||||
scheduled_requests.generation_requests)
|
||||
for req in scheduled_requests.all_requests()
|
||||
}
|
||||
|
||||
# Disable cuda graph for the 1st draft model forward
|
||||
@ -1838,8 +1838,7 @@ class PyExecutor:
|
||||
|
||||
def _process_decoded_tokens(draft_batch):
|
||||
new_requests = []
|
||||
for req in chain(draft_batch.context_requests,
|
||||
draft_batch.generation_requests):
|
||||
for req in draft_batch.all_requests():
|
||||
target_model_req = req_id_to_old_request[req.py_request_id]
|
||||
target_model_req.py_draft_tokens.append(
|
||||
req.get_last_tokens(0))
|
||||
@ -1847,6 +1846,8 @@ class PyExecutor:
|
||||
target_model_req.py_draft_tokens
|
||||
) < target_model_req.py_draft_pages_allocated:
|
||||
new_requests.append(req)
|
||||
else:
|
||||
self.draft_seq_slot_manager.free_resources(req)
|
||||
|
||||
return new_requests
|
||||
|
||||
@ -2087,14 +2088,12 @@ class PyExecutor:
|
||||
|
||||
def _add_inflight_ids(self, scheduled_requests):
|
||||
"""Add reqids of current requests to self.inflight_req_ids."""
|
||||
for req in chain(scheduled_requests.context_requests,
|
||||
scheduled_requests.generation_requests):
|
||||
for req in scheduled_requests.all_requests():
|
||||
self.inflight_req_ids.insert(req.request_id)
|
||||
|
||||
def _remove_inflight_ids(self, scheduled_requests):
|
||||
"""Remove reqids of current requests from self.inflight_req_ids."""
|
||||
for req in chain(scheduled_requests.context_requests,
|
||||
scheduled_requests.generation_requests):
|
||||
for req in scheduled_requests.all_requests():
|
||||
self.inflight_req_ids.erase(req.request_id)
|
||||
|
||||
def _should_exclude_last_generation_logits(self) -> bool:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
|
||||
@ -27,9 +28,11 @@ from .llm_request import LlmRequest, LlmRequestState
|
||||
from .scheduler import ScheduledRequests
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
@dataclass(kw_only=True)
|
||||
class SampleStateTensors:
|
||||
new_tokens: torch.Tensor
|
||||
logits: torch.Tensor | None = None
|
||||
log_probs: torch.Tensor | None = None
|
||||
|
||||
def values(self):
|
||||
return vars(self).values()
|
||||
@ -39,13 +42,6 @@ class SampleStateTensors:
|
||||
class SampleState:
|
||||
scheduled_requests: ScheduledRequests
|
||||
|
||||
logits: torch.Tensor = None
|
||||
logits_host: torch.Tensor = None
|
||||
|
||||
# Set when decode_async() has evaluated these to avoid computing again in update_requests()
|
||||
# log_probs[request_idx][token_idx]
|
||||
log_probs: list[list[float] | None] | None = None
|
||||
|
||||
device: SampleStateTensors = None
|
||||
host: SampleStateTensors = None
|
||||
|
||||
@ -77,10 +73,12 @@ class EarlyStopSampler(Sampler):
|
||||
|
||||
def sample_async(self, scheduled_requests: ScheduledRequests,
|
||||
model_outputs) -> SampleState:
|
||||
return SampleState(scheduled_requests=scheduled_requests,
|
||||
logits=model_outputs['logits'])
|
||||
host = SampleStateTensors(logits=model_outputs['logits'],
|
||||
new_tokens=torch.empty(0))
|
||||
return SampleState(scheduled_requests=scheduled_requests, host=host)
|
||||
|
||||
def update_requests(self, state: SampleState) -> None:
|
||||
assert isinstance(state, SampleState)
|
||||
scheduled_requests = state.scheduled_requests
|
||||
assert (not scheduled_requests.generation_requests)
|
||||
for idx, request in enumerate(scheduled_requests.context_requests):
|
||||
@ -88,7 +86,7 @@ class EarlyStopSampler(Sampler):
|
||||
# NOTE: This is a hack: set finish reason manually and set the beam 0
|
||||
request.set_finished_reason(FinishReason.LENGTH, 0)
|
||||
if request.py_return_context_logits:
|
||||
logits = state.logits[idx]
|
||||
logits = state.host.logits[idx]
|
||||
if logits.ndim == 1:
|
||||
# For BERT: Add axis to be compatible with LogitsStorage
|
||||
# (LogitsStorage will interpret this dim as the prompt_len which
|
||||
@ -104,8 +102,6 @@ def top_k_sampling_batch(logits, top_k=50):
|
||||
# logits should be 2D :[batch_size, vocab_size]
|
||||
batch_size, vocab_size = logits.size()
|
||||
|
||||
raw_probs = torch.softmax(logits, dim=-1)
|
||||
|
||||
# get first top_k logits of each sample and their indices
|
||||
values, indices = torch.topk(logits, top_k, dim=-1)
|
||||
min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size)
|
||||
@ -115,24 +111,18 @@ def top_k_sampling_batch(logits, top_k=50):
|
||||
torch.full_like(logits, float('-inf')), logits)
|
||||
|
||||
# compute probability distribution
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
softmax = torch.softmax(logits, dim=-1)
|
||||
|
||||
# sample from the distribution and generate result of [batch_size, 1]
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
|
||||
token_probs = torch.gather(raw_probs, dim=1,
|
||||
index=next_tokens.unsqueeze(1)).squeeze(-1)
|
||||
log_probs = torch.log(token_probs)
|
||||
return next_tokens, log_probs
|
||||
next_tokens = torch.multinomial(softmax, num_samples=1).squeeze(-1)
|
||||
return next_tokens, softmax
|
||||
|
||||
|
||||
def top_p_sampling_batch(logits, top_p=0.9):
|
||||
def top_p_sampling_batch(logits: torch.Tensor, top_p: float = 0.9):
|
||||
logits_dim = logits.dim()
|
||||
if logits_dim == 1:
|
||||
logits = logits.unsqueeze(0)
|
||||
# logits should be 2D :[batch_size, vocab_size]
|
||||
batch_size, vocab_size = logits.size()
|
||||
|
||||
raw_probs = torch.softmax(logits, dim=-1)
|
||||
assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]"
|
||||
|
||||
# sort the logits of each sample in descending order
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
@ -152,46 +142,96 @@ def top_p_sampling_batch(logits, top_p=0.9):
|
||||
logits = logits.masked_fill(indices_to_remove, float('-inf'))
|
||||
|
||||
# compute probability distribution
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
softmax = torch.softmax(logits, dim=-1)
|
||||
|
||||
# sample from the distribution and generate result of [batch_size, 1]
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
|
||||
token_probs = torch.gather(raw_probs, dim=1,
|
||||
index=next_tokens.unsqueeze(1)).squeeze(-1)
|
||||
log_probs = torch.log(token_probs)
|
||||
return next_tokens, log_probs
|
||||
next_tokens = torch.multinomial(softmax, num_samples=1).squeeze(-1)
|
||||
return next_tokens, softmax
|
||||
|
||||
|
||||
def greedy_search_sampling_batch(logits):
|
||||
raw_probs = torch.softmax(logits, dim=-1)
|
||||
next_tokens = torch.argmax(logits, dim=-1)
|
||||
token_probs = torch.gather(raw_probs, dim=1,
|
||||
index=next_tokens.unsqueeze(1)).squeeze(-1)
|
||||
log_probs = torch.log(token_probs)
|
||||
return next_tokens, log_probs
|
||||
softmax = torch.softmax(logits, dim=-1)
|
||||
return next_tokens, softmax
|
||||
|
||||
|
||||
def decode_single_request(request: LlmRequest, logits):
|
||||
assert logits.dim(
|
||||
) == 2 and logits.shape[0] == 1, "logits should have shape [1, vocab_size]"
|
||||
TopK = tuple[Literal["top_k"], int]
|
||||
TopP = tuple[Literal["top_p"], float]
|
||||
Greedy = tuple[Literal["greedy"], None]
|
||||
GREEDY: Greedy = ("greedy", None)
|
||||
Strategy = TopK | TopP | Greedy
|
||||
|
||||
|
||||
def request_strategy(request: LlmRequest) -> Strategy:
|
||||
if request.sampling_config.top_p is not None and len(
|
||||
request.sampling_config.top_p) > 0:
|
||||
next_tokens, log_probs = top_p_sampling_batch(
|
||||
logits, request.sampling_config.top_p[0])
|
||||
return ("top_p", request.sampling_config.top_p[0])
|
||||
elif request.sampling_config.top_k is not None and len(
|
||||
request.sampling_config.top_k) > 0:
|
||||
next_tokens, log_probs = top_k_sampling_batch(
|
||||
logits, request.sampling_config.top_k[0])
|
||||
return ("top_k", request.sampling_config.top_k[0])
|
||||
else:
|
||||
next_tokens, log_probs = greedy_search_sampling_batch(logits)
|
||||
return next_tokens, log_probs
|
||||
return ("greedy", None)
|
||||
|
||||
|
||||
def sampling_strategies(requests: Iterable[LlmRequest]) -> list[Strategy]:
|
||||
return [request_strategy(req) for req in requests]
|
||||
|
||||
|
||||
def sample(strategy: Strategy, logits: torch.Tensor):
|
||||
match strategy:
|
||||
case ("top_k", top_k):
|
||||
return top_k_sampling_batch(logits, top_k)
|
||||
case ("top_p", top_p):
|
||||
return top_p_sampling_batch(logits, top_p)
|
||||
case ("greedy", None):
|
||||
return greedy_search_sampling_batch(logits)
|
||||
|
||||
|
||||
def add_token(request: LlmRequest,
|
||||
new_tokens: torch.Tensor,
|
||||
*,
|
||||
beam: int,
|
||||
step: int = 0) -> int:
|
||||
seq_slot = request.seq_slot
|
||||
assert seq_slot is not None
|
||||
new_token = int(new_tokens[step, request.seq_slot, beam])
|
||||
request.add_new_token(new_token, beam)
|
||||
return new_token
|
||||
|
||||
|
||||
class TorchSampler(Sampler):
|
||||
BEAM = 0
|
||||
MAX_BEAM_WIDTH = BEAM + 1
|
||||
|
||||
def __init__(self, max_seq_len: int, mixed_sampler: bool = False):
|
||||
self.max_seq_len = max_seq_len
|
||||
self.mixed_sampler = mixed_sampler
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Store:
|
||||
new_tokens: torch.Tensor
|
||||
"""Shape: See cpp DecoderState.getAllNewTokens()"""
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Args:
|
||||
max_seq_len: int
|
||||
max_draft_tokens: int
|
||||
max_num_sequences: int
|
||||
max_beam_width: int
|
||||
mixed_sampler: bool
|
||||
|
||||
def __init__(self, args: Args):
|
||||
self.max_seq_len = args.max_seq_len
|
||||
self.mixed_sampler = args.mixed_sampler
|
||||
self.max_tokens = args.max_draft_tokens + 1
|
||||
assert args.max_beam_width == self.MAX_BEAM_WIDTH, "TorchSampler only supports beam_width = 1"
|
||||
self.num_seq_slots = args.max_num_sequences
|
||||
|
||||
# AutoDeploy build creates the sampler in inference mode,
|
||||
# which would disallow in-place mutating of new_tokens.
|
||||
# So, we temporarily exit inference mode.
|
||||
with torch.inference_mode(False):
|
||||
new_tokens = torch.zeros(
|
||||
(self.max_tokens, self.num_seq_slots, self.MAX_BEAM_WIDTH),
|
||||
dtype=torch.int,
|
||||
device='cuda')
|
||||
self.store = self.Store(new_tokens=new_tokens)
|
||||
|
||||
def _meet_max_token_stop_criteria(self, request: LlmRequest,
|
||||
num_tokens: int):
|
||||
@ -199,7 +239,8 @@ class TorchSampler(Sampler):
|
||||
>= request.py_max_new_tokens) or (num_tokens
|
||||
>= self.max_seq_len)
|
||||
|
||||
def _meet_stop_token_criteria(self, request: LlmRequest):
|
||||
@staticmethod
|
||||
def _meet_stop_token_criteria(request: LlmRequest):
|
||||
if request.py_stop_words_list:
|
||||
assert isinstance(
|
||||
request.py_stop_words_list,
|
||||
@ -217,233 +258,197 @@ class TorchSampler(Sampler):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _handle_stop_criteria(self, request: LlmRequest, new_token: int,
|
||||
num_tokens: int, beam_idx: int) -> bool:
|
||||
def _handle_stop_criteria(self, request: LlmRequest, new_token: int, *,
|
||||
beam: int) -> bool:
|
||||
"""Handle stop criteria and set appropriate finish reasons and state.
|
||||
Returns True if generation should stop."""
|
||||
if new_token == request.py_end_id:
|
||||
request.state = LlmRequestState.GENERATION_COMPLETE
|
||||
request.set_finished_reason(FinishReason.END_ID, beam_idx)
|
||||
request.finish_by_reason(FinishReason.END_ID)
|
||||
return True
|
||||
|
||||
num_tokens = request.get_num_tokens(beam)
|
||||
if self._meet_max_token_stop_criteria(request, num_tokens):
|
||||
request.state = LlmRequestState.GENERATION_COMPLETE
|
||||
request.set_finished_reason(FinishReason.LENGTH, beam_idx)
|
||||
request.finish_by_reason(FinishReason.LENGTH)
|
||||
return True
|
||||
|
||||
if self._meet_stop_token_criteria(request):
|
||||
request.state = LlmRequestState.GENERATION_COMPLETE
|
||||
request.set_finished_reason(FinishReason.STOP_WORDS, beam_idx)
|
||||
request.finish_by_reason(FinishReason.STOP_WORDS)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def update_requests(self, state: SampleState) -> None:
|
||||
if state.sampler_event:
|
||||
state.sampler_event.synchronize()
|
||||
new_tokens_list = state.host.new_tokens.tolist()
|
||||
scheduled_requests = state.scheduled_requests
|
||||
|
||||
request_idx = 0
|
||||
token_idx = 0
|
||||
beam_idx = 0
|
||||
|
||||
def advance_idx(num_tokens=1):
|
||||
nonlocal request_idx, token_idx
|
||||
request_idx += 1
|
||||
token_idx += num_tokens
|
||||
|
||||
def handle_logits(request: LlmRequest, tokens: list[int], count=1):
|
||||
if state.logits is None:
|
||||
return
|
||||
if not request.py_return_generation_logits and not request.py_return_log_probs:
|
||||
return
|
||||
|
||||
current_slice = slice(token_idx, token_idx + count)
|
||||
current_logits = state.logits[current_slice]
|
||||
|
||||
request.py_result.append_generation_logits(current_logits)
|
||||
|
||||
if not request.py_return_log_probs:
|
||||
return
|
||||
|
||||
if state.log_probs:
|
||||
log_probs = state.log_probs[request_idx]
|
||||
else:
|
||||
_, log_probs = greedy_search_sampling_batch(current_logits)
|
||||
|
||||
token_log_probs = [{
|
||||
token: Logprob(logprob=logprob, rank=1)
|
||||
} for token, logprob in zip(tokens, log_probs.tolist())]
|
||||
request.py_result.append_log_probs([token_log_probs])
|
||||
|
||||
if hasattr(scheduled_requests, 'chunked_requests'):
|
||||
request_idx += len(scheduled_requests.chunked_requests)
|
||||
token_idx += len(scheduled_requests.chunked_requests)
|
||||
|
||||
for request in scheduled_requests.context_requests:
|
||||
if request.context_remaining_length != 0:
|
||||
advance_idx()
|
||||
continue
|
||||
|
||||
if request.state != LlmRequestState.GENERATION_COMPLETE:
|
||||
new_token = new_tokens_list[token_idx]
|
||||
num_tokens = request.add_new_token(new_token, beam_idx)
|
||||
self._handle_stop_criteria(request, new_token, num_tokens,
|
||||
beam_idx)
|
||||
handle_logits(request, [new_token])
|
||||
request.py_decoding_iter += 1
|
||||
advance_idx()
|
||||
|
||||
extend_requests = []
|
||||
generation_requests = []
|
||||
for request in scheduled_requests.generation_requests:
|
||||
if len(request.py_draft_tokens) > 0:
|
||||
extend_requests.append(request)
|
||||
else:
|
||||
generation_requests.append(request)
|
||||
|
||||
for request in extend_requests:
|
||||
if request.state != LlmRequestState.GENERATION_COMPLETE:
|
||||
new_token = new_tokens_list[token_idx]
|
||||
num_tokens = request.add_new_token(new_token, beam_idx)
|
||||
if self._handle_stop_criteria(request, new_token, num_tokens,
|
||||
beam_idx):
|
||||
continue
|
||||
|
||||
# Accept draft tokens (if we have any) if and only if they match the new
|
||||
# token exactly.
|
||||
num_accepted = 0
|
||||
new_tokens = [new_token]
|
||||
for draft_token in request.py_draft_tokens:
|
||||
if draft_token != new_token:
|
||||
# Reject.
|
||||
break
|
||||
num_accepted += 1
|
||||
new_token = new_tokens_list[token_idx + num_accepted]
|
||||
num_tokens = request.add_new_token(new_token, beam_idx)
|
||||
new_tokens.append(num_tokens) # `num_tokens`->`new_token`
|
||||
|
||||
if self._handle_stop_criteria(request, new_token,
|
||||
num_tokens, beam_idx):
|
||||
break
|
||||
handle_logits(request, new_tokens, num_accepted)
|
||||
request.py_decoding_iter += 1
|
||||
request.py_num_accepted_draft_tokens = num_accepted
|
||||
request.py_rewind_len = request.py_draft_pages_allocated - num_accepted
|
||||
advance_idx(len(request.py_draft_tokens) + 1)
|
||||
|
||||
for request in generation_requests:
|
||||
if request.state != LlmRequestState.GENERATION_COMPLETE:
|
||||
new_token = new_tokens_list[token_idx]
|
||||
num_tokens = request.add_new_token(new_token, beam_idx)
|
||||
self._handle_stop_criteria(request, new_token, num_tokens,
|
||||
beam_idx)
|
||||
handle_logits(request, [new_token])
|
||||
request.py_decoding_iter += 1
|
||||
advance_idx()
|
||||
|
||||
def _mixed_sample(self, scheduled_requests: ScheduledRequests,
|
||||
model_outputs) -> SampleState:
|
||||
logits = model_outputs["logits"]
|
||||
log_probs = []
|
||||
new_tokens_device_array = []
|
||||
|
||||
idx = 0
|
||||
|
||||
for request in scheduled_requests.context_requests:
|
||||
assert not request.py_return_context_logits, "Return context logits not supported"
|
||||
token_logits = logits[idx:idx + 1, :]
|
||||
new_token, probs = decode_single_request(request, token_logits)
|
||||
new_tokens_device_array.append(new_token)
|
||||
probs = [probs.tolist()] if request.py_return_log_probs else None
|
||||
log_probs.append(probs) # Currently always beam_width=1
|
||||
idx += 1
|
||||
|
||||
for request in scheduled_requests.generation_requests:
|
||||
if request.state == LlmRequestState.GENERATION_COMPLETE:
|
||||
continue
|
||||
assert len(
|
||||
request.py_draft_tokens
|
||||
) == 0, "Speculative decoding not supported in SeparateDecoder."
|
||||
token_logits = logits[idx:idx + 1, :]
|
||||
new_token, probs = decode_single_request(request, token_logits)
|
||||
new_tokens_device_array.append(new_token)
|
||||
probs = [probs.tolist()] if request.py_return_log_probs else None
|
||||
log_probs.append(probs) # Currently always beam_width=1
|
||||
idx += 1
|
||||
|
||||
new_tokens_device = torch.cat(new_tokens_device_array)
|
||||
new_tokens_host = new_tokens_device.to('cpu', non_blocking=True)
|
||||
sampler_event = torch.cuda.Event()
|
||||
sampler_event.record()
|
||||
|
||||
return SampleState(
|
||||
scheduled_requests=scheduled_requests,
|
||||
logits=logits,
|
||||
device=SampleStateTensors(new_tokens=new_tokens_device),
|
||||
host=SampleStateTensors(new_tokens=new_tokens_host),
|
||||
sampler_event=sampler_event,
|
||||
log_probs=log_probs)
|
||||
|
||||
def _batch_sample(self, scheduled_requests: ScheduledRequests,
|
||||
model_outputs) -> SampleState:
|
||||
logits = model_outputs["logits"]
|
||||
new_tokens_device = torch.argmax(logits, dim=-1)
|
||||
new_tokens_host = new_tokens_device.to('cpu', non_blocking=True)
|
||||
sampler_event = torch.cuda.Event()
|
||||
sampler_event.record()
|
||||
return SampleState(
|
||||
scheduled_requests=scheduled_requests,
|
||||
logits=logits,
|
||||
device=SampleStateTensors(new_tokens=new_tokens_device),
|
||||
host=SampleStateTensors(new_tokens=new_tokens_host),
|
||||
sampler_event=sampler_event)
|
||||
|
||||
def sample_async(self, scheduled_requests: ScheduledRequests,
|
||||
model_outputs) -> SampleState:
|
||||
if self.mixed_sampler:
|
||||
return self._mixed_sample(scheduled_requests, model_outputs)
|
||||
else:
|
||||
return self._batch_sample(scheduled_requests, model_outputs)
|
||||
|
||||
|
||||
class TorchStarAttentionSampler(TorchSampler):
|
||||
|
||||
def update_one_request(self, request: LlmRequest,
|
||||
new_tokens_list: list[int], logits: torch.Tensor):
|
||||
beam_idx = 0
|
||||
|
||||
output_token_idx = request.output_token_idx
|
||||
new_token = new_tokens_list[output_token_idx]
|
||||
num_tokens = request.add_new_token(new_token, beam_idx)
|
||||
|
||||
current_logits = logits[output_token_idx].unsqueeze(0)
|
||||
def handle_logits(self, request: LlmRequest, state: SampleState, *,
|
||||
beam: int, count: int):
|
||||
current_slice = slice(0, count), request.seq_slot, beam
|
||||
if request.py_return_generation_logits:
|
||||
assert state.host.logits is not None
|
||||
current_logits = state.host.logits[current_slice]
|
||||
request.py_result.append_generation_logits(current_logits)
|
||||
if request.py_return_log_probs:
|
||||
_, log_probs = greedy_search_sampling_batch(current_logits)
|
||||
request.py_result.append_log_probs([[{
|
||||
new_token:
|
||||
Logprob(logprob=log_probs.item(), rank=1)
|
||||
}]])
|
||||
assert state.host.log_probs is not None
|
||||
log_probs = state.host.log_probs[request.seq_slot][beam][:count]
|
||||
current_tokens = state.host.new_tokens[current_slice]
|
||||
|
||||
self._handle_stop_criteria(request, new_token, num_tokens, beam_idx)
|
||||
if request.state != LlmRequestState.GENERATION_COMPLETE:
|
||||
request.py_decoding_iter += 1
|
||||
token_log_probs = [{
|
||||
int(token): Logprob(logprob=logprob, rank=1)
|
||||
} for token, logprob in zip(current_tokens, log_probs.tolist())]
|
||||
assert beam == 0, "The following call relies on beam_width to be 1 - hence the list with a single element"
|
||||
request.py_result.append_log_probs([token_log_probs])
|
||||
|
||||
def update_requests(self, state: SampleState):
|
||||
def process_draft_tokens(self, request: LlmRequest,
|
||||
new_tokens: torch.Tensor, new_token: int) -> int:
|
||||
num_accepted = 0
|
||||
for draft_token in request.py_draft_tokens:
|
||||
if draft_token != new_token:
|
||||
# Reject.
|
||||
break
|
||||
num_accepted += 1
|
||||
new_token = add_token(request,
|
||||
new_tokens,
|
||||
beam=self.BEAM,
|
||||
step=num_accepted)
|
||||
if self._handle_stop_criteria(request, new_token, beam=self.BEAM):
|
||||
break
|
||||
return num_accepted
|
||||
|
||||
def update_requests(self, state: SampleState) -> None:
|
||||
assert isinstance(state, SampleState)
|
||||
if state.sampler_event:
|
||||
state.sampler_event.synchronize()
|
||||
new_tokens_list = state.host.new_tokens.tolist()
|
||||
logits = state.logits
|
||||
new_tokens = state.host.new_tokens
|
||||
|
||||
for request in state.scheduled_requests.context_requests:
|
||||
if request.state == LlmRequestState.GENERATION_IN_PROGRESS:
|
||||
self.update_one_request(request, new_tokens_list, logits)
|
||||
for req in state.scheduled_requests.context_requests:
|
||||
if req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0:
|
||||
continue
|
||||
new_token = add_token(req, new_tokens, beam=self.BEAM)
|
||||
stop = self._handle_stop_criteria(req, new_token, beam=self.BEAM)
|
||||
self.handle_logits(req, state, beam=self.BEAM, count=1)
|
||||
req.py_decoding_iter += 1
|
||||
|
||||
for request in state.scheduled_requests.generation_requests:
|
||||
self.update_one_request(request, new_tokens_list, logits)
|
||||
for req in state.scheduled_requests.generation_requests:
|
||||
if req.state == LlmRequestState.GENERATION_COMPLETE:
|
||||
continue
|
||||
new_token = add_token(req, new_tokens, beam=self.BEAM)
|
||||
stop = self._handle_stop_criteria(req, new_token, beam=self.BEAM)
|
||||
processed = 1
|
||||
if not stop and len(req.py_draft_tokens) > 0:
|
||||
num_accepted = self.process_draft_tokens(
|
||||
req, new_tokens, new_token)
|
||||
req.py_num_accepted_draft_tokens = num_accepted
|
||||
req.py_rewind_len = req.py_draft_pages_allocated - num_accepted
|
||||
processed += num_accepted
|
||||
self.handle_logits(req, state, beam=self.BEAM, count=processed)
|
||||
req.py_decoding_iter += 1
|
||||
|
||||
def log_probs_host(self, requests: Iterable[LlmRequest]):
|
||||
"""Shape: In lockstep with TRTLLMSampler: https://github.com/NVIDIA/TensorRT-LLM/blob/cea5dd1e3883b18bf50901a7f196f50a9544c28c/cpp/include/tensorrt_llm/runtime/decoderState.h#L103"""
|
||||
if any(req.py_return_log_probs for req in requests):
|
||||
return torch.empty(
|
||||
(self.num_seq_slots, self.MAX_BEAM_WIDTH, self.max_tokens),
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
return None
|
||||
|
||||
def gen_logits_host(self, requests: Iterable[LlmRequest], vocab_size: int):
|
||||
if any(req.py_return_generation_logits for req in requests):
|
||||
return torch.empty((self.max_tokens, self.num_seq_slots,
|
||||
self.MAX_BEAM_WIDTH, vocab_size),
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
return None
|
||||
|
||||
def sample_async(self, scheduled_requests: ScheduledRequests,
|
||||
model_outputs: dict[str, torch.Tensor]) -> SampleState:
|
||||
requests = scheduled_requests.all_requests()
|
||||
new_tokens = self.store.new_tokens
|
||||
vocab_size = model_outputs["logits"].shape[-1]
|
||||
log_probs_host = self.log_probs_host(requests)
|
||||
gen_logits_host = self.gen_logits_host(requests, vocab_size)
|
||||
self._process_requests(requests,
|
||||
model_outputs,
|
||||
new_tokens,
|
||||
gen_logits_host=gen_logits_host,
|
||||
log_probs_host=log_probs_host)
|
||||
new_tokens_host = new_tokens.to(device="cpu", non_blocking=True)
|
||||
sampler_event = torch.cuda.Event()
|
||||
sampler_event.record()
|
||||
return SampleState(scheduled_requests=scheduled_requests,
|
||||
device=SampleStateTensors(new_tokens=new_tokens),
|
||||
host=SampleStateTensors(new_tokens=new_tokens_host,
|
||||
log_probs=log_probs_host,
|
||||
logits=gen_logits_host),
|
||||
sampler_event=sampler_event)
|
||||
|
||||
@staticmethod
|
||||
def append_eagle3(tokens: torch.Tensor, model_outputs):
|
||||
if "d2t" in model_outputs:
|
||||
d2t = model_outputs["d2t"][tokens]
|
||||
tokens += d2t
|
||||
|
||||
def _process_requests(self,
|
||||
requests: list[LlmRequest],
|
||||
model_outputs: dict[str, torch.Tensor],
|
||||
new_tokens: torch.Tensor,
|
||||
*,
|
||||
gen_logits_host: torch.Tensor | None = None,
|
||||
log_probs_host: torch.Tensor | None = None):
|
||||
beam_width = self.MAX_BEAM_WIDTH
|
||||
beam = self.BEAM
|
||||
raw_logits = model_outputs["logits"]
|
||||
num_steps = [1 + len(req.py_draft_tokens) for req in requests]
|
||||
sum_steps = sum(num_steps)
|
||||
no_draft_tokens = len(requests) == sum_steps
|
||||
fast_path = not self.mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None
|
||||
|
||||
seq_slots = torch.as_tensor([r.seq_slot for r in requests])
|
||||
seq_slots = seq_slots.to(device="cuda", non_blocking=True)
|
||||
|
||||
if fast_path:
|
||||
logits = raw_logits[:len(requests)]
|
||||
next_tokens = torch.argmax(logits, dim=-1)
|
||||
self.append_eagle3(next_tokens, model_outputs)
|
||||
int_next_tokens = next_tokens.to(torch.int, non_blocking=True)
|
||||
next_tokens = int_next_tokens.view(1, -1, beam_width)
|
||||
new_tokens[:1].index_copy_(1, seq_slots, next_tokens)
|
||||
return
|
||||
|
||||
strategies = sampling_strategies(requests)
|
||||
batched_next_tokens, batched_softmax = None, None
|
||||
batched_strategy: Strategy | None = GREEDY
|
||||
if self.mixed_sampler:
|
||||
assert "d2t" not in model_outputs, "eagle3 does not yet support non-greedy sampling"
|
||||
if len(set(strategies)) == 1:
|
||||
batched_strategy = strategies[0]
|
||||
else:
|
||||
batched_strategy = None
|
||||
|
||||
if batched_strategy is not None:
|
||||
logits = raw_logits[:sum_steps]
|
||||
batched_next_tokens, batched_softmax = sample(
|
||||
batched_strategy, logits)
|
||||
self.append_eagle3(batched_next_tokens, model_outputs)
|
||||
|
||||
offset = 0
|
||||
for strategy, slot, steps in zip(strategies, seq_slots, num_steps):
|
||||
input_slice = slice(offset, offset + steps)
|
||||
logits = raw_logits[input_slice]
|
||||
if batched_next_tokens is None:
|
||||
next_tokens, softmax = sample(strategy, logits)
|
||||
else:
|
||||
next_tokens = batched_next_tokens[input_slice]
|
||||
softmax = batched_softmax[input_slice]
|
||||
current_slice = slice(0, steps), slot, beam
|
||||
new_tokens[current_slice] = next_tokens
|
||||
if gen_logits_host is not None:
|
||||
gen_logits_host[current_slice].copy_(logits, non_blocking=True)
|
||||
if log_probs_host is not None:
|
||||
assert beam == 0, "The following call relies on beam_width to be 1 - hence the unsqueeze"
|
||||
token_probs = torch.gather(
|
||||
softmax, dim=1, index=next_tokens.unsqueeze(1)).squeeze(-1)
|
||||
log_probs = torch.log(token_probs)
|
||||
log_probs_host[slot, beam, :steps].copy_(log_probs,
|
||||
non_blocking=True)
|
||||
offset += steps
|
||||
|
||||
|
||||
class Algorithms:
|
||||
@ -456,19 +461,17 @@ class Algorithms:
|
||||
return f"Algs({', '.join(algs)})"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
@dataclass(kw_only=True)
|
||||
class SampleStateTensorsHostTRTLLM(SampleStateTensors):
|
||||
finished_sum: torch.Tensor
|
||||
finish_reasons: torch.Tensor
|
||||
sequence_lengths: torch.Tensor
|
||||
log_probs: torch.Tensor
|
||||
cum_log_probs: torch.Tensor
|
||||
cum_log_probs: torch.Tensor | None = None
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SampleStateTRTLLM(SampleState):
|
||||
host: SampleStateTensorsHostTRTLLM
|
||||
device: SampleStateTensors
|
||||
|
||||
|
||||
class TRTLLMSampler(Sampler):
|
||||
@ -527,13 +530,6 @@ class TRTLLMSampler(Sampler):
|
||||
DecoderInputBuffers(self.max_num_sequences,
|
||||
self.executor_config.max_batch_size,
|
||||
self.MAX_DECODING_TOKENS, buffer_manager),
|
||||
"new_tokens_device_tensor":
|
||||
torch.empty((
|
||||
self.executor_config.max_batch_size,
|
||||
self.executor_config.max_beam_width,
|
||||
),
|
||||
dtype=torch.int,
|
||||
device='cuda'),
|
||||
"sequence_lengths_host":
|
||||
torch.empty((
|
||||
self.executor_config.max_batch_size,
|
||||
@ -602,7 +598,6 @@ class TRTLLMSampler(Sampler):
|
||||
def sample_async(self, scheduled_requests: ScheduledRequests,
|
||||
model_outputs) -> SampleStateTRTLLM:
|
||||
batch_size = scheduled_requests.batch_size
|
||||
beam_width = self.beam_width(scheduled_requests.all_requests)
|
||||
|
||||
self.setup_sampler_step(scheduled_requests.context_requests)
|
||||
|
||||
@ -631,20 +626,6 @@ class TRTLLMSampler(Sampler):
|
||||
self.algs.decoder.forward_async(self.store["decoder_state"],
|
||||
decoding_input)
|
||||
|
||||
# NOTE: The following code prepares a new_tokens_device_tensor in accordance with the
|
||||
# current implementation of model_engine.
|
||||
# TODO: When we support speculative decoding:
|
||||
# new_tokens_device_tensor should be, for speculative decoding cases: [batch, 1 + draft_len], others: [batch]
|
||||
new_tokens_device_tensor = self.store[
|
||||
"new_tokens_device_tensor"][:batch_size, :beam_width]
|
||||
seq_slots = [
|
||||
request.seq_slot for request in scheduled_requests.all_requests
|
||||
]
|
||||
new_tokens_device_tensor.copy_(
|
||||
self.store["decoder_state"].all_new_tokens[0][seq_slots],
|
||||
non_blocking=True)
|
||||
new_tokens_device_tensor = new_tokens_device_tensor.view(-1)
|
||||
|
||||
new_output_tokens = self.store["decoder_state"].all_new_tokens.to(
|
||||
'cpu', non_blocking=True)
|
||||
finished_sum = self.store["decoder_state"].finished_sum.to(
|
||||
@ -654,16 +635,17 @@ class TRTLLMSampler(Sampler):
|
||||
sequence_lengths = self.store["decoder_state"].sequence_lengths.to(
|
||||
'cpu', non_blocking=True)
|
||||
|
||||
log_probs = torch.empty([0], dtype=torch.float, device='cpu')
|
||||
cum_log_probs = torch.empty([0], dtype=torch.float, device='cpu')
|
||||
log_probs = None
|
||||
cum_log_probs = None
|
||||
if any(request.py_return_log_probs
|
||||
for request in scheduled_requests.all_requests):
|
||||
for request in scheduled_requests.all_requests()):
|
||||
log_probs = self.store["decoder_state"].log_probs.to(
|
||||
'cpu', non_blocking=True)
|
||||
cum_log_probs = self.store["decoder_state"].cum_log_probs.to(
|
||||
'cpu', non_blocking=True)
|
||||
|
||||
device = SampleStateTensors(new_tokens=new_tokens_device_tensor)
|
||||
device = SampleStateTensors(
|
||||
new_tokens=self.store["decoder_state"].all_new_tokens)
|
||||
|
||||
host = SampleStateTensorsHostTRTLLM(new_tokens=new_output_tokens,
|
||||
finished_sum=finished_sum,
|
||||
@ -676,7 +658,6 @@ class TRTLLMSampler(Sampler):
|
||||
sampler_event.record()
|
||||
|
||||
return SampleStateTRTLLM(scheduled_requests=scheduled_requests,
|
||||
logits=model_outputs["logits"],
|
||||
device=device,
|
||||
host=host,
|
||||
sampler_event=sampler_event)
|
||||
@ -687,7 +668,8 @@ class TRTLLMSampler(Sampler):
|
||||
|
||||
scheduled_requests = state.scheduled_requests
|
||||
assert scheduled_requests.batch_size > 0
|
||||
beam_width = self.beam_width(scheduled_requests.all_requests)
|
||||
requests = scheduled_requests.all_requests()
|
||||
beam_width = self.beam_width(requests)
|
||||
sampler_event = state.sampler_event
|
||||
|
||||
if sampler_event:
|
||||
@ -697,7 +679,7 @@ class TRTLLMSampler(Sampler):
|
||||
finished_sum_host = state.host.finished_sum
|
||||
sequence_lengths_host_data = state.host.sequence_lengths
|
||||
|
||||
for request in scheduled_requests.all_requests:
|
||||
for request in requests:
|
||||
if request.is_context_init_state:
|
||||
continue
|
||||
|
||||
@ -718,17 +700,20 @@ class TRTLLMSampler(Sampler):
|
||||
seq_len - request.get_num_tokens(beam))
|
||||
|
||||
for step in range(num_new_tokens[beam]):
|
||||
new_token = new_tokens_host[step][seq_slot][beam]
|
||||
request.add_new_token(new_token, beam)
|
||||
new_token = add_token(request,
|
||||
new_tokens_host,
|
||||
beam=beam,
|
||||
step=step)
|
||||
|
||||
if request.py_return_log_probs:
|
||||
assert state.host.log_probs is not None
|
||||
# NOTE: Log probs with drafting has not been tested yet.
|
||||
begin_log_probs_offset = request.prompt_len if request.sampling_config.beam_width == 1 else 0
|
||||
current_token = seq_len - request.prompt_len - num_new_tokens[
|
||||
beam] + step
|
||||
|
||||
log_probs.append({
|
||||
new_token.item():
|
||||
new_token:
|
||||
Logprob(logprob=state.host.log_probs[seq_slot][beam]
|
||||
[begin_log_probs_offset +
|
||||
current_token].item(),
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import namedtuple
|
||||
from itertools import chain
|
||||
from typing import Optional
|
||||
|
||||
from tensorrt_llm.bindings import executor as tb_executor
|
||||
@ -36,9 +35,8 @@ class ScheduledRequests:
|
||||
def batch_size(self) -> int:
|
||||
return len(self.context_requests) + len(self.generation_requests)
|
||||
|
||||
@property
|
||||
def all_requests(self) -> chain[LlmRequest]:
|
||||
return chain(self.context_requests, self.generation_requests)
|
||||
def all_requests(self) -> list[LlmRequest]:
|
||||
return self.context_requests + self.generation_requests
|
||||
|
||||
|
||||
class RequestScheduler(ABC):
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
import itertools
|
||||
|
||||
from .llm_request import LlmRequest
|
||||
from .resource_manager import BaseResourceManager, SlotManager
|
||||
from .scheduler import ScheduledRequests
|
||||
@ -17,10 +15,8 @@ class SeqSlotManager(BaseResourceManager):
|
||||
return 1
|
||||
|
||||
def prepare_resources(self, scheduled_batch: ScheduledRequests) -> None:
|
||||
for llm_req in itertools.chain(scheduled_batch.context_requests,
|
||||
scheduled_batch.generation_requests):
|
||||
if (llm_req.is_context_init_state and llm_req.seq_slot is None) or \
|
||||
llm_req.is_disagg_generation_transmission_complete:
|
||||
for llm_req in scheduled_batch.all_requests():
|
||||
if llm_req.seq_slot is None or llm_req.is_disagg_generation_transmission_complete:
|
||||
llm_req.seq_slot = self.slot_manager.add_slot(
|
||||
llm_req.request_id)
|
||||
if llm_req.return_perf_metrics:
|
||||
|
||||
@ -10,7 +10,7 @@ from tensorrt_llm.mapping import Mapping
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..pyexecutor.llm_request import LlmRequest
|
||||
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
|
||||
from ..pyexecutor.sampler import SampleState, SampleStateTensors, TorchSampler
|
||||
from ..pyexecutor.sampler import TorchSampler
|
||||
from ..pyexecutor.scheduler import ScheduledRequests
|
||||
from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode
|
||||
from .mtp import MTPSampler
|
||||
@ -214,26 +214,6 @@ class Eagle3SpecMetadata(SpecMetadata):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Eagle3Sampler(TorchSampler):
|
||||
|
||||
def _batch_sample(self, scheduled_requests, model_outputs) -> SampleState:
|
||||
logits = model_outputs["logits"]
|
||||
new_tokens_device = torch.argmax(logits, dim=-1)
|
||||
if "d2t" in model_outputs:
|
||||
d2t = model_outputs["d2t"]
|
||||
new_tokens_device = d2t[new_tokens_device] + new_tokens_device
|
||||
device = SampleStateTensors(new_tokens=new_tokens_device)
|
||||
host = SampleStateTensors(
|
||||
new_tokens=new_tokens_device.to('cpu', non_blocking=True))
|
||||
sampler_event = torch.cuda.Event()
|
||||
sampler_event.record()
|
||||
return SampleState(scheduled_requests=scheduled_requests,
|
||||
logits=logits,
|
||||
device=device,
|
||||
host=host,
|
||||
sampler_event=sampler_event)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Eagle3OneModelSpecMetadata(SpecMetadata):
|
||||
# The hidden states
|
||||
@ -299,31 +279,10 @@ class Eagle3OneModelSpecMetadata(SpecMetadata):
|
||||
break
|
||||
|
||||
|
||||
class Eagle3Decoder(TorchSampler):
|
||||
class Eagle3OneModelSampler(MTPSampler):
|
||||
|
||||
def _batch_sample(self, scheduled_requests, model_outputs) -> SampleState:
|
||||
logits = model_outputs["logits"]
|
||||
new_tokens_device = torch.argmax(logits, dim=-1)
|
||||
if "d2t" in model_outputs:
|
||||
d2t = model_outputs["d2t"]
|
||||
new_tokens_device = d2t[new_tokens_device] + new_tokens_device
|
||||
new_tokens_host = new_tokens_device.to('cpu', non_blocking=True)
|
||||
new_tensors_device = {"new_tokens_device": new_tokens_device}
|
||||
new_tensors_host = {"new_tokens_host": new_tokens_host}
|
||||
decoder_event = torch.cuda.Event()
|
||||
decoder_event.record()
|
||||
return SampleState(scheduled_requests=scheduled_requests,
|
||||
logits=logits,
|
||||
new_tensors_device=new_tensors_device,
|
||||
new_tensors_host=new_tensors_host,
|
||||
decoder_event=decoder_event)
|
||||
|
||||
|
||||
class Eagle3OneModelDecoder(MTPSampler):
|
||||
|
||||
def __init__(self, max_seq_len: int, config: Eagle3Config):
|
||||
super().__init__(max_seq_len, None)
|
||||
self.draft_len = config.max_draft_tokens
|
||||
def __init__(self, args: TorchSampler.Args):
|
||||
super().__init__(args, nextn=args.max_draft_tokens)
|
||||
|
||||
|
||||
class Eagle3OneModelWorker(nn.Module):
|
||||
|
||||
@ -14,7 +14,7 @@ from ..pyexecutor.scheduler import ScheduledRequests
|
||||
from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
@dataclass(kw_only=True)
|
||||
class SampleStateTensorsMTP(SampleStateTensors):
|
||||
new_tokens_lens: torch.Tensor
|
||||
next_draft_tokens: torch.Tensor
|
||||
@ -248,12 +248,10 @@ class MTPSampler(TorchSampler):
|
||||
|
||||
SampleState = SampleStateMTP
|
||||
|
||||
def __init__(self, max_seq_len: int, config: MTPConfig):
|
||||
super().__init__(max_seq_len, False)
|
||||
def __init__(self, args: TorchSampler.Args, *, nextn: int):
|
||||
super().__init__(args)
|
||||
self.mapping = None
|
||||
self.draft_len = 0
|
||||
if config is not None:
|
||||
self.draft_len = config.num_nextn_predict_layers
|
||||
self.draft_len = nextn
|
||||
|
||||
def _draft_meet_max_token_stop_criteria(self, request: LlmRequest,
|
||||
num_tokens: int, beam_idx: int):
|
||||
@ -283,8 +281,9 @@ class MTPSampler(TorchSampler):
|
||||
if request.state != LlmRequestState.GENERATION_COMPLETE:
|
||||
new_token = new_tokens_list[idx][0]
|
||||
num_tokens = request.add_new_token(new_token, beam_idx)
|
||||
should_stop = self._handle_stop_criteria(
|
||||
request, new_token, num_tokens, beam_idx)
|
||||
should_stop = self._handle_stop_criteria(request,
|
||||
new_token,
|
||||
beam=beam_idx)
|
||||
if self._draft_meet_max_token_stop_criteria(
|
||||
request, num_tokens, beam_idx):
|
||||
should_stop = True
|
||||
@ -303,8 +302,9 @@ class MTPSampler(TorchSampler):
|
||||
for i in range(num_new_tokens):
|
||||
new_token = new_tokens[i]
|
||||
num_tokens = request.add_new_token(new_token, beam_idx)
|
||||
should_stop = self._handle_stop_criteria(
|
||||
request, new_token, num_tokens, beam_idx)
|
||||
should_stop = self._handle_stop_criteria(request,
|
||||
new_token,
|
||||
beam=beam_idx)
|
||||
if should_stop:
|
||||
break
|
||||
if self._draft_meet_max_token_stop_criteria(
|
||||
@ -344,7 +344,6 @@ class MTPSampler(TorchSampler):
|
||||
for request in scheduled_requests.context_requests:
|
||||
request.py_draft_tokens = [1] * self.draft_len
|
||||
return SampleStateMTP(scheduled_requests=scheduled_requests,
|
||||
logits=model_outputs['logits'],
|
||||
device=device,
|
||||
host=host,
|
||||
sampler_event=sampler_event)
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from tensorrt_llm._torch.pyexecutor.sampler import TorchSampler
|
||||
from tensorrt_llm._torch.speculative.interface import SpecConfig
|
||||
|
||||
from .draft_target import DraftTargetSpecMetadata
|
||||
from .eagle3 import (Eagle3OneModelDecoder, Eagle3OneModelSpecMetadata,
|
||||
Eagle3OneModelWorker, Eagle3ResourceManager, Eagle3Sampler,
|
||||
from .eagle3 import (Eagle3OneModelSampler, Eagle3OneModelSpecMetadata,
|
||||
Eagle3OneModelWorker, Eagle3ResourceManager,
|
||||
Eagle3SpecMetadata)
|
||||
from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler,
|
||||
MTPSpecMetadata, MTPWorker)
|
||||
@ -97,14 +100,17 @@ def get_spec_resource_manager(model_engine, draft_model_engine=None):
|
||||
return None
|
||||
|
||||
|
||||
def get_spec_decoder(max_seq_len, spec_config):
|
||||
def get_spec_decoder(sampler_args: TorchSampler.Args, spec_config: SpecConfig):
|
||||
if spec_config.spec_dec_mode.is_mtp():
|
||||
return MTPSampler(max_seq_len, spec_config)
|
||||
return MTPSampler(sampler_args,
|
||||
nextn=spec_config.num_nextn_predict_layers)
|
||||
if spec_config.spec_dec_mode.is_eagle3():
|
||||
return Eagle3Sampler(max_seq_len)
|
||||
# TorchSampler handles Eagle3 gracefully, by integrating d2t into the sampling process
|
||||
return TorchSampler(sampler_args)
|
||||
if spec_config.spec_dec_mode.is_eagle3_one_model():
|
||||
return Eagle3OneModelDecoder(max_seq_len, spec_config)
|
||||
return None
|
||||
return Eagle3OneModelSampler(sampler_args)
|
||||
raise ValueError(
|
||||
f"Unsupported speculative decoding mode: {spec_config.spec_dec_mode}")
|
||||
|
||||
|
||||
def get_spec_drafter(model_engine, spec_resource_manager=None):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user