mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Bhuvanesh Sridharan <bhuvan.sridharan@gmail.com> Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
100 lines
4.1 KiB
Python
100 lines
4.1 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from typing import List, Tuple
|
|
|
|
from transformers import PreTrainedTokenizer
|
|
|
|
|
|
def make_context(
|
|
tokenizer: PreTrainedTokenizer,
|
|
query: str,
|
|
history: List[Tuple[str, str]] = None,
|
|
system: str = "You are a helpful assistant.",
|
|
max_input_length:
|
|
int = 2048, # if you want to change this, you need to change the max_input_len in tensorrt_llm_july-release-v1/examples/qwen/build.py
|
|
max_window_size: int = 6144,
|
|
chat_format: str = "chatml",
|
|
):
|
|
if history is None:
|
|
history = []
|
|
|
|
if chat_format == "chatml":
|
|
im_start, im_end = "<|im_start|>", "<|im_end|>"
|
|
im_start_tokens = [tokenizer.im_start_id]
|
|
im_end_tokens = [tokenizer.im_end_id]
|
|
nl_tokens = tokenizer.encode("\n")
|
|
|
|
def _tokenize_str(role, content):
|
|
return (f"{role}\n{content}",
|
|
tokenizer.encode(
|
|
role,
|
|
allowed_special=set(),
|
|
) + nl_tokens + tokenizer.encode(
|
|
content,
|
|
allowed_special=set(),
|
|
))
|
|
|
|
system_text, system_tokens_part = _tokenize_str("system", system)
|
|
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
|
|
raw_text = ""
|
|
context_tokens = []
|
|
|
|
for turn_query, turn_response in reversed(history):
|
|
query_text, query_tokens_part = _tokenize_str("user", turn_query)
|
|
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
|
|
|
|
response_text, response_tokens_part = _tokenize_str(
|
|
"assistant", turn_response)
|
|
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
|
|
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
|
|
prev_chat = (
|
|
f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
|
|
)
|
|
|
|
current_context_size = (len(system_tokens) +
|
|
len(next_context_tokens) +
|
|
len(context_tokens))
|
|
if current_context_size < max_window_size:
|
|
context_tokens = next_context_tokens + context_tokens
|
|
raw_text = prev_chat + raw_text
|
|
else:
|
|
break
|
|
|
|
context_tokens = system_tokens + context_tokens
|
|
raw_text = f"{im_start}{system_text}{im_end}" + raw_text
|
|
context_tokens += (nl_tokens + im_start_tokens +
|
|
_tokenize_str("user", query)[1] + im_end_tokens +
|
|
nl_tokens + im_start_tokens +
|
|
tokenizer.encode("assistant") + nl_tokens)
|
|
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
|
|
|
|
elif chat_format == "raw":
|
|
raw_text = query
|
|
context_tokens = tokenizer.encode(raw_text)
|
|
else:
|
|
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
|
|
# truncate to max_input_length, truncate from the front
|
|
return raw_text, context_tokens[-max_input_length:]
|
|
|
|
|
|
def get_stop_words_ids(chat_format, tokenizer):
|
|
if chat_format == "raw":
|
|
stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
|
|
elif chat_format == "chatml":
|
|
stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
|
|
else:
|
|
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
|
|
return stop_words_ids
|