mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-22 19:52:38 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
72 lines
2.4 KiB
Python
72 lines
2.4 KiB
Python
from typing import Any, List
|
|
|
|
TokenIdsTy = List[int]
|
|
|
|
|
|
class TokenizerBase:
|
|
''' This is a protocol for the tokenizer. Users can implement their own tokenizer by inheriting this class. '''
|
|
|
|
@property
|
|
def eos_token_id(self) -> int:
|
|
''' Return the id of the end of sentence token. '''
|
|
raise NotImplementedError()
|
|
|
|
@property
|
|
def pad_token_id(self) -> int:
|
|
''' Return the id of the padding token. '''
|
|
raise NotImplementedError()
|
|
|
|
def encode(self, text: str, *args, **kwargs) -> TokenIdsTy:
|
|
''' Encode the text to token ids. '''
|
|
raise NotImplementedError()
|
|
|
|
def decode(self, token_ids: TokenIdsTy, *args, **kwargs) -> str:
|
|
''' Decode the token ids to text. '''
|
|
raise NotImplementedError()
|
|
|
|
def batch_encode_plus(self, texts: List[str]) -> dict:
|
|
''' Encode the batch of texts to token ids. '''
|
|
raise NotImplementedError()
|
|
|
|
def tokenize(self, text, *args, **kwargs):
|
|
return self.encode(text, *args, **kwargs)
|
|
|
|
def __call__(self, text: str, *args, **kwargs) -> Any:
|
|
''' Encode the text to token ids. '''
|
|
raise NotImplementedError()
|
|
|
|
|
|
class TransformersTokenizer(TokenizerBase):
|
|
''' A wrapper for the Transformers' tokenizer.
|
|
This is the default tokenizer for LLM. '''
|
|
|
|
@classmethod
|
|
def from_pretrained(self, pretrained_model_dir: str, **kwargs):
|
|
from transformers import AutoTokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir,
|
|
**kwargs)
|
|
return TransformersTokenizer(tokenizer)
|
|
|
|
def __init__(self, tokenizer):
|
|
self.tokenizer = tokenizer
|
|
|
|
def __call__(self, text: str, *args, **kwargs) -> Any:
|
|
return self.tokenizer(text, *args, **kwargs)
|
|
|
|
@property
|
|
def eos_token_id(self) -> int:
|
|
return self.tokenizer.eos_token_id
|
|
|
|
@property
|
|
def pad_token_id(self) -> int:
|
|
return self.tokenizer.pad_token_id
|
|
|
|
def encode(self, text: str, *args, **kwargs) -> TokenIdsTy:
|
|
return self.tokenizer.encode(text, *args, **kwargs)
|
|
|
|
def decode(self, token_ids: TokenIdsTy, *args, **kwargs) -> str:
|
|
return self.tokenizer.decode(token_ids, *args, **kwargs)
|
|
|
|
def batch_encode_plus(self, texts: List[str], *args, **kwargs) -> dict:
|
|
return self.tokenizer.batch_encode_plus(texts, *args, **kwargs)
|