diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index b7ad63821a..c9d6e1f44b 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -395,6 +395,25 @@ class SamplingParams: strs = [self.stop] if isinstance(self.stop, str) else self.stop self._stop_word_ids = [_encode(tokenizer, s, add_special_tokens) for s in strs] + # add generation_config to stop word list, only in qwen3-next now + if ( + hf_model_config is not None + and hf_model_config.model_type == "qwen3_next" + and generation_config is not None + and isinstance(generation_config.eos_token_id, List) + and all(isinstance(i, int) for i in generation_config.eos_token_id) + ): + if self._stop_word_ids: + all_stop_tokens_id = set(i for sublist in self._stop_word_ids for i in sublist) + from_generation_stop_tokens = [ + i for i in generation_config.eos_token_id if i not in all_stop_tokens_id + ] + + if from_generation_stop_tokens: + self._stop_word_ids.append(from_generation_stop_tokens) + else: + self._stop_word_ids = [generation_config.eos_token_id] + return self def _get_bad_words(self) -> List[List[int]]: