[TRTLLM-10030][chore] promote SampleState to TypeVar + typing fixes (#11281)

Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
mpikulski 2026-02-05 16:33:22 +01:00 committed by GitHub
parent eae480b713
commit 7d235cfb23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 122 additions and 78 deletions

View File

@ -1,10 +1,14 @@
from copy import copy, deepcopy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch
import tensorrt_llm.bindings
if TYPE_CHECKING:
from tensorrt_llm._torch.pyexecutor.sampler import Strategy
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer
from tensorrt_llm.bindings import executor as tllm_executor
from tensorrt_llm.executor.result import TokenLogprobs
@ -583,6 +587,8 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
additional_outputs=additional_outputs)
self.child_requests = []
self._py_sampling_strategy: "Strategy | None" = None
self._py_embedding_bias_1d: Optional[torch.Tensor] = None
if hasattr(self, 'embedding_bias') and self.embedding_bias is not None:
# Pre-squeeze to 1D if needed (remove batch dimension)

View File

@ -35,6 +35,7 @@ from tensorrt_llm.bindings import (
CudaStream,
DataType,
ModelConfig,
SamplingConfigVector,
WorldConfig,
make_sampling_config,
)
@ -99,8 +100,8 @@ class LogProbsState:
@dataclass(kw_only=True)
class LogProbsStateList:
FloatState = list[list[list[float]]]
IntState = list[list[list[int]]]
FloatState: TypeAlias = list[list[list[float]]]
IntState: TypeAlias = list[list[list[int]]]
sampled_vals: FloatState
sampled_indices: IntState
@ -139,19 +140,26 @@ class SamplerEvent:
self.cuda_event.synchronize()
GenericSampleStateTensorsHost = TypeVar("GenericSampleStateTensorsHost", bound=SampleStateTensors)
GenericSampleStateTensorsDevice = TypeVar(
"GenericSampleStateTensorsDevice", bound=SampleStateTensors
)
@dataclass(kw_only=True)
class SampleState:
class SampleState(Generic[GenericSampleStateTensorsHost, GenericSampleStateTensorsDevice]):
scheduled_requests: ScheduledRequests
device: Optional[SampleStateTensors] = None
host: Optional[SampleStateTensors] = None
device: Optional[GenericSampleStateTensorsDevice] = None
host: Optional[GenericSampleStateTensorsHost] = None
sampler_event: Optional[SamplerEvent] = None
class Sampler(ABC):
SampleState = SampleState
GenericSampleState = TypeVar("GenericSampleState", bound=SampleState)
class Sampler(ABC, Generic[GenericSampleState]):
def setup_sampler_step(self, scheduled_requests: ScheduledRequests):
pass
@ -165,13 +173,13 @@ class Sampler(ABC):
model_outputs,
num_context_logits_prefix_sum: list[int],
resource_manager: Optional[ResourceManager] = None,
) -> SampleState:
) -> GenericSampleState:
raise NotImplementedError
@abstractmethod
def update_requests(
self,
state: SampleState,
state: GenericSampleState,
resource_manager: Optional[ResourceManager] = None,
) -> None:
raise NotImplementedError
@ -191,12 +199,14 @@ class Sampler(ABC):
return True # conservative default
class EarlyStopSampler(Sampler):
class EarlyStopSampler(Sampler[SampleState[SampleStateTensors, SampleStateTensors]]):
"""
Use for skipping decoding step for non generation model,
such as encoder-only model (e.g., BERT) or reward models that only need context phase.
"""
SampleState: TypeAlias = SampleState[SampleStateTensors, SampleStateTensors]
@override
def sample_async(
self,
@ -206,7 +216,7 @@ class EarlyStopSampler(Sampler):
resource_manager: Optional[ResourceManager] = None,
) -> SampleState:
host = SampleStateTensors(new_tokens=torch.empty(0))
return SampleState(scheduled_requests=scheduled_requests, host=host)
return self.SampleState(scheduled_requests=scheduled_requests, host=host)
@override
def update_requests(
@ -238,9 +248,7 @@ class MultimodalResult:
@dataclass(kw_only=True)
class SampleStateWithMMResult:
scheduled_requests: ScheduledRequests
class SampleStateWithMMResult(SampleState[SampleStateTensors, SampleStateTensors]):
data: MultimodalResult
@ -270,34 +278,36 @@ class RequestGroupValueWithMetadata(RequestGroupValue):
metadata: StrategyMetadata | None
class EarlyStopWithMMResult(Sampler):
class EarlyStopWithMMResult(Sampler[SampleStateWithMMResult]):
"""
Use for skipping decoding step for non generation model, and return the batch_output (such as mm_embeddings)
"""
SampleState: TypeAlias = SampleStateWithMMResult
@override
def sample_async( # type: ignore
def sample_async(
self,
scheduled_requests: ScheduledRequests,
model_outputs,
num_context_logits_prefix_sum: list[int],
resource_manager: Optional[ResourceManager] = None,
) -> SampleStateWithMMResult:
) -> SampleState:
# from model_outputs to MultimodalResult
data = MultimodalResult(
mm_embeddings=model_outputs.pop("mm_embeddings"),
extra_data={**model_outputs},
)
return SampleStateWithMMResult(scheduled_requests=scheduled_requests, data=data)
return self.SampleState(scheduled_requests=scheduled_requests, data=data)
@override
def update_requests( # type: ignore
def update_requests(
self,
state: SampleStateWithMMResult,
state: SampleState,
resource_manager: Optional[ResourceManager] = None,
) -> None:
# resource_manager will not be used in this function, just for interface consistency.
assert isinstance(state, SampleStateWithMMResult)
assert isinstance(state, SampleState)
scheduled_requests = state.scheduled_requests
assert not scheduled_requests.generation_requests
mm_embeddings = state.data.mm_embeddings
@ -310,9 +320,10 @@ class EarlyStopWithMMResult(Sampler):
request.state = LlmRequestState.GENERATION_COMPLETE
# NOTE: This is a hack: set finish reason manually and set the beam 0
request.set_finished_reason(FinishReason.LENGTH, 0)
if len(mm_embedding) != sum(request.multimodal_lengths): # type: ignore
assert request.multimodal_lengths is not None
if len(mm_embedding) != sum(request.multimodal_lengths):
raise ValueError(
f"mm_embedding shape mismatch: {len(mm_embedding)} != {sum(request.multimodal_lengths)}" # type: ignore
f"mm_embedding shape mismatch: {len(mm_embedding)} != {sum(request.multimodal_lengths)}"
)
request.py_result.append_mm_embeddings(mm_embedding)
@ -384,13 +395,13 @@ def _request_get_sampling_params(request: LlmRequest) -> UtilsSamplingParams:
def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy:
# We try to cache the resolved strategy on the request object, as it's not cheap enough to
# resolve it on every iteration.
if hasattr(request, "py_sampling_strategy"):
return request.py_sampling_strategy # type: ignore
if request._py_sampling_strategy is not None:
return request._py_sampling_strategy
params = _request_get_sampling_params(request)
sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size)
if _request_sampling_params_cachable(params):
request.py_sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size) # type: ignore
request._py_sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size)
return sampling_strategy
@ -777,8 +788,7 @@ class SampleStateTensorsHostTorch(SampleStateTensors):
@dataclass(kw_only=True)
class SampleStateTorch(SampleState):
host: SampleStateTensorsHostTorch # type: ignore
class SampleStateTorch(SampleState[SampleStateTensorsHostTorch, SampleStateTensors]):
beam_histories: list[BeamHistory | None] | None = None
@ -885,8 +895,7 @@ class AsyncWorkerMixin:
return SamplerEvent(cuda_event=cuda_event, worker_futures=worker_futures)
class TorchSampler(Sampler, AsyncWorkerMixin):
SampleState = SampleStateTorch
class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
DEFAULT_MAX_TOPK_LOGPROBS = 20
@override
@ -1527,7 +1536,6 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
)
)
@override
@override
def setup_sampler_step(self, scheduled_requests: ScheduledRequests):
"""Setup the sampler step for the requests
@ -2019,11 +2027,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
@torch.inference_mode()
def update_requests(
self,
state: Sampler.SampleState,
state: SampleStateTorch,
resource_manager: Optional[ResourceManager] = None,
) -> None:
state = cast(SampleStateTorch, state)
assert isinstance(state, SampleStateTorch)
if state.sampler_event:
state.sampler_event.synchronize()
@ -3318,13 +3324,12 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors):
@dataclass(kw_only=True)
class SampleStateTRTLLM(SampleState):
class SampleStateTRTLLM(SampleState[SampleStateTensorsHostTRTLLM, SampleStateTensors]):
finalize_events: dict[str, CudaEvent] | None = None
"""`Optional` to accommodate `_forward_step_inter_pp` which creates a `SampleState` without `finalize_events`"""
host: Optional[SampleStateTensorsHostTRTLLM] = None # type: ignore
class TRTLLMSampler(Sampler, AsyncWorkerMixin):
class TRTLLMSampler(Sampler[SampleStateTRTLLM], AsyncWorkerMixin):
MAX_DECODING_TOKENS = 1 # It must be 1 when not in speculative decoding
SampleState = SampleStateTRTLLM
@ -3452,20 +3457,20 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
@torch.inference_mode()
@nvtx_range("setup_sampler_step")
def setup_sampler_step(self, requests): # type: ignore
def setup_sampler_step(self, scheduled_requests):
batch_slots, sampling_configs, lookahead_prompt, lookahead_algo_configs = (
self.algs.create_new_decoder_requests( # type: ignore
self.model_config,
self.world_config,
self.decoding_config,
requests.context_requests,
scheduled_requests.context_requests,
self.logits_datatype,
self.store["decoder_input_buffers"][self.micro_batch_idx],
self.store["decoder_state"],
self.store["cuda_stream"],
self.algs.decoder.decoder_stream, # type: ignore
self.max_seq_len,
self.beam_width(requests.context_requests),
self.beam_width(scheduled_requests.context_requests),
)
)
@ -3482,11 +3487,11 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
lookahead_algo_configs,
)
adp = [r for r in requests.generation_requests if r.is_attention_dp_dummy]
adp = [r for r in scheduled_requests.generation_requests if r.is_attention_dp_dummy]
batch_size = len(adp)
if batch_size == 0:
return
config = make_sampling_config([r.sampling_config for r in adp]) # type: ignore
config = make_sampling_config(cast(SamplingConfigVector, [r.sampling_config for r in adp]))
slots = torch.tensor([r.py_seq_slot for r in adp], dtype=torch.int32)
self.algs.decoder.underlying_decoder().setup(config, batch_size, slots) # type: ignore
@ -3592,7 +3597,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
@torch.inference_mode()
@override
def update_requests( # type: ignore
def update_requests(
self,
state: SampleStateTRTLLM,
resource_manager: Optional[ResourceManager] = None,
@ -3616,8 +3621,9 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
@nvtx_range("update_requests_single_beam_single_step")
def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM):
"""Specialization of update_requests for single beam and single step"""
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist() # type: ignore
finish_reasons = state.host.finish_reasons.flatten().tolist() # type: ignore
assert state.host is not None
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist()
finish_reasons = state.host.finish_reasons.flatten().tolist()
reqs = [
r for r in state.scheduled_requests.context_requests if not r.is_context_init_state
@ -3636,7 +3642,8 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
seq_slots = []
seq_slots_need_log_probs = []
for request in reqs:
if sequence_lengths_host_data[request.py_seq_slot] <= request.get_num_tokens(0): # type: ignore
assert request.py_seq_slot is not None
if sequence_lengths_host_data[request.py_seq_slot] <= request.get_num_tokens(0):
continue
reqs_with_new_tokens.append(request)
@ -3646,19 +3653,21 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
seq_slots_need_log_probs.append(request.py_seq_slot)
# [maxTokensPerStep, batchSize, maxBeamWidth]
new_tokens = state.host.new_tokens[0, seq_slots, 0].tolist() # type: ignore
new_tokens = state.host.new_tokens[0, seq_slots, 0].tolist()
add_new_tokens_to_requests(reqs_with_new_tokens, new_tokens, 0)
# Log probs
if state.host.log_probs is not None: # type: ignore
assert state.host is not None
if state.host.log_probs is not None:
# [batchSize, maxBeamWidth]
seq_last_idx = state.host.sequence_lengths[seq_slots_need_log_probs, 0] - 1 # type: ignore
seq_last_idx = state.host.sequence_lengths[seq_slots_need_log_probs, 0] - 1
# [batchSize, maxBeamWidth, maxSequenceLength]
log_probs_host = state.host.log_probs[ # type: ignore
log_probs_host = state.host.log_probs[
seq_slots_need_log_probs, 0, seq_last_idx
].tolist()
# [batchSize, maxBeamWidth]
cum_log_probs_host = state.host.cum_log_probs[seq_slots_need_log_probs, 0].tolist() # type: ignore
assert state.host.cum_log_probs is not None
cum_log_probs_host = state.host.cum_log_probs[seq_slots_need_log_probs, 0].tolist()
log_probs_idx = 0
for request, new_token in zip(reqs_with_new_tokens, new_tokens):
@ -3677,7 +3686,8 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
for request in reqs:
request.py_decoding_iter += 1
finished_state = FinishedState(finish_reasons[request.py_seq_slot]) # type: ignore
assert request.py_seq_slot is not None
finished_state = FinishedState(finish_reasons[request.py_seq_slot])
if finished_state.is_finished:
request.state = LlmRequestState.GENERATION_COMPLETE
finish_reason = finished_state.to_finish_reason()
@ -3690,14 +3700,15 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
state: SampleStateTRTLLM,
beam_width: int,
):
new_tokens_host = state.host.new_tokens.tolist() # type: ignore
finished_sum_host = state.host.finished_sum.tolist() # type: ignore
finish_reasons = state.host.finish_reasons.flatten().tolist() # type: ignore
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist() # type: ignore
assert state.host is not None
new_tokens_host = state.host.new_tokens.tolist()
finished_sum_host = state.host.finished_sum.tolist()
finish_reasons = state.host.finish_reasons.flatten().tolist()
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist()
cum_log_probs_host = (
state.host.cum_log_probs.tolist() if state.host.cum_log_probs is not None else None # type: ignore
state.host.cum_log_probs.tolist() if state.host.cum_log_probs is not None else None
)
log_probs_host = state.host.log_probs.tolist() if state.host.log_probs is not None else None # type: ignore
log_probs_host = state.host.log_probs.tolist() if state.host.log_probs is not None else None
finalize_events = state.finalize_events
reqs = [
@ -3710,6 +3721,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
for request in reqs:
seq_slot = request.py_seq_slot
assert seq_slot is not None
num_generated_tokens = request.num_draft_tokens + 1
current_num_of_tokens = request.max_beam_num_tokens
num_new_tokens = [0] * beam_width
@ -3718,7 +3730,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
cum_log_probs = []
for beam_idx in range(beam_width):
seq_len = sequence_lengths_host_data[seq_slot * beam_width + beam_idx] # type: ignore
seq_len = sequence_lengths_host_data[seq_slot * beam_width + beam_idx]
num_new_tokens[beam_idx] = min(
num_generated_tokens, seq_len - request.get_num_tokens(beam_idx)
)
@ -3727,7 +3739,8 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
new_token = add_token(request, new_tokens_host, beam_idx=beam_idx, step=step)
if request.py_return_log_probs:
assert state.host.log_probs is not None # type: ignore
assert state.host.log_probs is not None
assert log_probs_host is not None
# NOTE: Log probs with drafting has not been tested yet.
begin_log_probs_offset = (
request.prompt_len if request.sampling_config.beam_width == 1 else 0
@ -3738,7 +3751,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
log_probs[beam_idx].append(
{
new_token: Logprob(
logprob=log_probs_host[seq_slot][beam_idx][ # type: ignore
logprob=log_probs_host[seq_slot][beam_idx][
begin_log_probs_offset + current_token
],
rank=1,
@ -3747,9 +3760,10 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
)
if request.py_return_log_probs:
cum_log_probs.append(cum_log_probs_host[seq_slot][beam_idx]) # type: ignore
assert cum_log_probs_host is not None
cum_log_probs.append(cum_log_probs_host[seq_slot][beam_idx])
finished_state = FinishedState(finish_reasons[seq_slot * beam_width + beam_idx]) # type: ignore
finished_state = FinishedState(finish_reasons[seq_slot * beam_width + beam_idx])
if finished_state.is_finished:
finish_reason = finished_state.to_finish_reason()
request.set_finished_reason(finish_reason, beam_idx)
@ -3766,7 +3780,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
if request.state != LlmRequestState.GENERATION_COMPLETE:
request.py_decoding_iter += 1
if finished_sum_host[seq_slot] == beam_width: # type: ignore
if finished_sum_host[seq_slot] == beam_width:
request.state = LlmRequestState.GENERATION_COMPLETE
for request in reqs:
if finalize_events is not None and request.request_id in finalize_events:
@ -3789,19 +3803,25 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
request: LlmRequest which shall be post processed
finalize_event: CudaEvent to wait for the finalize step to finish
"""
assert state.host is not None
seq_slot = request.py_seq_slot
beam_width = request.sampling_config.beam_width
# synchronize on the finalize event before continuing the post processing.
# should be unnecessary, as already wait for the sampler event in update_requests
assert state.finalize_events is not None
state.finalize_events[request.request_id].synchronize() # type: ignore
# Get these values again, as they might have changed during the finalize step
output_ids_host = state.host.gathered_ids # type: ignore
sequence_lengths_host = state.host.sequence_lengths # type: ignore
output_ids_host = state.host.gathered_ids
assert output_ids_host is not None
sequence_lengths_host = state.host.sequence_lengths
if request.py_return_log_probs:
log_probs_host = state.host.log_probs # type: ignore
cum_log_probs_host = state.host.cum_log_probs # type: ignore
log_probs_host = state.host.log_probs
cum_log_probs_host = state.host.cum_log_probs
else:
log_probs_host = None
cum_log_probs_host = None
generated_tokens = [[0]] * beam_width
log_probs = [[] for _ in range(beam_width)]
@ -3814,11 +3834,13 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
sequence_lengths_host[seq_slot, beam_idx].item() - request.py_prompt_len
)
end = begin + generated_length
generated_tokens[beam_idx] = output_ids_host[seq_slot, beam_idx, begin:end].tolist() # type: ignore
generated_tokens[beam_idx] = output_ids_host[seq_slot, beam_idx, begin:end].tolist()
# get the correct log probs for beam search
if request.py_return_log_probs:
cum_log_probs.append(cum_log_probs_host[seq_slot, beam_idx].item()) # type: ignore
assert log_probs_host is not None
assert cum_log_probs_host is not None
cum_log_probs.append(cum_log_probs_host[seq_slot, beam_idx].item())
begin_log_probs_offset = (
request.prompt_len if request.sampling_config.beam_width == 1 else 0
@ -3827,7 +3849,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
log_probs[beam_idx].append(
{
token: Logprob(
logprob=log_probs_host[seq_slot, beam_idx][ # type: ignore
logprob=log_probs_host[seq_slot, beam_idx][
begin_log_probs_offset + current_token
].item(),
rank=1,

View File

@ -1,3 +1,4 @@
import sys
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional
@ -13,7 +14,7 @@ from ..distributed.ops import allgather
from ..model_config import ModelConfig
from ..pyexecutor.llm_request import LlmRequest, LlmRequestState
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
from ..pyexecutor.sampler import (DEFAULT_BEAM_IDX, SampleState,
from ..pyexecutor.sampler import (DEFAULT_BEAM_IDX, Sampler, SampleState,
SampleStateTensors, TorchSampler, add_token,
int_tensor)
from ..pyexecutor.scheduler import ScheduledRequests
@ -22,6 +23,11 @@ from .interface import SpecMetadata, SpecWorkerBase
if TYPE_CHECKING:
from tensorrt_llm.llmapi.llm_args import MTPDecodingConfig
if sys.version_info[:2] >= (3, 12):
from typing import override
else:
from typing_extensions import override
@dataclass(kw_only=True)
class SampleStateTensorsMTP(SampleStateTensors):
@ -30,9 +36,8 @@ class SampleStateTensorsMTP(SampleStateTensors):
@dataclass(kw_only=True)
class SampleStateMTP(SampleState):
device: SampleStateTensorsMTP
host: SampleStateTensorsMTP
class SampleStateMTP(SampleState[SampleStateTensorsMTP, SampleStateTensorsMTP]):
pass
class MTPHiddenStatesManager(BaseResourceManager):
@ -210,13 +215,17 @@ class MTPSpecMetadata(SpecMetadata):
self.slot_ids[:num_seqs].copy_(mtp_slot_ids, non_blocking=True)
class MTPSampler(TorchSampler):
class MTPSampler(Sampler[SampleStateMTP]):
"""
MTP sampler.
"""
SampleState = SampleStateMTP
@override
def is_generation_model(self) -> bool:
return True
@dataclass(kw_only=True)
class Store(TorchSampler.Store):
new_tokens: torch.Tensor

View File

@ -84,6 +84,9 @@ class TestStrategySelection:
sampling_config: SamplingConfig
is_context_init_state: bool # Torch sampler accesses this, but it does not affect this test
def __init__(self):
self._py_sampling_strategy: Strategy | None = None
def get_beam_width_by_iter(
self, for_next_iteration: bool = False
) -> int: # Torch sampler accesses this, but it does not affect this test
@ -1561,8 +1564,12 @@ class TestBatchedSampling:
generator: Optional[torch.Generator] = None,
return_probs: bool,
group_metadata: StrategyMetadata | None = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor] | float]:
assert generator is sampler.get_generator(logits.device)
if isinstance(group_key, tuple):
assert isinstance(group_key[0], str)
else:
assert isinstance(group_key, str)
nonlocal flashinfer_keys_seen
assert (group_key, return_probs) not in flashinfer_keys_seen
flashinfer_keys_seen.add((group_key, return_probs))