mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] add the eos tokens in generation config to stop words in the sampler (#10389)
Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
This commit is contained in:
parent
8a04c05079
commit
82aaf98070
@ -541,7 +541,8 @@ class BaseWorker(GenerationExecutor):
|
||||
guided_decoding_params=request.sampling_params.
|
||||
_get_guided_decoding_params(),
|
||||
bad_words=request.sampling_params._get_bad_words(),
|
||||
stop_words=request.sampling_params._get_stop_words(),
|
||||
stop_words=[] if request.sampling_params.ignore_eos else
|
||||
request.sampling_params._get_stop_words(),
|
||||
embedding_bias=request.sampling_params.embedding_bias,
|
||||
lora_config=lora_config,
|
||||
prompt_tuning_config=prompt_tuning_config,
|
||||
|
||||
@ -373,14 +373,6 @@ class SamplingParams:
|
||||
if self.end_id is None:
|
||||
self.end_id = tokenizer.eos_token_id
|
||||
self.pad_id = tokenizer.pad_token_id
|
||||
# kimi_k2 model uses the eos_token_id in generation config
|
||||
if (
|
||||
hf_model_config is not None
|
||||
and hf_model_config.model_type == "kimi_k2"
|
||||
and generation_config is not None
|
||||
and isinstance(generation_config.eos_token_id, int)
|
||||
):
|
||||
self.end_id = generation_config.eos_token_id
|
||||
|
||||
if self.pad_id is None:
|
||||
self.pad_id = self.end_id
|
||||
@ -400,24 +392,20 @@ class SamplingParams:
|
||||
strs = [self.stop] if isinstance(self.stop, str) else self.stop
|
||||
self._stop_word_ids = [_encode(tokenizer, s, add_special_tokens) for s in strs]
|
||||
|
||||
# add generation_config to stop word list, only in qwen3-next now
|
||||
if (
|
||||
hf_model_config is not None
|
||||
and hf_model_config.model_type == "qwen3_next"
|
||||
and generation_config is not None
|
||||
and isinstance(generation_config.eos_token_id, List)
|
||||
and all(isinstance(i, int) for i in generation_config.eos_token_id)
|
||||
):
|
||||
if self._stop_word_ids:
|
||||
all_stop_tokens_id = set(i for sublist in self._stop_word_ids for i in sublist)
|
||||
from_generation_stop_tokens = [
|
||||
i for i in generation_config.eos_token_id if i not in all_stop_tokens_id
|
||||
]
|
||||
# Add eos_token_id in generation_config to stop_token_ids
|
||||
# The eos_token_id in generation_config are really mean to stop the text generation.
|
||||
if generation_config is not None and generation_config.eos_token_id is not None:
|
||||
if isinstance(generation_config.eos_token_id, int):
|
||||
generation_config.eos_token_id = [generation_config.eos_token_id]
|
||||
# else is always List[int]
|
||||
|
||||
if from_generation_stop_tokens:
|
||||
self._stop_word_ids.append(from_generation_stop_tokens)
|
||||
else:
|
||||
self._stop_word_ids = [generation_config.eos_token_id]
|
||||
if not self.stop_token_ids:
|
||||
self.stop_token_ids = []
|
||||
for stop_token in generation_config.eos_token_id:
|
||||
if stop_token != self.end_id and stop_token not in self.stop_token_ids:
|
||||
self.stop_token_ids.append(stop_token)
|
||||
if not self.stop_token_ids:
|
||||
self.stop_token_ids = None
|
||||
|
||||
return self
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user