mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[TRTLLM-10030][chore] promote SampleState to TypeVar + typing fixes (#11281)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
parent
eae480b713
commit
7d235cfb23
@ -1,10 +1,14 @@
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import tensorrt_llm.bindings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensorrt_llm._torch.pyexecutor.sampler import Strategy
|
||||
|
||||
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer
|
||||
from tensorrt_llm.bindings import executor as tllm_executor
|
||||
from tensorrt_llm.executor.result import TokenLogprobs
|
||||
@ -583,6 +587,8 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
|
||||
additional_outputs=additional_outputs)
|
||||
self.child_requests = []
|
||||
|
||||
self._py_sampling_strategy: "Strategy | None" = None
|
||||
|
||||
self._py_embedding_bias_1d: Optional[torch.Tensor] = None
|
||||
if hasattr(self, 'embedding_bias') and self.embedding_bias is not None:
|
||||
# Pre-squeeze to 1D if needed (remove batch dimension)
|
||||
|
||||
@ -35,6 +35,7 @@ from tensorrt_llm.bindings import (
|
||||
CudaStream,
|
||||
DataType,
|
||||
ModelConfig,
|
||||
SamplingConfigVector,
|
||||
WorldConfig,
|
||||
make_sampling_config,
|
||||
)
|
||||
@ -99,8 +100,8 @@ class LogProbsState:
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class LogProbsStateList:
|
||||
FloatState = list[list[list[float]]]
|
||||
IntState = list[list[list[int]]]
|
||||
FloatState: TypeAlias = list[list[list[float]]]
|
||||
IntState: TypeAlias = list[list[list[int]]]
|
||||
|
||||
sampled_vals: FloatState
|
||||
sampled_indices: IntState
|
||||
@ -139,19 +140,26 @@ class SamplerEvent:
|
||||
self.cuda_event.synchronize()
|
||||
|
||||
|
||||
GenericSampleStateTensorsHost = TypeVar("GenericSampleStateTensorsHost", bound=SampleStateTensors)
|
||||
GenericSampleStateTensorsDevice = TypeVar(
|
||||
"GenericSampleStateTensorsDevice", bound=SampleStateTensors
|
||||
)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SampleState:
|
||||
class SampleState(Generic[GenericSampleStateTensorsHost, GenericSampleStateTensorsDevice]):
|
||||
scheduled_requests: ScheduledRequests
|
||||
|
||||
device: Optional[SampleStateTensors] = None
|
||||
host: Optional[SampleStateTensors] = None
|
||||
device: Optional[GenericSampleStateTensorsDevice] = None
|
||||
host: Optional[GenericSampleStateTensorsHost] = None
|
||||
|
||||
sampler_event: Optional[SamplerEvent] = None
|
||||
|
||||
|
||||
class Sampler(ABC):
|
||||
SampleState = SampleState
|
||||
GenericSampleState = TypeVar("GenericSampleState", bound=SampleState)
|
||||
|
||||
|
||||
class Sampler(ABC, Generic[GenericSampleState]):
|
||||
def setup_sampler_step(self, scheduled_requests: ScheduledRequests):
|
||||
pass
|
||||
|
||||
@ -165,13 +173,13 @@ class Sampler(ABC):
|
||||
model_outputs,
|
||||
num_context_logits_prefix_sum: list[int],
|
||||
resource_manager: Optional[ResourceManager] = None,
|
||||
) -> SampleState:
|
||||
) -> GenericSampleState:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_requests(
|
||||
self,
|
||||
state: SampleState,
|
||||
state: GenericSampleState,
|
||||
resource_manager: Optional[ResourceManager] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
@ -191,12 +199,14 @@ class Sampler(ABC):
|
||||
return True # conservative default
|
||||
|
||||
|
||||
class EarlyStopSampler(Sampler):
|
||||
class EarlyStopSampler(Sampler[SampleState[SampleStateTensors, SampleStateTensors]]):
|
||||
"""
|
||||
Use for skipping decoding step for non generation model,
|
||||
such as encoder-only model (e.g., BERT) or reward models that only need context phase.
|
||||
"""
|
||||
|
||||
SampleState: TypeAlias = SampleState[SampleStateTensors, SampleStateTensors]
|
||||
|
||||
@override
|
||||
def sample_async(
|
||||
self,
|
||||
@ -206,7 +216,7 @@ class EarlyStopSampler(Sampler):
|
||||
resource_manager: Optional[ResourceManager] = None,
|
||||
) -> SampleState:
|
||||
host = SampleStateTensors(new_tokens=torch.empty(0))
|
||||
return SampleState(scheduled_requests=scheduled_requests, host=host)
|
||||
return self.SampleState(scheduled_requests=scheduled_requests, host=host)
|
||||
|
||||
@override
|
||||
def update_requests(
|
||||
@ -238,9 +248,7 @@ class MultimodalResult:
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SampleStateWithMMResult:
|
||||
scheduled_requests: ScheduledRequests
|
||||
|
||||
class SampleStateWithMMResult(SampleState[SampleStateTensors, SampleStateTensors]):
|
||||
data: MultimodalResult
|
||||
|
||||
|
||||
@ -270,34 +278,36 @@ class RequestGroupValueWithMetadata(RequestGroupValue):
|
||||
metadata: StrategyMetadata | None
|
||||
|
||||
|
||||
class EarlyStopWithMMResult(Sampler):
|
||||
class EarlyStopWithMMResult(Sampler[SampleStateWithMMResult]):
|
||||
"""
|
||||
Use for skipping decoding step for non generation model, and return the batch_output (such as mm_embeddings)
|
||||
"""
|
||||
|
||||
SampleState: TypeAlias = SampleStateWithMMResult
|
||||
|
||||
@override
|
||||
def sample_async( # type: ignore
|
||||
def sample_async(
|
||||
self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
model_outputs,
|
||||
num_context_logits_prefix_sum: list[int],
|
||||
resource_manager: Optional[ResourceManager] = None,
|
||||
) -> SampleStateWithMMResult:
|
||||
) -> SampleState:
|
||||
# from model_outputs to MultimodalResult
|
||||
data = MultimodalResult(
|
||||
mm_embeddings=model_outputs.pop("mm_embeddings"),
|
||||
extra_data={**model_outputs},
|
||||
)
|
||||
return SampleStateWithMMResult(scheduled_requests=scheduled_requests, data=data)
|
||||
return self.SampleState(scheduled_requests=scheduled_requests, data=data)
|
||||
|
||||
@override
|
||||
def update_requests( # type: ignore
|
||||
def update_requests(
|
||||
self,
|
||||
state: SampleStateWithMMResult,
|
||||
state: SampleState,
|
||||
resource_manager: Optional[ResourceManager] = None,
|
||||
) -> None:
|
||||
# resource_manager will not be used in this function, just for interface consistency.
|
||||
assert isinstance(state, SampleStateWithMMResult)
|
||||
assert isinstance(state, SampleState)
|
||||
scheduled_requests = state.scheduled_requests
|
||||
assert not scheduled_requests.generation_requests
|
||||
mm_embeddings = state.data.mm_embeddings
|
||||
@ -310,9 +320,10 @@ class EarlyStopWithMMResult(Sampler):
|
||||
request.state = LlmRequestState.GENERATION_COMPLETE
|
||||
# NOTE: This is a hack: set finish reason manually and set the beam 0
|
||||
request.set_finished_reason(FinishReason.LENGTH, 0)
|
||||
if len(mm_embedding) != sum(request.multimodal_lengths): # type: ignore
|
||||
assert request.multimodal_lengths is not None
|
||||
if len(mm_embedding) != sum(request.multimodal_lengths):
|
||||
raise ValueError(
|
||||
f"mm_embedding shape mismatch: {len(mm_embedding)} != {sum(request.multimodal_lengths)}" # type: ignore
|
||||
f"mm_embedding shape mismatch: {len(mm_embedding)} != {sum(request.multimodal_lengths)}"
|
||||
)
|
||||
|
||||
request.py_result.append_mm_embeddings(mm_embedding)
|
||||
@ -384,13 +395,13 @@ def _request_get_sampling_params(request: LlmRequest) -> UtilsSamplingParams:
|
||||
def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy:
|
||||
# We try to cache the resolved strategy on the request object, as it's not cheap enough to
|
||||
# resolve it on every iteration.
|
||||
if hasattr(request, "py_sampling_strategy"):
|
||||
return request.py_sampling_strategy # type: ignore
|
||||
if request._py_sampling_strategy is not None:
|
||||
return request._py_sampling_strategy
|
||||
|
||||
params = _request_get_sampling_params(request)
|
||||
sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size)
|
||||
if _request_sampling_params_cachable(params):
|
||||
request.py_sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size) # type: ignore
|
||||
request._py_sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size)
|
||||
return sampling_strategy
|
||||
|
||||
|
||||
@ -777,8 +788,7 @@ class SampleStateTensorsHostTorch(SampleStateTensors):
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SampleStateTorch(SampleState):
|
||||
host: SampleStateTensorsHostTorch # type: ignore
|
||||
class SampleStateTorch(SampleState[SampleStateTensorsHostTorch, SampleStateTensors]):
|
||||
beam_histories: list[BeamHistory | None] | None = None
|
||||
|
||||
|
||||
@ -885,8 +895,7 @@ class AsyncWorkerMixin:
|
||||
return SamplerEvent(cuda_event=cuda_event, worker_futures=worker_futures)
|
||||
|
||||
|
||||
class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
SampleState = SampleStateTorch
|
||||
class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
DEFAULT_MAX_TOPK_LOGPROBS = 20
|
||||
|
||||
@override
|
||||
@ -1527,7 +1536,6 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
@override
|
||||
def setup_sampler_step(self, scheduled_requests: ScheduledRequests):
|
||||
"""Setup the sampler step for the requests
|
||||
@ -2019,11 +2027,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
@torch.inference_mode()
|
||||
def update_requests(
|
||||
self,
|
||||
state: Sampler.SampleState,
|
||||
state: SampleStateTorch,
|
||||
resource_manager: Optional[ResourceManager] = None,
|
||||
) -> None:
|
||||
state = cast(SampleStateTorch, state)
|
||||
assert isinstance(state, SampleStateTorch)
|
||||
if state.sampler_event:
|
||||
state.sampler_event.synchronize()
|
||||
|
||||
@ -3318,13 +3324,12 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors):
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SampleStateTRTLLM(SampleState):
|
||||
class SampleStateTRTLLM(SampleState[SampleStateTensorsHostTRTLLM, SampleStateTensors]):
|
||||
finalize_events: dict[str, CudaEvent] | None = None
|
||||
"""`Optional` to accommodate `_forward_step_inter_pp` which creates a `SampleState` without `finalize_events`"""
|
||||
host: Optional[SampleStateTensorsHostTRTLLM] = None # type: ignore
|
||||
|
||||
|
||||
class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
class TRTLLMSampler(Sampler[SampleStateTRTLLM], AsyncWorkerMixin):
|
||||
MAX_DECODING_TOKENS = 1 # It must be 1 when not in speculative decoding
|
||||
SampleState = SampleStateTRTLLM
|
||||
|
||||
@ -3452,20 +3457,20 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
|
||||
@torch.inference_mode()
|
||||
@nvtx_range("setup_sampler_step")
|
||||
def setup_sampler_step(self, requests): # type: ignore
|
||||
def setup_sampler_step(self, scheduled_requests):
|
||||
batch_slots, sampling_configs, lookahead_prompt, lookahead_algo_configs = (
|
||||
self.algs.create_new_decoder_requests( # type: ignore
|
||||
self.model_config,
|
||||
self.world_config,
|
||||
self.decoding_config,
|
||||
requests.context_requests,
|
||||
scheduled_requests.context_requests,
|
||||
self.logits_datatype,
|
||||
self.store["decoder_input_buffers"][self.micro_batch_idx],
|
||||
self.store["decoder_state"],
|
||||
self.store["cuda_stream"],
|
||||
self.algs.decoder.decoder_stream, # type: ignore
|
||||
self.max_seq_len,
|
||||
self.beam_width(requests.context_requests),
|
||||
self.beam_width(scheduled_requests.context_requests),
|
||||
)
|
||||
)
|
||||
|
||||
@ -3482,11 +3487,11 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
lookahead_algo_configs,
|
||||
)
|
||||
|
||||
adp = [r for r in requests.generation_requests if r.is_attention_dp_dummy]
|
||||
adp = [r for r in scheduled_requests.generation_requests if r.is_attention_dp_dummy]
|
||||
batch_size = len(adp)
|
||||
if batch_size == 0:
|
||||
return
|
||||
config = make_sampling_config([r.sampling_config for r in adp]) # type: ignore
|
||||
config = make_sampling_config(cast(SamplingConfigVector, [r.sampling_config for r in adp]))
|
||||
slots = torch.tensor([r.py_seq_slot for r in adp], dtype=torch.int32)
|
||||
self.algs.decoder.underlying_decoder().setup(config, batch_size, slots) # type: ignore
|
||||
|
||||
@ -3592,7 +3597,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
|
||||
@torch.inference_mode()
|
||||
@override
|
||||
def update_requests( # type: ignore
|
||||
def update_requests(
|
||||
self,
|
||||
state: SampleStateTRTLLM,
|
||||
resource_manager: Optional[ResourceManager] = None,
|
||||
@ -3616,8 +3621,9 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
@nvtx_range("update_requests_single_beam_single_step")
|
||||
def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM):
|
||||
"""Specialization of update_requests for single beam and single step"""
|
||||
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist() # type: ignore
|
||||
finish_reasons = state.host.finish_reasons.flatten().tolist() # type: ignore
|
||||
assert state.host is not None
|
||||
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist()
|
||||
finish_reasons = state.host.finish_reasons.flatten().tolist()
|
||||
|
||||
reqs = [
|
||||
r for r in state.scheduled_requests.context_requests if not r.is_context_init_state
|
||||
@ -3636,7 +3642,8 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
seq_slots = []
|
||||
seq_slots_need_log_probs = []
|
||||
for request in reqs:
|
||||
if sequence_lengths_host_data[request.py_seq_slot] <= request.get_num_tokens(0): # type: ignore
|
||||
assert request.py_seq_slot is not None
|
||||
if sequence_lengths_host_data[request.py_seq_slot] <= request.get_num_tokens(0):
|
||||
continue
|
||||
|
||||
reqs_with_new_tokens.append(request)
|
||||
@ -3646,19 +3653,21 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
seq_slots_need_log_probs.append(request.py_seq_slot)
|
||||
|
||||
# [maxTokensPerStep, batchSize, maxBeamWidth]
|
||||
new_tokens = state.host.new_tokens[0, seq_slots, 0].tolist() # type: ignore
|
||||
new_tokens = state.host.new_tokens[0, seq_slots, 0].tolist()
|
||||
add_new_tokens_to_requests(reqs_with_new_tokens, new_tokens, 0)
|
||||
|
||||
# Log probs
|
||||
if state.host.log_probs is not None: # type: ignore
|
||||
assert state.host is not None
|
||||
if state.host.log_probs is not None:
|
||||
# [batchSize, maxBeamWidth]
|
||||
seq_last_idx = state.host.sequence_lengths[seq_slots_need_log_probs, 0] - 1 # type: ignore
|
||||
seq_last_idx = state.host.sequence_lengths[seq_slots_need_log_probs, 0] - 1
|
||||
# [batchSize, maxBeamWidth, maxSequenceLength]
|
||||
log_probs_host = state.host.log_probs[ # type: ignore
|
||||
log_probs_host = state.host.log_probs[
|
||||
seq_slots_need_log_probs, 0, seq_last_idx
|
||||
].tolist()
|
||||
# [batchSize, maxBeamWidth]
|
||||
cum_log_probs_host = state.host.cum_log_probs[seq_slots_need_log_probs, 0].tolist() # type: ignore
|
||||
assert state.host.cum_log_probs is not None
|
||||
cum_log_probs_host = state.host.cum_log_probs[seq_slots_need_log_probs, 0].tolist()
|
||||
|
||||
log_probs_idx = 0
|
||||
for request, new_token in zip(reqs_with_new_tokens, new_tokens):
|
||||
@ -3677,7 +3686,8 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
|
||||
for request in reqs:
|
||||
request.py_decoding_iter += 1
|
||||
finished_state = FinishedState(finish_reasons[request.py_seq_slot]) # type: ignore
|
||||
assert request.py_seq_slot is not None
|
||||
finished_state = FinishedState(finish_reasons[request.py_seq_slot])
|
||||
if finished_state.is_finished:
|
||||
request.state = LlmRequestState.GENERATION_COMPLETE
|
||||
finish_reason = finished_state.to_finish_reason()
|
||||
@ -3690,14 +3700,15 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
state: SampleStateTRTLLM,
|
||||
beam_width: int,
|
||||
):
|
||||
new_tokens_host = state.host.new_tokens.tolist() # type: ignore
|
||||
finished_sum_host = state.host.finished_sum.tolist() # type: ignore
|
||||
finish_reasons = state.host.finish_reasons.flatten().tolist() # type: ignore
|
||||
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist() # type: ignore
|
||||
assert state.host is not None
|
||||
new_tokens_host = state.host.new_tokens.tolist()
|
||||
finished_sum_host = state.host.finished_sum.tolist()
|
||||
finish_reasons = state.host.finish_reasons.flatten().tolist()
|
||||
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist()
|
||||
cum_log_probs_host = (
|
||||
state.host.cum_log_probs.tolist() if state.host.cum_log_probs is not None else None # type: ignore
|
||||
state.host.cum_log_probs.tolist() if state.host.cum_log_probs is not None else None
|
||||
)
|
||||
log_probs_host = state.host.log_probs.tolist() if state.host.log_probs is not None else None # type: ignore
|
||||
log_probs_host = state.host.log_probs.tolist() if state.host.log_probs is not None else None
|
||||
finalize_events = state.finalize_events
|
||||
|
||||
reqs = [
|
||||
@ -3710,6 +3721,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
|
||||
for request in reqs:
|
||||
seq_slot = request.py_seq_slot
|
||||
assert seq_slot is not None
|
||||
num_generated_tokens = request.num_draft_tokens + 1
|
||||
current_num_of_tokens = request.max_beam_num_tokens
|
||||
num_new_tokens = [0] * beam_width
|
||||
@ -3718,7 +3730,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
cum_log_probs = []
|
||||
|
||||
for beam_idx in range(beam_width):
|
||||
seq_len = sequence_lengths_host_data[seq_slot * beam_width + beam_idx] # type: ignore
|
||||
seq_len = sequence_lengths_host_data[seq_slot * beam_width + beam_idx]
|
||||
num_new_tokens[beam_idx] = min(
|
||||
num_generated_tokens, seq_len - request.get_num_tokens(beam_idx)
|
||||
)
|
||||
@ -3727,7 +3739,8 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
new_token = add_token(request, new_tokens_host, beam_idx=beam_idx, step=step)
|
||||
|
||||
if request.py_return_log_probs:
|
||||
assert state.host.log_probs is not None # type: ignore
|
||||
assert state.host.log_probs is not None
|
||||
assert log_probs_host is not None
|
||||
# NOTE: Log probs with drafting has not been tested yet.
|
||||
begin_log_probs_offset = (
|
||||
request.prompt_len if request.sampling_config.beam_width == 1 else 0
|
||||
@ -3738,7 +3751,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
log_probs[beam_idx].append(
|
||||
{
|
||||
new_token: Logprob(
|
||||
logprob=log_probs_host[seq_slot][beam_idx][ # type: ignore
|
||||
logprob=log_probs_host[seq_slot][beam_idx][
|
||||
begin_log_probs_offset + current_token
|
||||
],
|
||||
rank=1,
|
||||
@ -3747,9 +3760,10 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
)
|
||||
|
||||
if request.py_return_log_probs:
|
||||
cum_log_probs.append(cum_log_probs_host[seq_slot][beam_idx]) # type: ignore
|
||||
assert cum_log_probs_host is not None
|
||||
cum_log_probs.append(cum_log_probs_host[seq_slot][beam_idx])
|
||||
|
||||
finished_state = FinishedState(finish_reasons[seq_slot * beam_width + beam_idx]) # type: ignore
|
||||
finished_state = FinishedState(finish_reasons[seq_slot * beam_width + beam_idx])
|
||||
if finished_state.is_finished:
|
||||
finish_reason = finished_state.to_finish_reason()
|
||||
request.set_finished_reason(finish_reason, beam_idx)
|
||||
@ -3766,7 +3780,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
if request.state != LlmRequestState.GENERATION_COMPLETE:
|
||||
request.py_decoding_iter += 1
|
||||
|
||||
if finished_sum_host[seq_slot] == beam_width: # type: ignore
|
||||
if finished_sum_host[seq_slot] == beam_width:
|
||||
request.state = LlmRequestState.GENERATION_COMPLETE
|
||||
for request in reqs:
|
||||
if finalize_events is not None and request.request_id in finalize_events:
|
||||
@ -3789,19 +3803,25 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
request: LlmRequest which shall be post processed
|
||||
finalize_event: CudaEvent to wait for the finalize step to finish
|
||||
"""
|
||||
assert state.host is not None
|
||||
seq_slot = request.py_seq_slot
|
||||
beam_width = request.sampling_config.beam_width
|
||||
# synchronize on the finalize event before continuing the post processing.
|
||||
# should be unnecessary, as already wait for the sampler event in update_requests
|
||||
assert state.finalize_events is not None
|
||||
state.finalize_events[request.request_id].synchronize() # type: ignore
|
||||
|
||||
# Get these values again, as they might have changed during the finalize step
|
||||
output_ids_host = state.host.gathered_ids # type: ignore
|
||||
sequence_lengths_host = state.host.sequence_lengths # type: ignore
|
||||
output_ids_host = state.host.gathered_ids
|
||||
assert output_ids_host is not None
|
||||
sequence_lengths_host = state.host.sequence_lengths
|
||||
|
||||
if request.py_return_log_probs:
|
||||
log_probs_host = state.host.log_probs # type: ignore
|
||||
cum_log_probs_host = state.host.cum_log_probs # type: ignore
|
||||
log_probs_host = state.host.log_probs
|
||||
cum_log_probs_host = state.host.cum_log_probs
|
||||
else:
|
||||
log_probs_host = None
|
||||
cum_log_probs_host = None
|
||||
|
||||
generated_tokens = [[0]] * beam_width
|
||||
log_probs = [[] for _ in range(beam_width)]
|
||||
@ -3814,11 +3834,13 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
sequence_lengths_host[seq_slot, beam_idx].item() - request.py_prompt_len
|
||||
)
|
||||
end = begin + generated_length
|
||||
generated_tokens[beam_idx] = output_ids_host[seq_slot, beam_idx, begin:end].tolist() # type: ignore
|
||||
generated_tokens[beam_idx] = output_ids_host[seq_slot, beam_idx, begin:end].tolist()
|
||||
|
||||
# get the correct log probs for beam search
|
||||
if request.py_return_log_probs:
|
||||
cum_log_probs.append(cum_log_probs_host[seq_slot, beam_idx].item()) # type: ignore
|
||||
assert log_probs_host is not None
|
||||
assert cum_log_probs_host is not None
|
||||
cum_log_probs.append(cum_log_probs_host[seq_slot, beam_idx].item())
|
||||
|
||||
begin_log_probs_offset = (
|
||||
request.prompt_len if request.sampling_config.beam_width == 1 else 0
|
||||
@ -3827,7 +3849,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
log_probs[beam_idx].append(
|
||||
{
|
||||
token: Logprob(
|
||||
logprob=log_probs_host[seq_slot, beam_idx][ # type: ignore
|
||||
logprob=log_probs_host[seq_slot, beam_idx][
|
||||
begin_log_probs_offset + current_token
|
||||
].item(),
|
||||
rank=1,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
@ -13,7 +14,7 @@ from ..distributed.ops import allgather
|
||||
from ..model_config import ModelConfig
|
||||
from ..pyexecutor.llm_request import LlmRequest, LlmRequestState
|
||||
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
|
||||
from ..pyexecutor.sampler import (DEFAULT_BEAM_IDX, SampleState,
|
||||
from ..pyexecutor.sampler import (DEFAULT_BEAM_IDX, Sampler, SampleState,
|
||||
SampleStateTensors, TorchSampler, add_token,
|
||||
int_tensor)
|
||||
from ..pyexecutor.scheduler import ScheduledRequests
|
||||
@ -22,6 +23,11 @@ from .interface import SpecMetadata, SpecWorkerBase
|
||||
if TYPE_CHECKING:
|
||||
from tensorrt_llm.llmapi.llm_args import MTPDecodingConfig
|
||||
|
||||
if sys.version_info[:2] >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SampleStateTensorsMTP(SampleStateTensors):
|
||||
@ -30,9 +36,8 @@ class SampleStateTensorsMTP(SampleStateTensors):
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SampleStateMTP(SampleState):
|
||||
device: SampleStateTensorsMTP
|
||||
host: SampleStateTensorsMTP
|
||||
class SampleStateMTP(SampleState[SampleStateTensorsMTP, SampleStateTensorsMTP]):
|
||||
pass
|
||||
|
||||
|
||||
class MTPHiddenStatesManager(BaseResourceManager):
|
||||
@ -210,13 +215,17 @@ class MTPSpecMetadata(SpecMetadata):
|
||||
self.slot_ids[:num_seqs].copy_(mtp_slot_ids, non_blocking=True)
|
||||
|
||||
|
||||
class MTPSampler(TorchSampler):
|
||||
class MTPSampler(Sampler[SampleStateMTP]):
|
||||
"""
|
||||
MTP sampler.
|
||||
"""
|
||||
|
||||
SampleState = SampleStateMTP
|
||||
|
||||
@override
|
||||
def is_generation_model(self) -> bool:
|
||||
return True
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Store(TorchSampler.Store):
|
||||
new_tokens: torch.Tensor
|
||||
|
||||
@ -84,6 +84,9 @@ class TestStrategySelection:
|
||||
sampling_config: SamplingConfig
|
||||
is_context_init_state: bool # Torch sampler accesses this, but it does not affect this test
|
||||
|
||||
def __init__(self):
|
||||
self._py_sampling_strategy: Strategy | None = None
|
||||
|
||||
def get_beam_width_by_iter(
|
||||
self, for_next_iteration: bool = False
|
||||
) -> int: # Torch sampler accesses this, but it does not affect this test
|
||||
@ -1561,8 +1564,12 @@ class TestBatchedSampling:
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_probs: bool,
|
||||
group_metadata: StrategyMetadata | None = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor] | float]:
|
||||
assert generator is sampler.get_generator(logits.device)
|
||||
if isinstance(group_key, tuple):
|
||||
assert isinstance(group_key[0], str)
|
||||
else:
|
||||
assert isinstance(group_key, str)
|
||||
nonlocal flashinfer_keys_seen
|
||||
assert (group_key, return_probs) not in flashinfer_keys_seen
|
||||
flashinfer_keys_seen.add((group_key, return_probs))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user