TensorRT-LLMs/tensorrt_llm/hlapi/tokenizer.py
Kaiyu Xie d879430b04
Update TensorRT-LLM (#846)
* Update TensorRT-LLM

---------

Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-01-09 21:03:35 +08:00

72 lines
2.4 KiB
Python

from typing import Any, List
TokenIdsTy = List[int]
class TokenizerBase:
''' This is a protocol for the tokenizer. Users can implement their own tokenizer by inheriting this class. '''
@property
def eos_token_id(self) -> int:
''' Return the id of the end of sentence token. '''
raise NotImplementedError()
@property
def pad_token_id(self) -> int:
''' Return the id of the padding token. '''
raise NotImplementedError()
def encode(self, text: str, *args, **kwargs) -> TokenIdsTy:
''' Encode the text to token ids. '''
raise NotImplementedError()
def decode(self, token_ids: TokenIdsTy, *args, **kwargs) -> str:
''' Decode the token ids to text. '''
raise NotImplementedError()
def batch_encode_plus(self, texts: List[str]) -> dict:
''' Encode the batch of texts to token ids. '''
raise NotImplementedError()
def tokenize(self, text, *args, **kwargs):
return self.encode(text, *args, **kwargs)
def __call__(self, text: str, *args, **kwargs) -> Any:
''' Encode the text to token ids. '''
raise NotImplementedError()
class TransformersTokenizer(TokenizerBase):
''' A wrapper for the Transformers' tokenizer.
This is the default tokenizer for LLM. '''
@classmethod
def from_pretrained(self, pretrained_model_dir: str, **kwargs):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir,
**kwargs)
return TransformersTokenizer(tokenizer)
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(self, text: str, *args, **kwargs) -> Any:
return self.tokenizer(text, *args, **kwargs)
@property
def eos_token_id(self) -> int:
return self.tokenizer.eos_token_id
@property
def pad_token_id(self) -> int:
return self.tokenizer.pad_token_id
def encode(self, text: str, *args, **kwargs) -> TokenIdsTy:
return self.tokenizer.encode(text, *args, **kwargs)
def decode(self, token_ids: TokenIdsTy, *args, **kwargs) -> str:
return self.tokenizer.decode(token_ids, *args, **kwargs)
def batch_encode_plus(self, texts: List[str], *args, **kwargs) -> dict:
return self.tokenizer.batch_encode_plus(texts, *args, **kwargs)