TensorRT-LLMs/tensorrt_llm/models/qwen/utils.py
Kaiyu Xie 66ef1df492
Update TensorRT-LLM (#1492)
* Update TensorRT-LLM

---------

Co-authored-by: Loki <lokravi@amazon.com>
2024-04-24 14:44:22 +08:00

117 lines
4.3 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def make_context(
tokenizer,
query,
history,
system,
max_input_length,
max_window_size: int = 6144,
chat_format: str = "chatml",
):
if history is None:
history = []
if chat_format == "chatml":
im_start, im_end = "<|im_start|>", "<|im_end|>"
im_start_tokens = [tokenizer.im_start_id]
im_end_tokens = [tokenizer.im_end_id]
nl_tokens = tokenizer.encode("\n")
def _tokenize_str(role, content):
return (f"{role}\n{content}",
tokenizer.encode(
role,
allowed_special=set(),
) + nl_tokens + tokenizer.encode(
content,
allowed_special=set(),
))
system_text, system_tokens_part = _tokenize_str("system", system)
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
raw_text = ""
context_tokens = []
for turn_query, turn_response in reversed(history):
query_text, query_tokens_part = _tokenize_str("user", turn_query)
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
response_text, response_tokens_part = _tokenize_str(
"assistant", turn_response)
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
prev_chat = (
f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
)
current_context_size = (len(system_tokens) +
len(next_context_tokens) +
len(context_tokens))
if current_context_size < max_window_size:
context_tokens = next_context_tokens + context_tokens
raw_text = prev_chat + raw_text
else:
break
context_tokens = system_tokens + context_tokens
raw_text = f"{im_start}{system_text}{im_end}" + raw_text
context_tokens += (nl_tokens + im_start_tokens +
_tokenize_str("user", query)[1] + im_end_tokens +
nl_tokens + im_start_tokens +
tokenizer.encode("assistant") + nl_tokens)
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
elif chat_format == "raw":
raw_text = query
context_tokens = tokenizer.encode(raw_text)
else:
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
# truncate to max_input_length, truncate from the front
return raw_text, context_tokens[-max_input_length:]
def get_qwen_key_list(qwen_type):
qwen_key_list = [
"attn.c_attn", # attention.qkv
"attn.c_proj", # attention.dense
"mlp.w1", # mlp.gate
"mlp.w2", # mlp.fc
"mlp.c_proj", # mlp.proj
"ln_1", # input_layernorm
"ln_2", # post_layernorm
"transformer.wte", # vocabulary embedding
"transformer.ln_f", # final layer norm
]
qwen2_key_list = [
"self_attn.", # attention.qkv
"self_attn.o_proj", # attention.dense
"mlp.up_proj", # mlp.gate
"mlp.gate_proj", # mlp.fc
"mlp.down_proj", # mlp.proj
"input_layernorm", # input_layernorm
"post_attention_layernorm", # post_layernorm
"model.embed_tokens", # vocabulary embedding
"model.norm", # final layer norm
]
key_list = []
if qwen_type == 'qwen':
key_list.extend(qwen_key_list)
else:
key_list.extend(qwen2_key_list)
return key_list