mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] add the eos tokens in generation config to stop words in the sampler (#10389)
Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
This commit is contained in:
parent
8a04c05079
commit
82aaf98070
@ -541,7 +541,8 @@ class BaseWorker(GenerationExecutor):
|
|||||||
guided_decoding_params=request.sampling_params.
|
guided_decoding_params=request.sampling_params.
|
||||||
_get_guided_decoding_params(),
|
_get_guided_decoding_params(),
|
||||||
bad_words=request.sampling_params._get_bad_words(),
|
bad_words=request.sampling_params._get_bad_words(),
|
||||||
stop_words=request.sampling_params._get_stop_words(),
|
stop_words=[] if request.sampling_params.ignore_eos else
|
||||||
|
request.sampling_params._get_stop_words(),
|
||||||
embedding_bias=request.sampling_params.embedding_bias,
|
embedding_bias=request.sampling_params.embedding_bias,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
prompt_tuning_config=prompt_tuning_config,
|
prompt_tuning_config=prompt_tuning_config,
|
||||||
|
|||||||
@ -373,14 +373,6 @@ class SamplingParams:
|
|||||||
if self.end_id is None:
|
if self.end_id is None:
|
||||||
self.end_id = tokenizer.eos_token_id
|
self.end_id = tokenizer.eos_token_id
|
||||||
self.pad_id = tokenizer.pad_token_id
|
self.pad_id = tokenizer.pad_token_id
|
||||||
# kimi_k2 model uses the eos_token_id in generation config
|
|
||||||
if (
|
|
||||||
hf_model_config is not None
|
|
||||||
and hf_model_config.model_type == "kimi_k2"
|
|
||||||
and generation_config is not None
|
|
||||||
and isinstance(generation_config.eos_token_id, int)
|
|
||||||
):
|
|
||||||
self.end_id = generation_config.eos_token_id
|
|
||||||
|
|
||||||
if self.pad_id is None:
|
if self.pad_id is None:
|
||||||
self.pad_id = self.end_id
|
self.pad_id = self.end_id
|
||||||
@ -400,24 +392,20 @@ class SamplingParams:
|
|||||||
strs = [self.stop] if isinstance(self.stop, str) else self.stop
|
strs = [self.stop] if isinstance(self.stop, str) else self.stop
|
||||||
self._stop_word_ids = [_encode(tokenizer, s, add_special_tokens) for s in strs]
|
self._stop_word_ids = [_encode(tokenizer, s, add_special_tokens) for s in strs]
|
||||||
|
|
||||||
# add generation_config to stop word list, only in qwen3-next now
|
# Add eos_token_id in generation_config to stop_token_ids
|
||||||
if (
|
# The eos_token_id in generation_config are really mean to stop the text generation.
|
||||||
hf_model_config is not None
|
if generation_config is not None and generation_config.eos_token_id is not None:
|
||||||
and hf_model_config.model_type == "qwen3_next"
|
if isinstance(generation_config.eos_token_id, int):
|
||||||
and generation_config is not None
|
generation_config.eos_token_id = [generation_config.eos_token_id]
|
||||||
and isinstance(generation_config.eos_token_id, List)
|
# else is always List[int]
|
||||||
and all(isinstance(i, int) for i in generation_config.eos_token_id)
|
|
||||||
):
|
|
||||||
if self._stop_word_ids:
|
|
||||||
all_stop_tokens_id = set(i for sublist in self._stop_word_ids for i in sublist)
|
|
||||||
from_generation_stop_tokens = [
|
|
||||||
i for i in generation_config.eos_token_id if i not in all_stop_tokens_id
|
|
||||||
]
|
|
||||||
|
|
||||||
if from_generation_stop_tokens:
|
if not self.stop_token_ids:
|
||||||
self._stop_word_ids.append(from_generation_stop_tokens)
|
self.stop_token_ids = []
|
||||||
else:
|
for stop_token in generation_config.eos_token_id:
|
||||||
self._stop_word_ids = [generation_config.eos_token_id]
|
if stop_token != self.end_id and stop_token not in self.stop_token_ids:
|
||||||
|
self.stop_token_ids.append(stop_token)
|
||||||
|
if not self.stop_token_ids:
|
||||||
|
self.stop_token_ids = None
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user