diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index f48d724658..96522f3f9b 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -1,10 +1,14 @@ from copy import copy, deepcopy from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import torch import tensorrt_llm.bindings + +if TYPE_CHECKING: + from tensorrt_llm._torch.pyexecutor.sampler import Strategy + from tensorrt_llm._torch.shared_tensor import SharedTensorContainer from tensorrt_llm.bindings import executor as tllm_executor from tensorrt_llm.executor.result import TokenLogprobs @@ -583,6 +587,8 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest): additional_outputs=additional_outputs) self.child_requests = [] + self._py_sampling_strategy: "Strategy | None" = None + self._py_embedding_bias_1d: Optional[torch.Tensor] = None if hasattr(self, 'embedding_bias') and self.embedding_bias is not None: # Pre-squeeze to 1D if needed (remove batch dimension) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index ac29abe979..ed9aae6ccb 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -35,6 +35,7 @@ from tensorrt_llm.bindings import ( CudaStream, DataType, ModelConfig, + SamplingConfigVector, WorldConfig, make_sampling_config, ) @@ -99,8 +100,8 @@ class LogProbsState: @dataclass(kw_only=True) class LogProbsStateList: - FloatState = list[list[list[float]]] - IntState = list[list[list[int]]] + FloatState: TypeAlias = list[list[list[float]]] + IntState: TypeAlias = list[list[list[int]]] sampled_vals: FloatState sampled_indices: IntState @@ -139,19 +140,26 @@ class SamplerEvent: self.cuda_event.synchronize() +GenericSampleStateTensorsHost = TypeVar("GenericSampleStateTensorsHost", bound=SampleStateTensors) +GenericSampleStateTensorsDevice = TypeVar( + "GenericSampleStateTensorsDevice", bound=SampleStateTensors +) + + @dataclass(kw_only=True) -class SampleState: +class SampleState(Generic[GenericSampleStateTensorsHost, GenericSampleStateTensorsDevice]): scheduled_requests: ScheduledRequests - device: Optional[SampleStateTensors] = None - host: Optional[SampleStateTensors] = None + device: Optional[GenericSampleStateTensorsDevice] = None + host: Optional[GenericSampleStateTensorsHost] = None sampler_event: Optional[SamplerEvent] = None -class Sampler(ABC): - SampleState = SampleState +GenericSampleState = TypeVar("GenericSampleState", bound=SampleState) + +class Sampler(ABC, Generic[GenericSampleState]): def setup_sampler_step(self, scheduled_requests: ScheduledRequests): pass @@ -165,13 +173,13 @@ class Sampler(ABC): model_outputs, num_context_logits_prefix_sum: list[int], resource_manager: Optional[ResourceManager] = None, - ) -> SampleState: + ) -> GenericSampleState: raise NotImplementedError @abstractmethod def update_requests( self, - state: SampleState, + state: GenericSampleState, resource_manager: Optional[ResourceManager] = None, ) -> None: raise NotImplementedError @@ -191,12 +199,14 @@ class Sampler(ABC): return True # conservative default -class EarlyStopSampler(Sampler): +class EarlyStopSampler(Sampler[SampleState[SampleStateTensors, SampleStateTensors]]): """ Use for skipping decoding step for non generation model, such as encoder-only model (e.g., BERT) or reward models that only need context phase. """ + SampleState: TypeAlias = SampleState[SampleStateTensors, SampleStateTensors] + @override def sample_async( self, @@ -206,7 +216,7 @@ class EarlyStopSampler(Sampler): resource_manager: Optional[ResourceManager] = None, ) -> SampleState: host = SampleStateTensors(new_tokens=torch.empty(0)) - return SampleState(scheduled_requests=scheduled_requests, host=host) + return self.SampleState(scheduled_requests=scheduled_requests, host=host) @override def update_requests( @@ -238,9 +248,7 @@ class MultimodalResult: @dataclass(kw_only=True) -class SampleStateWithMMResult: - scheduled_requests: ScheduledRequests - +class SampleStateWithMMResult(SampleState[SampleStateTensors, SampleStateTensors]): data: MultimodalResult @@ -270,34 +278,36 @@ class RequestGroupValueWithMetadata(RequestGroupValue): metadata: StrategyMetadata | None -class EarlyStopWithMMResult(Sampler): +class EarlyStopWithMMResult(Sampler[SampleStateWithMMResult]): """ Use for skipping decoding step for non generation model, and return the batch_output (such as mm_embeddings) """ + SampleState: TypeAlias = SampleStateWithMMResult + @override - def sample_async( # type: ignore + def sample_async( self, scheduled_requests: ScheduledRequests, model_outputs, num_context_logits_prefix_sum: list[int], resource_manager: Optional[ResourceManager] = None, - ) -> SampleStateWithMMResult: + ) -> SampleState: # from model_outputs to MultimodalResult data = MultimodalResult( mm_embeddings=model_outputs.pop("mm_embeddings"), extra_data={**model_outputs}, ) - return SampleStateWithMMResult(scheduled_requests=scheduled_requests, data=data) + return self.SampleState(scheduled_requests=scheduled_requests, data=data) @override - def update_requests( # type: ignore + def update_requests( self, - state: SampleStateWithMMResult, + state: SampleState, resource_manager: Optional[ResourceManager] = None, ) -> None: # resource_manager will not be used in this function, just for interface consistency. - assert isinstance(state, SampleStateWithMMResult) + assert isinstance(state, SampleState) scheduled_requests = state.scheduled_requests assert not scheduled_requests.generation_requests mm_embeddings = state.data.mm_embeddings @@ -310,9 +320,10 @@ class EarlyStopWithMMResult(Sampler): request.state = LlmRequestState.GENERATION_COMPLETE # NOTE: This is a hack: set finish reason manually and set the beam 0 request.set_finished_reason(FinishReason.LENGTH, 0) - if len(mm_embedding) != sum(request.multimodal_lengths): # type: ignore + assert request.multimodal_lengths is not None + if len(mm_embedding) != sum(request.multimodal_lengths): raise ValueError( - f"mm_embedding shape mismatch: {len(mm_embedding)} != {sum(request.multimodal_lengths)}" # type: ignore + f"mm_embedding shape mismatch: {len(mm_embedding)} != {sum(request.multimodal_lengths)}" ) request.py_result.append_mm_embeddings(mm_embedding) @@ -384,13 +395,13 @@ def _request_get_sampling_params(request: LlmRequest) -> UtilsSamplingParams: def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy: # We try to cache the resolved strategy on the request object, as it's not cheap enough to # resolve it on every iteration. - if hasattr(request, "py_sampling_strategy"): - return request.py_sampling_strategy # type: ignore + if request._py_sampling_strategy is not None: + return request._py_sampling_strategy params = _request_get_sampling_params(request) sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size) if _request_sampling_params_cachable(params): - request.py_sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size) # type: ignore + request._py_sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size) return sampling_strategy @@ -777,8 +788,7 @@ class SampleStateTensorsHostTorch(SampleStateTensors): @dataclass(kw_only=True) -class SampleStateTorch(SampleState): - host: SampleStateTensorsHostTorch # type: ignore +class SampleStateTorch(SampleState[SampleStateTensorsHostTorch, SampleStateTensors]): beam_histories: list[BeamHistory | None] | None = None @@ -885,8 +895,7 @@ class AsyncWorkerMixin: return SamplerEvent(cuda_event=cuda_event, worker_futures=worker_futures) -class TorchSampler(Sampler, AsyncWorkerMixin): - SampleState = SampleStateTorch +class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): DEFAULT_MAX_TOPK_LOGPROBS = 20 @override @@ -1527,7 +1536,6 @@ class TorchSampler(Sampler, AsyncWorkerMixin): ) ) - @override @override def setup_sampler_step(self, scheduled_requests: ScheduledRequests): """Setup the sampler step for the requests @@ -2019,11 +2027,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin): @torch.inference_mode() def update_requests( self, - state: Sampler.SampleState, + state: SampleStateTorch, resource_manager: Optional[ResourceManager] = None, ) -> None: - state = cast(SampleStateTorch, state) - assert isinstance(state, SampleStateTorch) if state.sampler_event: state.sampler_event.synchronize() @@ -3318,13 +3324,12 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors): @dataclass(kw_only=True) -class SampleStateTRTLLM(SampleState): +class SampleStateTRTLLM(SampleState[SampleStateTensorsHostTRTLLM, SampleStateTensors]): finalize_events: dict[str, CudaEvent] | None = None """`Optional` to accommodate `_forward_step_inter_pp` which creates a `SampleState` without `finalize_events`""" - host: Optional[SampleStateTensorsHostTRTLLM] = None # type: ignore -class TRTLLMSampler(Sampler, AsyncWorkerMixin): +class TRTLLMSampler(Sampler[SampleStateTRTLLM], AsyncWorkerMixin): MAX_DECODING_TOKENS = 1 # It must be 1 when not in speculative decoding SampleState = SampleStateTRTLLM @@ -3452,20 +3457,20 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): @torch.inference_mode() @nvtx_range("setup_sampler_step") - def setup_sampler_step(self, requests): # type: ignore + def setup_sampler_step(self, scheduled_requests): batch_slots, sampling_configs, lookahead_prompt, lookahead_algo_configs = ( self.algs.create_new_decoder_requests( # type: ignore self.model_config, self.world_config, self.decoding_config, - requests.context_requests, + scheduled_requests.context_requests, self.logits_datatype, self.store["decoder_input_buffers"][self.micro_batch_idx], self.store["decoder_state"], self.store["cuda_stream"], self.algs.decoder.decoder_stream, # type: ignore self.max_seq_len, - self.beam_width(requests.context_requests), + self.beam_width(scheduled_requests.context_requests), ) ) @@ -3482,11 +3487,11 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): lookahead_algo_configs, ) - adp = [r for r in requests.generation_requests if r.is_attention_dp_dummy] + adp = [r for r in scheduled_requests.generation_requests if r.is_attention_dp_dummy] batch_size = len(adp) if batch_size == 0: return - config = make_sampling_config([r.sampling_config for r in adp]) # type: ignore + config = make_sampling_config(cast(SamplingConfigVector, [r.sampling_config for r in adp])) slots = torch.tensor([r.py_seq_slot for r in adp], dtype=torch.int32) self.algs.decoder.underlying_decoder().setup(config, batch_size, slots) # type: ignore @@ -3592,7 +3597,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): @torch.inference_mode() @override - def update_requests( # type: ignore + def update_requests( self, state: SampleStateTRTLLM, resource_manager: Optional[ResourceManager] = None, @@ -3616,8 +3621,9 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): @nvtx_range("update_requests_single_beam_single_step") def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM): """Specialization of update_requests for single beam and single step""" - sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist() # type: ignore - finish_reasons = state.host.finish_reasons.flatten().tolist() # type: ignore + assert state.host is not None + sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist() + finish_reasons = state.host.finish_reasons.flatten().tolist() reqs = [ r for r in state.scheduled_requests.context_requests if not r.is_context_init_state @@ -3636,7 +3642,8 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): seq_slots = [] seq_slots_need_log_probs = [] for request in reqs: - if sequence_lengths_host_data[request.py_seq_slot] <= request.get_num_tokens(0): # type: ignore + assert request.py_seq_slot is not None + if sequence_lengths_host_data[request.py_seq_slot] <= request.get_num_tokens(0): continue reqs_with_new_tokens.append(request) @@ -3646,19 +3653,21 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): seq_slots_need_log_probs.append(request.py_seq_slot) # [maxTokensPerStep, batchSize, maxBeamWidth] - new_tokens = state.host.new_tokens[0, seq_slots, 0].tolist() # type: ignore + new_tokens = state.host.new_tokens[0, seq_slots, 0].tolist() add_new_tokens_to_requests(reqs_with_new_tokens, new_tokens, 0) # Log probs - if state.host.log_probs is not None: # type: ignore + assert state.host is not None + if state.host.log_probs is not None: # [batchSize, maxBeamWidth] - seq_last_idx = state.host.sequence_lengths[seq_slots_need_log_probs, 0] - 1 # type: ignore + seq_last_idx = state.host.sequence_lengths[seq_slots_need_log_probs, 0] - 1 # [batchSize, maxBeamWidth, maxSequenceLength] - log_probs_host = state.host.log_probs[ # type: ignore + log_probs_host = state.host.log_probs[ seq_slots_need_log_probs, 0, seq_last_idx ].tolist() # [batchSize, maxBeamWidth] - cum_log_probs_host = state.host.cum_log_probs[seq_slots_need_log_probs, 0].tolist() # type: ignore + assert state.host.cum_log_probs is not None + cum_log_probs_host = state.host.cum_log_probs[seq_slots_need_log_probs, 0].tolist() log_probs_idx = 0 for request, new_token in zip(reqs_with_new_tokens, new_tokens): @@ -3677,7 +3686,8 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): for request in reqs: request.py_decoding_iter += 1 - finished_state = FinishedState(finish_reasons[request.py_seq_slot]) # type: ignore + assert request.py_seq_slot is not None + finished_state = FinishedState(finish_reasons[request.py_seq_slot]) if finished_state.is_finished: request.state = LlmRequestState.GENERATION_COMPLETE finish_reason = finished_state.to_finish_reason() @@ -3690,14 +3700,15 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): state: SampleStateTRTLLM, beam_width: int, ): - new_tokens_host = state.host.new_tokens.tolist() # type: ignore - finished_sum_host = state.host.finished_sum.tolist() # type: ignore - finish_reasons = state.host.finish_reasons.flatten().tolist() # type: ignore - sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist() # type: ignore + assert state.host is not None + new_tokens_host = state.host.new_tokens.tolist() + finished_sum_host = state.host.finished_sum.tolist() + finish_reasons = state.host.finish_reasons.flatten().tolist() + sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist() cum_log_probs_host = ( - state.host.cum_log_probs.tolist() if state.host.cum_log_probs is not None else None # type: ignore + state.host.cum_log_probs.tolist() if state.host.cum_log_probs is not None else None ) - log_probs_host = state.host.log_probs.tolist() if state.host.log_probs is not None else None # type: ignore + log_probs_host = state.host.log_probs.tolist() if state.host.log_probs is not None else None finalize_events = state.finalize_events reqs = [ @@ -3710,6 +3721,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): for request in reqs: seq_slot = request.py_seq_slot + assert seq_slot is not None num_generated_tokens = request.num_draft_tokens + 1 current_num_of_tokens = request.max_beam_num_tokens num_new_tokens = [0] * beam_width @@ -3718,7 +3730,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): cum_log_probs = [] for beam_idx in range(beam_width): - seq_len = sequence_lengths_host_data[seq_slot * beam_width + beam_idx] # type: ignore + seq_len = sequence_lengths_host_data[seq_slot * beam_width + beam_idx] num_new_tokens[beam_idx] = min( num_generated_tokens, seq_len - request.get_num_tokens(beam_idx) ) @@ -3727,7 +3739,8 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): new_token = add_token(request, new_tokens_host, beam_idx=beam_idx, step=step) if request.py_return_log_probs: - assert state.host.log_probs is not None # type: ignore + assert state.host.log_probs is not None + assert log_probs_host is not None # NOTE: Log probs with drafting has not been tested yet. begin_log_probs_offset = ( request.prompt_len if request.sampling_config.beam_width == 1 else 0 @@ -3738,7 +3751,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): log_probs[beam_idx].append( { new_token: Logprob( - logprob=log_probs_host[seq_slot][beam_idx][ # type: ignore + logprob=log_probs_host[seq_slot][beam_idx][ begin_log_probs_offset + current_token ], rank=1, @@ -3747,9 +3760,10 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): ) if request.py_return_log_probs: - cum_log_probs.append(cum_log_probs_host[seq_slot][beam_idx]) # type: ignore + assert cum_log_probs_host is not None + cum_log_probs.append(cum_log_probs_host[seq_slot][beam_idx]) - finished_state = FinishedState(finish_reasons[seq_slot * beam_width + beam_idx]) # type: ignore + finished_state = FinishedState(finish_reasons[seq_slot * beam_width + beam_idx]) if finished_state.is_finished: finish_reason = finished_state.to_finish_reason() request.set_finished_reason(finish_reason, beam_idx) @@ -3766,7 +3780,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): if request.state != LlmRequestState.GENERATION_COMPLETE: request.py_decoding_iter += 1 - if finished_sum_host[seq_slot] == beam_width: # type: ignore + if finished_sum_host[seq_slot] == beam_width: request.state = LlmRequestState.GENERATION_COMPLETE for request in reqs: if finalize_events is not None and request.request_id in finalize_events: @@ -3789,19 +3803,25 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): request: LlmRequest which shall be post processed finalize_event: CudaEvent to wait for the finalize step to finish """ + assert state.host is not None seq_slot = request.py_seq_slot beam_width = request.sampling_config.beam_width # synchronize on the finalize event before continuing the post processing. # should be unnecessary, as already wait for the sampler event in update_requests + assert state.finalize_events is not None state.finalize_events[request.request_id].synchronize() # type: ignore # Get these values again, as they might have changed during the finalize step - output_ids_host = state.host.gathered_ids # type: ignore - sequence_lengths_host = state.host.sequence_lengths # type: ignore + output_ids_host = state.host.gathered_ids + assert output_ids_host is not None + sequence_lengths_host = state.host.sequence_lengths if request.py_return_log_probs: - log_probs_host = state.host.log_probs # type: ignore - cum_log_probs_host = state.host.cum_log_probs # type: ignore + log_probs_host = state.host.log_probs + cum_log_probs_host = state.host.cum_log_probs + else: + log_probs_host = None + cum_log_probs_host = None generated_tokens = [[0]] * beam_width log_probs = [[] for _ in range(beam_width)] @@ -3814,11 +3834,13 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): sequence_lengths_host[seq_slot, beam_idx].item() - request.py_prompt_len ) end = begin + generated_length - generated_tokens[beam_idx] = output_ids_host[seq_slot, beam_idx, begin:end].tolist() # type: ignore + generated_tokens[beam_idx] = output_ids_host[seq_slot, beam_idx, begin:end].tolist() # get the correct log probs for beam search if request.py_return_log_probs: - cum_log_probs.append(cum_log_probs_host[seq_slot, beam_idx].item()) # type: ignore + assert log_probs_host is not None + assert cum_log_probs_host is not None + cum_log_probs.append(cum_log_probs_host[seq_slot, beam_idx].item()) begin_log_probs_offset = ( request.prompt_len if request.sampling_config.beam_width == 1 else 0 @@ -3827,7 +3849,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): log_probs[beam_idx].append( { token: Logprob( - logprob=log_probs_host[seq_slot, beam_idx][ # type: ignore + logprob=log_probs_host[seq_slot, beam_idx][ begin_log_probs_offset + current_token ].item(), rank=1, diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 85ad553a50..a49c7f9b5e 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -1,3 +1,4 @@ +import sys from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional @@ -13,7 +14,7 @@ from ..distributed.ops import allgather from ..model_config import ModelConfig from ..pyexecutor.llm_request import LlmRequest, LlmRequestState from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager -from ..pyexecutor.sampler import (DEFAULT_BEAM_IDX, SampleState, +from ..pyexecutor.sampler import (DEFAULT_BEAM_IDX, Sampler, SampleState, SampleStateTensors, TorchSampler, add_token, int_tensor) from ..pyexecutor.scheduler import ScheduledRequests @@ -22,6 +23,11 @@ from .interface import SpecMetadata, SpecWorkerBase if TYPE_CHECKING: from tensorrt_llm.llmapi.llm_args import MTPDecodingConfig +if sys.version_info[:2] >= (3, 12): + from typing import override +else: + from typing_extensions import override + @dataclass(kw_only=True) class SampleStateTensorsMTP(SampleStateTensors): @@ -30,9 +36,8 @@ class SampleStateTensorsMTP(SampleStateTensors): @dataclass(kw_only=True) -class SampleStateMTP(SampleState): - device: SampleStateTensorsMTP - host: SampleStateTensorsMTP +class SampleStateMTP(SampleState[SampleStateTensorsMTP, SampleStateTensorsMTP]): + pass class MTPHiddenStatesManager(BaseResourceManager): @@ -210,13 +215,17 @@ class MTPSpecMetadata(SpecMetadata): self.slot_ids[:num_seqs].copy_(mtp_slot_ids, non_blocking=True) -class MTPSampler(TorchSampler): +class MTPSampler(Sampler[SampleStateMTP]): """ MTP sampler. """ SampleState = SampleStateMTP + @override + def is_generation_model(self) -> bool: + return True + @dataclass(kw_only=True) class Store(TorchSampler.Store): new_tokens: torch.Tensor diff --git a/tests/unittest/_torch/sampler/test_torch_sampler.py b/tests/unittest/_torch/sampler/test_torch_sampler.py index 61c79e77a6..0dc5dc0ad5 100644 --- a/tests/unittest/_torch/sampler/test_torch_sampler.py +++ b/tests/unittest/_torch/sampler/test_torch_sampler.py @@ -84,6 +84,9 @@ class TestStrategySelection: sampling_config: SamplingConfig is_context_init_state: bool # Torch sampler accesses this, but it does not affect this test + def __init__(self): + self._py_sampling_strategy: Strategy | None = None + def get_beam_width_by_iter( self, for_next_iteration: bool = False ) -> int: # Torch sampler accesses this, but it does not affect this test @@ -1561,8 +1564,12 @@ class TestBatchedSampling: generator: Optional[torch.Generator] = None, return_probs: bool, group_metadata: StrategyMetadata | None = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor] | float]: assert generator is sampler.get_generator(logits.device) + if isinstance(group_key, tuple): + assert isinstance(group_key[0], str) + else: + assert isinstance(group_key, str) nonlocal flashinfer_keys_seen assert (group_key, return_probs) not in flashinfer_keys_seen flashinfer_keys_seen.add((group_key, return_probs))