TensorRT-LLMs/tensorrt_llm/llmapi/tokenizer.py
Enwei Zhu 3fa19ffa4e
test [TRTLLM-4477,TRTLLM-4481]: Accuracy test improvement (Part 3.5): Support GSM8K and GPQA (#3483)
* add gsm8k

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* fix gsm8k

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* add gpqa

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* conditional import lm_eval

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* gpqa in lm_eval

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* system prompt

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* shuffle

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* update AA prompt and regex

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* revert AA prompt and regex

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* integration to tests

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* fix

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* add DS-R1

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* fix and clean

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* fix

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* update tests

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* update

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* clean up

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* free_gpu_memory_fraction=0.8

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* fix

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

---------

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
2025-04-22 07:38:16 +08:00

267 lines
10 KiB
Python
Raw Blame History

from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers import (AutoTokenizer, PreTrainedTokenizerBase,
PreTrainedTokenizerFast)
class TokenizerBase(PreTrainedTokenizerBase):
''' This is a protocol for the tokenizer. Users can implement their own tokenizer by inheriting this class. '''
class TransformersTokenizer(TokenizerBase):
''' A wrapper for the Transformers' tokenizer.
This is the default tokenizer for LLM. '''
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self._all_special_tokens_set = set(self.tokenizer.all_special_tokens)
def __call__(self, text: str, *args, **kwargs) -> Any:
return self.tokenizer(text, *args, **kwargs)
@property
def eos_token_id(self) -> int:
return self.tokenizer.eos_token_id
@property
def pad_token_id(self) -> int:
return self.tokenizer.pad_token_id
@property
def name_or_path(self) -> str:
return self.tokenizer.name_or_path
def encode(self, text: str, *args, **kwargs) -> List[int]:
return self.tokenizer.encode(text, *args, **kwargs)
def decode(self, token_ids: List[int], *args, **kwargs) -> str:
return self.tokenizer.decode(token_ids, *args, **kwargs)
def batch_encode_plus(self, texts: List[str], *args, **kwargs) -> dict:
return self.tokenizer.batch_encode_plus(texts, *args, **kwargs)
def apply_chat_template(
self, conversation: Union[List[Dict[str, str]],
List[List[Dict[str, str]]]], *args,
**kwargs) -> Union[str, List[int], List[str], List[List[int]]]:
return self.tokenizer.apply_chat_template(conversation, *args, **kwargs)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.tokenizer})"
@classmethod
def from_pretrained(cls, pretrained_model_dir: str, **kwargs):
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir,
**kwargs)
return cls(tokenizer)
def save_pretrained(self, pretrained_model_dir: str, **kwargs):
self.tokenizer.save_pretrained(pretrained_model_dir, **kwargs)
def clean_up_tokenization(self, out_string: str) -> str:
return self.tokenizer.clean_up_tokenization(out_string)
@property
def clean_up_tokenization_spaces(self):
return self.tokenizer.clean_up_tokenization_spaces
@property
def is_fast(self) -> bool:
return self.tokenizer.is_fast
def get_added_vocab(self) -> Dict[str, int]:
# Assumed to be O(1) complexity
return self.tokenizer.get_added_vocab()
def convert_ids_to_tokens(
self,
ids: Union[int, List[int]],
skip_special_tokens: bool = False) -> Union[str, List[str]]:
return self.tokenizer.convert_ids_to_tokens(
ids, skip_special_tokens=skip_special_tokens)
def convert_tokens_to_string(
self,
tokens: List[str],
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = True) -> str:
# Adapted from
# https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/transformers_utils/detokenizer.py#L172
if self.is_fast or not self.get_added_vocab():
return self.tokenizer.convert_tokens_to_string(tokens)
sub_texts: List[str] = []
current_sub_text: List[str] = []
for token in tokens:
if skip_special_tokens and token in self._all_special_tokens_set:
continue
if token in self.get_added_vocab():
if current_sub_text:
sub_text = self.tokenizer.convert_tokens_to_string(
current_sub_text)
sub_texts.append(sub_text)
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_text = self.tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
if spaces_between_special_tokens:
return " ".join(sub_texts)
else:
return "".join(sub_texts)
def decode_incrementally(
self,
token_ids: List[int],
prev_text: Optional[str] = None,
states: Optional[dict] = None,
*,
flush: bool = False,
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = None,
spaces_between_special_tokens: bool = True) -> Tuple[str, dict]:
"""Incremental detokenization, typically used for streaming generation.
Args:
token_ids (List[int]): The incremental token ids.
prev_text (str): The previous decoded text. None if it's the first iteration.
states (dict): A dict that saves previous states for incremental detokenization. None if it's the first iteration.
flush (bool): Force flushing the pending tokens to decoded text.
skip_special_tokens (bool): Whether to remove special tokens in the decoding.
clean_up_tokenization_spaces (bool): Whether to clean up tokenization spaces.
spaces_between_special_tokens (bool): Whether to add spaces between special tokens.
Returns:
text, states (Tuple[str, dict]): text is the current decoded text, states is the current incremental detokenization states.
They should be passed to next incremental detokenization iteration, if any.
"""
# Adapted from
# https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/transformers_utils/detokenizer.py#L238
if prev_text is None:
prev_text = ""
if states is None:
states = {}
last_new_tokens = states.pop('last_new_tokens', [])
pending_tokens = states.pop('pending_tokens', [])
if len(last_new_tokens) > 0:
last_new_text = self.convert_tokens_to_string(
last_new_tokens,
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens)
else:
last_new_text = ""
new_tokens = self.convert_ids_to_tokens(
token_ids, skip_special_tokens=skip_special_tokens)
pending_tokens.extend(new_tokens)
curr_new_text = self.convert_tokens_to_string(
last_new_tokens + pending_tokens,
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens)
if not flush and (len(curr_new_text.rstrip()) <= len(
last_new_text.rstrip()) or curr_new_text.endswith("<EFBFBD>")):
return prev_text, {
'last_new_tokens': last_new_tokens,
'pending_tokens': pending_tokens
}
# Remove the part of last_new_text
curr_new_text = curr_new_text[len(last_new_text):]
if clean_up_tokenization_spaces is None:
clean_up_tokenization_spaces = self.clean_up_tokenization_spaces
if clean_up_tokenization_spaces:
curr_new_text = self.clean_up_tokenization(curr_new_text)
return prev_text + curr_new_text, {'last_new_tokens': pending_tokens}
def tokenizer_factory(obj: Optional[Union[str, Path, PreTrainedTokenizerBase,
TokenizerBase]] = None,
**kwargs) -> Optional[TokenizerBase]:
if obj is None:
return None
elif isinstance(obj, (str, Path)):
default_kwargs = {
'legacy': False,
'padding_side': 'left',
'truncation_side': 'left',
'trust_remote_code': True,
'use_fast': True,
}
default_kwargs.update(kwargs)
return TransformersTokenizer.from_pretrained(obj, **default_kwargs)
elif isinstance(obj, TokenizerBase):
return obj
elif isinstance(obj, PreTrainedTokenizerBase):
return TransformersTokenizer(obj)
else:
raise TypeError(f"Unrecognized tokenizer {obj}")
def _xgrammar_tokenizer_info(tokenizer):
# Reference: https://github.com/mlc-ai/xgrammar/blob/b9a16de54e1e0eff58da14c65750414cceaf1a6f/python/xgrammar/tokenizer_info.py#L133
if isinstance(tokenizer, TokenizerBase):
tokenizer = tokenizer.tokenizer
stop_token_ids = [tokenizer.eos_token_id]
try:
encoded_vocab = tokenizer.get_vocab()
encoded_vocab = [
token
for token, _ in sorted(encoded_vocab.items(), key=lambda x: x[1])
]
except AttributeError as e:
msg = (
f"Cannot get the vocabulary of the tokenizer {type(tokenizer)}. The tokenizer "
"should have a get_vocab method.")
raise ValueError(msg) from e
if isinstance(tokenizer, PreTrainedTokenizerFast):
backend_str = tokenizer.backend_tokenizer.to_str()
return {
"encoded_vocab": encoded_vocab,
"tokenizer_str": backend_str,
"stop_token_ids": stop_token_ids
}
elif ("vocab_file" in tokenizer.vocab_files_names
and "tiktoken" in tokenizer.vocab_files_names["vocab_file"]):
return {
"encoded_vocab": encoded_vocab,
"stop_token_ids": stop_token_ids
}
else:
raise ValueError(f"Unsupported tokenizer type: {type(tokenizer)}")
def load_hf_tokenizer(model_dir: str,
trust_remote_code: bool = True,
use_fast: bool = True) -> Optional[TransformersTokenizer]:
''' Load a tokenizer from a Hugging Face model directory.
Args:
model_dir (str): The model directory.
trust_remote_code (bool): Whether to trust the remote code.
use_fast (bool): Whether to use the fast tokenizer.
Returns:
A TransformersTokenizer object if the tokenizer is loaded successfully.
'''
try:
return TransformersTokenizer.from_pretrained(
model_dir,
legacy=False,
padding_side='left',
truncation_side='left',
trust_remote_code=trust_remote_code,
use_fast=use_fast)
except Exception:
return None