mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
183 lines
4.8 KiB
Python
183 lines
4.8 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# Modified from https://github.com/openai/whisper/blob/main/whisper/tokenizer.py
|
|
import base64
|
|
import os
|
|
|
|
import tiktoken
|
|
|
|
LANGUAGES = {
|
|
"en": "english",
|
|
"zh": "chinese",
|
|
"de": "german",
|
|
"es": "spanish",
|
|
"ru": "russian",
|
|
"ko": "korean",
|
|
"fr": "french",
|
|
"ja": "japanese",
|
|
"pt": "portuguese",
|
|
"tr": "turkish",
|
|
"pl": "polish",
|
|
"ca": "catalan",
|
|
"nl": "dutch",
|
|
"ar": "arabic",
|
|
"sv": "swedish",
|
|
"it": "italian",
|
|
"id": "indonesian",
|
|
"hi": "hindi",
|
|
"fi": "finnish",
|
|
"vi": "vietnamese",
|
|
"he": "hebrew",
|
|
"uk": "ukrainian",
|
|
"el": "greek",
|
|
"ms": "malay",
|
|
"cs": "czech",
|
|
"ro": "romanian",
|
|
"da": "danish",
|
|
"hu": "hungarian",
|
|
"ta": "tamil",
|
|
"no": "norwegian",
|
|
"th": "thai",
|
|
"ur": "urdu",
|
|
"hr": "croatian",
|
|
"bg": "bulgarian",
|
|
"lt": "lithuanian",
|
|
"la": "latin",
|
|
"mi": "maori",
|
|
"ml": "malayalam",
|
|
"cy": "welsh",
|
|
"sk": "slovak",
|
|
"te": "telugu",
|
|
"fa": "persian",
|
|
"lv": "latvian",
|
|
"bn": "bengali",
|
|
"sr": "serbian",
|
|
"az": "azerbaijani",
|
|
"sl": "slovenian",
|
|
"kn": "kannada",
|
|
"et": "estonian",
|
|
"mk": "macedonian",
|
|
"br": "breton",
|
|
"eu": "basque",
|
|
"is": "icelandic",
|
|
"hy": "armenian",
|
|
"ne": "nepali",
|
|
"mn": "mongolian",
|
|
"bs": "bosnian",
|
|
"kk": "kazakh",
|
|
"sq": "albanian",
|
|
"sw": "swahili",
|
|
"gl": "galician",
|
|
"mr": "marathi",
|
|
"pa": "punjabi",
|
|
"si": "sinhala",
|
|
"km": "khmer",
|
|
"sn": "shona",
|
|
"yo": "yoruba",
|
|
"so": "somali",
|
|
"af": "afrikaans",
|
|
"oc": "occitan",
|
|
"ka": "georgian",
|
|
"be": "belarusian",
|
|
"tg": "tajik",
|
|
"sd": "sindhi",
|
|
"gu": "gujarati",
|
|
"am": "amharic",
|
|
"yi": "yiddish",
|
|
"lo": "lao",
|
|
"uz": "uzbek",
|
|
"fo": "faroese",
|
|
"ht": "haitian creole",
|
|
"ps": "pashto",
|
|
"tk": "turkmen",
|
|
"nn": "nynorsk",
|
|
"mt": "maltese",
|
|
"sa": "sanskrit",
|
|
"lb": "luxembourgish",
|
|
"my": "myanmar",
|
|
"bo": "tibetan",
|
|
"tl": "tagalog",
|
|
"mg": "malagasy",
|
|
"as": "assamese",
|
|
"tt": "tatar",
|
|
"haw": "hawaiian",
|
|
"ln": "lingala",
|
|
"ha": "hausa",
|
|
"ba": "bashkir",
|
|
"jw": "javanese",
|
|
"su": "sundanese",
|
|
"yue": "cantonese",
|
|
}
|
|
|
|
|
|
def get_tokenizer(name: str = "multilingual",
|
|
num_languages: int = 99,
|
|
tokenizer_dir: str = None):
|
|
if tokenizer_dir is None:
|
|
vocab_path = os.path.join(os.path.dirname(__file__),
|
|
f"assets/{name}.tiktoken")
|
|
else:
|
|
vocab_path = os.path.join(tokenizer_dir, f"{name}.tiktoken")
|
|
ranks = {
|
|
base64.b64decode(token): int(rank)
|
|
for token, rank in (line.split() for line in open(vocab_path) if line)
|
|
}
|
|
n_vocab = len(ranks)
|
|
special_tokens = {}
|
|
|
|
specials = [
|
|
"<|endoftext|>",
|
|
"<|startoftranscript|>",
|
|
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
|
"<|translate|>",
|
|
"<|transcribe|>",
|
|
"<|startoflm|>",
|
|
"<|startofprev|>",
|
|
"<|nospeech|>",
|
|
"<|notimestamps|>",
|
|
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
|
]
|
|
|
|
for token in specials:
|
|
special_tokens[token] = n_vocab
|
|
n_vocab += 1
|
|
|
|
return tiktoken.Encoding(
|
|
name=os.path.basename(vocab_path),
|
|
explicit_n_vocab=n_vocab,
|
|
pat_str=
|
|
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
|
mergeable_ranks=ranks,
|
|
special_tokens=special_tokens,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
enc = get_tokenizer()
|
|
mytest_str = "<|startofprev|> Nvidia<|startoftranscript|><|en|><|transcribe|>"
|
|
encoding = enc.encode(mytest_str, allowed_special=enc.special_tokens_set)
|
|
mystr = enc.decode([50361, 45, 43021, 50258, 50259, 50359])
|
|
mystr2 = enc.decode([50361, 46284, 50258, 50259, 50359])
|
|
print(encoding, mystr, mystr2)
|
|
print(
|
|
enc.encode("<|startoftranscript|>",
|
|
allowed_special=enc.special_tokens_set)[0])
|
|
|
|
my_zh_str = "好好学习"
|
|
encoding = enc.encode(my_zh_str, allowed_special=enc.special_tokens_set)
|
|
decoding = enc.decode(encoding)
|
|
print(type(decoding))
|
|
print(encoding, decoding)
|