TensorRT-LLMs/tensorrt_llm/runtime/model_runner_cpp.py
Kaiyu Xie f430a4b447
Update TensorRT-LLM (#1688)
* Update TensorRT-LLM

---------

Co-authored-by: IbrahimAmin <ibrahimamin532@gmail.com>
Co-authored-by: Fabian Joswig <fjosw@users.noreply.github.com>
Co-authored-by: Pzzzzz <hello-cd.plus@hotmail.com>
Co-authored-by: CoderHam <hemant@cohere.com>
Co-authored-by: Konstantin Lopuhin <kostia.lopuhin@gmail.com>
2024-05-28 20:07:49 +08:00

557 lines
25 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from pathlib import Path
from typing import List, Optional, Union
import torch
import tensorrt_llm.bindings.executor as trtllm
from .. import profiler
from ..bindings import DataType, GptJsonConfig, ModelConfig, WorldConfig
from ..logger import logger
from ..mapping import Mapping
from .generation import LogitsProcessor, SamplingConfig, StoppingCriteria
from .model_runner import ModelRunnerMixin
_bindings_dtype_to_torch_dtype_dict = {
DataType.FLOAT: torch.float,
DataType.HALF: torch.half,
DataType.INT8: torch.int8,
DataType.INT32: torch.int32,
DataType.BOOL: torch.bool,
DataType.UINT8: torch.uint8,
DataType.BF16: torch.bfloat16,
DataType.INT64: torch.int64
}
class ModelRunnerCpp(ModelRunnerMixin):
"""
An interface class that wraps Executor and provides generation methods.
"""
def __init__(self, executor: trtllm.Executor, max_batch_size: int,
max_input_len: int, max_seq_len: int, max_beam_width: int,
model_config: ModelConfig, world_config: WorldConfig) -> None:
self.session = executor
self.max_batch_size = max_batch_size
self.max_input_len = max_input_len
self.max_seq_len = max_seq_len
self.max_beam_width = max_beam_width
self.model_config = model_config
self.mapping = Mapping(world_size=world_config.tensor_parallelism *
world_config.pipeline_parallelism,
rank=world_config.rank,
gpus_per_node=world_config.gpus_per_node,
tp_size=world_config.tensor_parallelism,
pp_size=world_config.pipeline_parallelism)
self.world_config = world_config
@classmethod
def from_dir(
cls,
engine_dir: str,
*,
lora_dir: Optional[str] = None,
rank: int = 0,
max_batch_size: Optional[int] = None,
max_input_len: Optional[int] = None,
max_output_len: Optional[int] = None,
max_beam_width: Optional[int] = None,
max_attention_window_size: Optional[int] = None,
sink_token_length: Optional[int] = None,
kv_cache_free_gpu_memory_fraction: Optional[float] = None,
medusa_choices: list[list[int]] | None = None,
debug_mode: bool = False,
lora_ckpt_source: str = "hf",
gpu_weights_percent: float = 1,
max_tokens_in_paged_kv_cache: int | None = None,
kv_cache_enable_block_reuse: bool = False,
enable_chunked_context: bool = False,
) -> 'ModelRunnerCpp':
"""
Create a ModelRunnerCpp instance from an engine directory.
Args:
engine_dir (str):
The directory that contains the serialized engine files and config files.
lora_dir (str):
The directory that contains LoRA weights.
rank (int):
The runtime rank id.
max_batch_size (int):
The runtime batch size limit. If max_batch_size is not None, it should not
be larger than the engine's max_batch_size; otherwise, the engine's max_batch_size
will be used.
max_input_len (int):
The runtime input length limit. If max_input_len is not None, it should not
be larger than the engine's max_input_len; otherwise, the engine's max_input_len
will be used.
max_output_len (int):
The runtime output length limit. If max_output_len is not None, it should not
be larger than the engine's max_output_len; otherwise, the engine's max_output_len
will be used.
max_beam_width (int):
The runtime beam width limit. If max_beam_width is not None, it should not
be larger than the engine's max_beam_width; otherwise, the engine's max_beam_width
will be used.
max_attention_window_size (int):
The attention window size that controls the sliding window attention / cyclic kv cache behavior.
sink_token_length (int) :
The sink token length, default=0.
kv_cache_free_gpu_memory_fraction (float) :
Free GPU memory fraction that KV cache used.
debug_mode (bool):
Whether or not to turn on the debug mode.
medusa_choices (List[List[int]]):
Medusa choices to use when in Medusa decoding.
lora_ckpt_source (str):
Source of checkpoint. Should be one of ['hf', 'nemo'].
max_tokens_in_paged_kv_cache (int):
Maximum amount of tokens configured in kv cache.
kv_cache_enable_block_reuse (bool):
Enables block reuse in kv cache.
enable_chunked_context (bool):
Enables chunked context.
Returns:
ModelRunnerCpp: An instance of ModelRunnerCpp.
"""
config_path = Path(engine_dir) / "config.json"
json_config = GptJsonConfig.parse_file(config_path)
model_config = json_config.model_config
# Note: Parallel configuration will be fetched automatically from trtllm.Executor constructor
# by inspecting the json file. These lines serve the purpose of serving vocab_size_padded and
# num_layers properties.
tp_size = json_config.tensor_parallelism
pp_size = json_config.pipeline_parallelism
gpus_per_node = json_config.gpus_per_node
world_config = WorldConfig.mpi(tensor_parallelism=tp_size,
pipeline_parallelism=pp_size,
gpus_per_node=gpus_per_node)
assert rank == world_config.rank
profiler.start('load tensorrt_llm engine')
kv_cache_config = trtllm.KvCacheConfig(
free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction,
max_attention_window=max_attention_window_size,
sink_token_length=sink_token_length,
max_tokens=max_tokens_in_paged_kv_cache,
enable_block_reuse=kv_cache_enable_block_reuse)
decoding_config = trtllm.DecodingConfig()
if medusa_choices is not None:
decoding_config.medusa_choices = medusa_choices
executor = trtllm.Executor(
engine_dir, trtllm.ModelType.DECODER_ONLY,
trtllm.ExecutorConfig(max_beam_width=max_beam_width,
kv_cache_config=kv_cache_config,
decoding_config=decoding_config))
profiler.stop('load tensorrt_llm engine')
loading_time = profiler.elapsed_time_in_sec("load tensorrt_llm engine")
logger.info(f'Load engine takes: {loading_time} sec')
return cls(executor,
max_batch_size=max_batch_size,
max_input_len=max_input_len,
max_seq_len=max_input_len + max_output_len,
max_beam_width=max_beam_width,
model_config=model_config,
world_config=world_config)
def _check_inputs(self, batch_input_ids: List[List[int]],
sampling_config: trtllm.SamplingConfig, max_new_tokens):
batch_size = len(batch_input_ids)
if batch_size > self.max_batch_size:
raise RuntimeError(
f"Input batch size ({batch_size}) exceeds the engine or specified limit ({self.max_batch_size})"
)
input_lengths = [len(x) for x in batch_input_ids]
max_length = max(input_lengths)
if max_length > self.max_input_len:
raise RuntimeError(
f"Maximum input length ({max_length}) exceeds the engine or specified limit ({self.max_input_len})"
)
if max_length + max_new_tokens > self.max_seq_len:
raise RuntimeError(
f"Maximum input length ({max_length}) + maximum new tokens ({max_new_tokens}) exceeds the engine or specified limit ({self.max_seq_len})"
)
if sampling_config.beam_width > self.max_beam_width:
raise RuntimeError(
f"Num beams ({sampling_config.beam_width}) exceeds the engine or specified limit ({self.max_beam_width})"
)
@property
def dtype(self) -> torch.dtype:
bindings_dtype = self.model_config.data_type
return _bindings_dtype_to_torch_dtype_dict[bindings_dtype]
@property
def vocab_size(self) -> int:
return self.model_config.vocab_size
@property
def vocab_size_padded(self) -> int:
return self.model_config.vocab_size_padded(self.world_config.size)
@property
def hidden_size(self) -> int:
return self.model_config.hidden_size
@property
def num_heads(self) -> int:
return self.model_config.num_heads
@property
def num_layers(self) -> int:
return self.model_config.num_layers(
self.world_config.pipeline_parallelism)
@property
def max_sequence_length(self) -> int:
return self.max_seq_len
@property
def remove_input_padding(self) -> bool:
return self.model_config.use_packed_input
@property
def max_prompt_embedding_table_size(self) -> int:
return self.model_config.max_prompt_embedding_table_size
@property
def gather_context_logits(self) -> bool:
return self.model_config.compute_context_logits
@property
def gather_generation_logits(self) -> bool:
return self.model_config.compute_generation_logits
def generate(self,
batch_input_ids: List[torch.Tensor],
*,
sampling_config: Optional[SamplingConfig] = None,
lora_uids: Optional[list] = None,
streaming: bool = False,
stopping_criteria: Optional[StoppingCriteria] = None,
logits_processor: Optional[LogitsProcessor] = None,
max_new_tokens: int = 1,
end_id: int | None = None,
pad_id: int | None = None,
bad_words_list: list[list[int]] | None = None,
stop_words_list: list[list[int]] | None = None,
return_dict: bool = False,
output_sequence_lengths: bool = False,
output_log_probs: bool = False,
output_cum_log_probs: bool = False,
prompt_table: Optional[Union[str, torch.Tensor]] = None,
prompt_tasks: Optional[str] = None,
**kwargs) -> Union[torch.Tensor, dict]:
"""
Generates sequences of token ids.
The generation-controlling parameters are set in the sampling_config; it will be set to a default one if not passed.
You can override any sampling_config's attributes by passing corresponding parameters.
Args:
batch_input_ids (List[torch.Tensor]):
A list of input id tensors. Each tensor is of shape (sequence_length, ).
sampling_config (SamplingConfig):
The sampling configuration to be used as base parametrization for the generation call.
The passed **kwargs matching the sampling_config's attributes will override them.
If the sampling_config is not provided, a default will be used.
prompt_table (str or torch.Tensor):
The file path of prompt table (.npy format, exported by nemo_prompt_convert.py) or the prompt table itself.
prompt_tasks (str):
The prompt tuning task ids for the input batch, in format of comma-separated list (e.g., 0,3,1,0).
lora_uids (list):
The uids of LoRA weights for the input batch. Use -1 to disable the LoRA module.
streaming (bool):
Whether or not to use streaming mode for generation.
stopping_criteria (StoppingCriteria):
Custom stopping criteria.
logits_processor (LogitsProcessor):
Custom logits processors.
kwargs (Dict[str, Any]:
Ad hoc parametrization of sampling_config.
The passed **kwargs matching the sampling_config's attributes will override them.
Returns:
torch.Tensor or dict:
If return_dict=False, the method returns generated output_ids.
If return_dict=True, the method returns a dict of output_ids,
sequence_lengths (if sampling_config.output_sequence_lengths=True),
context_logits and generation_logits (if self.gather_context_logits=True and
self.gather_generation_logits=True, respectively).
"""
# TODO: Check if these can be supported now and support them
if lora_uids is not None:
raise RuntimeError("LoRA is not supported in C++ session.")
if stopping_criteria is not None:
raise RuntimeError(
"Stopping criteria is not supported in C++ session.")
if logits_processor is not None:
raise RuntimeError(
"Logits processor is not supported in C++ session.")
# If we are in a multi-gpu scenario, only rank 0 continues
if not self.session.can_enqueue_requests():
return []
# Convert tensor input to plain lists
batch_input_ids_list = [a.tolist() for a in batch_input_ids]
if sampling_config is None:
# Convert from old API of SamplingConfig
# Note: Due to a Python3.10 bug one cannot use inspect on it currently
accepted_parameters = [
"num_beams", "top_k", "top_p", "top_p_min", "top_p_reset_ids",
"top_p_decay", "random_seed", "temperature", "min_length",
"beam_search_diversity_rate", "repetition_penalty",
"presence_penalty", "frequency_penalty", "length_penalty",
"early_stopping"
]
rename_params = {"num_beams": "beam_width"}
sampling_params = {
k: v
for k, v in kwargs.items() if k in accepted_parameters
}
for k, v in rename_params.items():
if k in sampling_params:
sampling_params[v] = sampling_params.pop(k)
if "top_p" in sampling_params and sampling_params["top_p"] == 0.0:
sampling_params["top_p"] = None
sampling_config = trtllm.SamplingConfig(**sampling_params)
else:
sampling_config = copy.deepcopy(sampling_config)
self._check_inputs(batch_input_ids_list, sampling_config,
max_new_tokens)
output_config = trtllm.OutputConfig(
return_context_logits=self.gather_context_logits,
return_generation_logits=self.gather_generation_logits,
return_log_probs=output_log_probs,
)
prompt_tuning_configs = self._prepare_ptuning_executor(
batch_input_ids_list, prompt_table, prompt_tasks)
stop_words_list = self._prepare_words_list(stop_words_list,
len(batch_input_ids_list))
bad_words_list = self._prepare_words_list(bad_words_list,
len(batch_input_ids_list))
requests = [
trtllm.Request(input_token_ids=input_ids,
max_new_tokens=max_new_tokens,
pad_id=pad_id,
end_id=end_id,
stop_words=stop_words,
bad_words=bad_words,
sampling_config=sampling_config,
streaming=streaming,
output_config=output_config,
prompt_tuning_config=prompt_tuning_config)
for input_ids, stop_words, bad_words, prompt_tuning_config in zip(
batch_input_ids_list, stop_words_list, bad_words_list,
prompt_tuning_configs)
]
request_ids = self.session.enqueue_requests(requests)
if not streaming:
return self._initialize_and_fill_output(request_ids, end_id,
return_dict,
output_sequence_lengths,
output_log_probs,
output_cum_log_probs,
batch_input_ids, streaming)
else:
return self._stream(request_ids, end_id, return_dict,
output_sequence_lengths, output_log_probs,
output_cum_log_probs, batch_input_ids,
streaming, batch_input_ids_list)
def _prepare_words_list(self, words_list: torch.Tensor, batch_size: int):
if words_list is None:
return [None] * batch_size
words_list = words_list.tolist()
assert len(words_list) == batch_size
for words in words_list:
while len(words[0]) > 0 and words[0][-1] == 0:
words[0].pop()
words[1].pop()
return words_list
def _prepare_ptuning_executor(self, batch_input_ids_list, prompt_table,
prompt_tasks):
prompt_tuning_configs = len(batch_input_ids_list) * [None]
if prompt_table is not None:
prompt_table_data = self._prepare_embedding_table(
prompt_table).cuda()
if prompt_tasks is not None:
task_indices = [int(t) for t in prompt_tasks.split(',')]
assert len(task_indices) == len(batch_input_ids_list), \
f"Number of supplied tasks ({len(task_indices)}) must match input batch size ({len(batch_input_ids_list)})"
prompt_tuning_configs = [
trtllm.PromptTuningConfig(
embedding_table=prompt_table_data[task_indices[i]])
for i in range(len(batch_input_ids_list))
]
else:
prompt_tuning_configs = [
trtllm.PromptTuningConfig(
embedding_table=prompt_table_data[0])
for _ in range(len(batch_input_ids_list))
]
return prompt_tuning_configs
def _initialize_and_fill_output(self, request_ids, end_id, return_dict,
output_sequence_lengths, output_log_probs,
output_cum_log_probs, batch_input_ids,
streaming):
output_ids = [[] for _ in range(len(request_ids))]
for reqid_pos in range(len(request_ids)):
output_ids[reqid_pos] = [[] for _ in range(self.max_beam_width)]
multi_responses = self.session.await_responses(request_ids)
responses = [
response for responses in multi_responses for response in responses
]
return self._fill_output(responses, output_ids, end_id, return_dict,
output_sequence_lengths, output_log_probs,
output_cum_log_probs, batch_input_ids,
streaming, request_ids)
def _stream(self, request_ids, end_id, return_dict, output_sequence_lengths,
output_log_probs, output_cum_log_probs, batch_input_ids,
streaming, batch_input_ids_list):
output_ids = [[] for _ in range(len(request_ids))]
for reqid_pos in range(len(request_ids)):
output_ids[reqid_pos] = [
copy.deepcopy(batch_input_ids_list[reqid_pos])
for _ in range(self.max_beam_width)
]
finished_reqs = 0
while finished_reqs < len(request_ids):
responses = self.session.await_responses()
for response in responses:
if response.result.is_final:
finished_reqs += 1
yield self._fill_output(responses, output_ids, end_id, return_dict,
output_sequence_lengths, output_log_probs,
output_cum_log_probs, batch_input_ids,
streaming, request_ids)
def _fill_output(self, responses, output_ids, end_id, return_dict,
output_sequence_lengths, output_log_probs,
output_cum_log_probs, batch_input_ids, streaming,
request_ids):
cuda_device = torch.device("cuda")
for response in responses:
if response.has_error():
raise RuntimeError(response.error_msg)
reqid_pos = request_ids.index(response.request_id)
for beam, output_tokens in enumerate(
response.result.output_token_ids):
output_ids[reqid_pos][beam] += output_tokens
sequence_lengths = []
for output in output_ids:
sequence_lengths.append([len(a) for a in output])
if streaming:
output_ids = copy.deepcopy(output_ids)
for beam in output_ids:
for output_tokens in beam:
output_tokens += (self.max_seq_len -
len(output_tokens)) * [end_id]
output_ids = torch.tensor(output_ids,
dtype=torch.int32,
device=cuda_device)
if return_dict:
outputs = {'output_ids': output_ids}
if output_sequence_lengths:
outputs['sequence_lengths'] = torch.tensor(sequence_lengths,
dtype=torch.int32,
device=cuda_device)
if self.gather_context_logits:
outputs['context_logits'] = [
a.result.context_logits.cuda() for a in responses
if a.result.context_logits is not None
]
# Pad context_logits into a rectangle
max_input_length = max(a.shape[0]
for a in outputs['context_logits'])
for i, a in enumerate(outputs['context_logits']):
pad_length = max_input_length - a.shape[0]
outputs['context_logits'][i] = torch.nn.functional.pad(
a, [0, 0, 0, pad_length])
outputs['context_logits'] = torch.stack(
outputs['context_logits'])
if self.gather_generation_logits:
outputs['generation_logits'] = [
a.result.generation_logits.cuda() for a in responses
if a.result.generation_logits is not None
]
outputs['generation_logits'] = torch.stack(
outputs['generation_logits'])
if output_log_probs:
outputs['log_probs'] = [
a.result.log_probs for a in responses
if a.result.log_probs is not None
]
# Pad log_probs into a rectangle
max_seq_len = max(
len(a) for beam_list in outputs['log_probs']
for a in beam_list)
for i, a in enumerate(outputs['log_probs']):
for j, b in enumerate(a):
pad_length = max_seq_len - len(b)
outputs['log_probs'][i][j] = b + [0.0] * pad_length
outputs['log_probs'] = torch.tensor(outputs['log_probs'],
device=cuda_device)
if output_cum_log_probs:
outputs['cum_log_probs'] = [
a.result.cum_log_probs for a in responses
if a.result.cum_log_probs is not None
]
outputs['cum_log_probs'] = torch.tensor(
outputs['cum_log_probs'], device=cuda_device)
input_lengths = torch.tensor([x.size(0) for x in batch_input_ids],
dtype=torch.int32,
device=cuda_device)
outputs = self._prepare_outputs(outputs, input_lengths)
else:
outputs = output_ids
return outputs