mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-15 23:44:02 +08:00
[None][fix] modify qwen3-next sampling stop_tokens (#9331)
Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
This commit is contained in:
parent
11a0b276fb
commit
0582e54b61
@ -395,6 +395,25 @@ class SamplingParams:
|
||||
strs = [self.stop] if isinstance(self.stop, str) else self.stop
|
||||
self._stop_word_ids = [_encode(tokenizer, s, add_special_tokens) for s in strs]
|
||||
|
||||
# add generation_config to stop word list, only in qwen3-next now
|
||||
if (
|
||||
hf_model_config is not None
|
||||
and hf_model_config.model_type == "qwen3_next"
|
||||
and generation_config is not None
|
||||
and isinstance(generation_config.eos_token_id, List)
|
||||
and all(isinstance(i, int) for i in generation_config.eos_token_id)
|
||||
):
|
||||
if self._stop_word_ids:
|
||||
all_stop_tokens_id = set(i for sublist in self._stop_word_ids for i in sublist)
|
||||
from_generation_stop_tokens = [
|
||||
i for i in generation_config.eos_token_id if i not in all_stop_tokens_id
|
||||
]
|
||||
|
||||
if from_generation_stop_tokens:
|
||||
self._stop_word_ids.append(from_generation_stop_tokens)
|
||||
else:
|
||||
self._stop_word_ids = [generation_config.eos_token_id]
|
||||
|
||||
return self
|
||||
|
||||
def _get_bad_words(self) -> List[List[int]]:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user