TensorRT-LLMs/tensorrt_llm/hlapi/tokenizer.py
2024-03-19 17:36:42 +08:00

66 lines
2.1 KiB
Python

from pathlib import Path
from typing import Any, List, Union
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
TokenIdsTy = List[int]
class TokenizerBase(PreTrainedTokenizerBase):
''' This is a protocol for the tokenizer. Users can implement their own tokenizer by inheriting this class. '''
class TransformersTokenizer(TokenizerBase):
''' A wrapper for the Transformers' tokenizer.
This is the default tokenizer for LLM. '''
@classmethod
def from_pretrained(cls, pretrained_model_dir: str, **kwargs):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir,
**kwargs)
return TransformersTokenizer(tokenizer)
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(self, text: str, *args, **kwargs) -> Any:
return self.tokenizer(text, *args, **kwargs)
@property
def eos_token_id(self) -> int:
return self.tokenizer.eos_token_id
@property
def pad_token_id(self) -> int:
return self.tokenizer.pad_token_id
def encode(self, text: str, *args, **kwargs) -> TokenIdsTy:
return self.tokenizer.encode(text, *args, **kwargs)
def decode(self, token_ids: TokenIdsTy, *args, **kwargs) -> str:
return self.tokenizer.decode(token_ids, *args, **kwargs)
def batch_encode_plus(self, texts: List[str], *args, **kwargs) -> dict:
return self.tokenizer.batch_encode_plus(texts, *args, **kwargs)
def tokenizer_factory(
obj: Union[str, Path, TokenizerBase, PreTrainedTokenizerBase, None],
**kwargs) -> Union[TokenizerBase, PreTrainedTokenizerBase, None]:
if obj is None:
return None
if isinstance(obj, (str, Path)):
default_kwargs = {
'legacy': False,
'padding_side': 'left',
'truncation_side': 'left',
'trust_remote_code': True,
'use_fast': True,
}
default_kwargs.update(kwargs)
return AutoTokenizer.from_pretrained(obj, **kwargs)
return obj