[None][feat] add the eos tokens in generation config to stop words in the sampler (#10389)

Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
This commit is contained in:
JadoTu 2026-01-06 09:24:03 +08:00 committed by GitHub
parent 8a04c05079
commit 82aaf98070
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 26 deletions

View File

@ -541,7 +541,8 @@ class BaseWorker(GenerationExecutor):
guided_decoding_params=request.sampling_params.
_get_guided_decoding_params(),
bad_words=request.sampling_params._get_bad_words(),
stop_words=request.sampling_params._get_stop_words(),
stop_words=[] if request.sampling_params.ignore_eos else
request.sampling_params._get_stop_words(),
embedding_bias=request.sampling_params.embedding_bias,
lora_config=lora_config,
prompt_tuning_config=prompt_tuning_config,

View File

@ -373,14 +373,6 @@ class SamplingParams:
if self.end_id is None:
self.end_id = tokenizer.eos_token_id
self.pad_id = tokenizer.pad_token_id
# kimi_k2 model uses the eos_token_id in generation config
if (
hf_model_config is not None
and hf_model_config.model_type == "kimi_k2"
and generation_config is not None
and isinstance(generation_config.eos_token_id, int)
):
self.end_id = generation_config.eos_token_id
if self.pad_id is None:
self.pad_id = self.end_id
@ -400,24 +392,20 @@ class SamplingParams:
strs = [self.stop] if isinstance(self.stop, str) else self.stop
self._stop_word_ids = [_encode(tokenizer, s, add_special_tokens) for s in strs]
# add generation_config to stop word list, only in qwen3-next now
if (
hf_model_config is not None
and hf_model_config.model_type == "qwen3_next"
and generation_config is not None
and isinstance(generation_config.eos_token_id, List)
and all(isinstance(i, int) for i in generation_config.eos_token_id)
):
if self._stop_word_ids:
all_stop_tokens_id = set(i for sublist in self._stop_word_ids for i in sublist)
from_generation_stop_tokens = [
i for i in generation_config.eos_token_id if i not in all_stop_tokens_id
]
# Add eos_token_id in generation_config to stop_token_ids
# The eos_token_id in generation_config are really mean to stop the text generation.
if generation_config is not None and generation_config.eos_token_id is not None:
if isinstance(generation_config.eos_token_id, int):
generation_config.eos_token_id = [generation_config.eos_token_id]
# else is always List[int]
if from_generation_stop_tokens:
self._stop_word_ids.append(from_generation_stop_tokens)
else:
self._stop_word_ids = [generation_config.eos_token_id]
if not self.stop_token_ids:
self.stop_token_ids = []
for stop_token in generation_config.eos_token_id:
if stop_token != self.end_id and stop_token not in self.stop_token_ids:
self.stop_token_ids.append(stop_token)
if not self.stop_token_ids:
self.stop_token_ids = None
return self