# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy from pathlib import Path from typing import List, Optional, Union import torch import tensorrt_llm.bindings.executor as trtllm from .. import profiler from ..bindings import DataType, GptJsonConfig, ModelConfig, WorldConfig from ..logger import logger from ..mapping import Mapping from .generation import LogitsProcessor, SamplingConfig, StoppingCriteria from .model_runner import ModelRunnerMixin _bindings_dtype_to_torch_dtype_dict = { DataType.FLOAT: torch.float, DataType.HALF: torch.half, DataType.INT8: torch.int8, DataType.INT32: torch.int32, DataType.BOOL: torch.bool, DataType.UINT8: torch.uint8, DataType.BF16: torch.bfloat16, DataType.INT64: torch.int64 } class ModelRunnerCpp(ModelRunnerMixin): """ An interface class that wraps Executor and provides generation methods. """ def __init__(self, executor: trtllm.Executor, max_batch_size: int, max_input_len: int, max_seq_len: int, max_beam_width: int, model_config: ModelConfig, world_config: WorldConfig) -> None: self.session = executor self.max_batch_size = max_batch_size self.max_input_len = max_input_len self.max_seq_len = max_seq_len self.max_beam_width = max_beam_width self.model_config = model_config self.mapping = Mapping(world_size=world_config.tensor_parallelism * world_config.pipeline_parallelism, rank=world_config.rank, gpus_per_node=world_config.gpus_per_node, tp_size=world_config.tensor_parallelism, pp_size=world_config.pipeline_parallelism) self.world_config = world_config @classmethod def from_dir( cls, engine_dir: str, *, lora_dir: Optional[str] = None, rank: int = 0, max_batch_size: Optional[int] = None, max_input_len: Optional[int] = None, max_output_len: Optional[int] = None, max_beam_width: Optional[int] = None, max_attention_window_size: Optional[int] = None, sink_token_length: Optional[int] = None, kv_cache_free_gpu_memory_fraction: Optional[float] = None, medusa_choices: list[list[int]] | None = None, debug_mode: bool = False, lora_ckpt_source: str = "hf", gpu_weights_percent: float = 1, max_tokens_in_paged_kv_cache: int | None = None, kv_cache_enable_block_reuse: bool = False, enable_chunked_context: bool = False, is_enc_dec: bool = False, ) -> 'ModelRunnerCpp': """ Create a ModelRunnerCpp instance from an engine directory. Args: engine_dir (str): The directory that contains the serialized engine files and config files. lora_dir (str): The directory that contains LoRA weights. rank (int): The runtime rank id. max_batch_size (int): The runtime batch size limit. If max_batch_size is not None, it should not be larger than the engine's max_batch_size; otherwise, the engine's max_batch_size will be used. max_input_len (int): The runtime input length limit. If max_input_len is not None, it should not be larger than the engine's max_input_len; otherwise, the engine's max_input_len will be used. max_output_len (int): The runtime output length limit. If max_output_len is not None, it should not be larger than the engine's max_output_len; otherwise, the engine's max_output_len will be used. max_beam_width (int): The runtime beam width limit. If max_beam_width is not None, it should not be larger than the engine's max_beam_width; otherwise, the engine's max_beam_width will be used. max_attention_window_size (int): The attention window size that controls the sliding window attention / cyclic kv cache behavior. sink_token_length (int) : The sink token length, default=0. kv_cache_free_gpu_memory_fraction (float) : Free GPU memory fraction that KV cache used. debug_mode (bool): Whether or not to turn on the debug mode. medusa_choices (List[List[int]]): Medusa choices to use when in Medusa decoding. lora_ckpt_source (str): Source of checkpoint. Should be one of ['hf', 'nemo']. max_tokens_in_paged_kv_cache (int): Maximum amount of tokens configured in kv cache. kv_cache_enable_block_reuse (bool): Enables block reuse in kv cache. enable_chunked_context (bool): Enables chunked context. is_enc_dec (bool): Whether the model is encoder-decoder architecture. Returns: ModelRunnerCpp: An instance of ModelRunnerCpp. """ if is_enc_dec: encoder_config_path = Path(engine_dir) / "encoder" / "config.json" encoder_json_config = GptJsonConfig.parse_file(encoder_config_path) encoder_json_config.model_config decoder_config_path = Path(engine_dir) / "decoder" / "config.json" decoder_json_config = GptJsonConfig.parse_file(decoder_config_path) decoder_model_config = decoder_json_config.model_config tp_size = decoder_json_config.tensor_parallelism pp_size = decoder_json_config.pipeline_parallelism gpus_per_node = decoder_json_config.gpus_per_node world_config = WorldConfig.mpi(tensor_parallelism=tp_size, pipeline_parallelism=pp_size, gpus_per_node=gpus_per_node) assert rank == world_config.rank profiler.start('load tensorrt_llm engine') kv_cache_config = trtllm.KvCacheConfig( free_gpu_memory_fraction=0.45, # hardcode for now max_attention_window=max_attention_window_size, sink_token_length=sink_token_length) executor = trtllm.Executor( Path(engine_dir) / "encoder", Path(engine_dir) / "decoder", trtllm.ModelType.ENCODER_DECODER, trtllm.ExecutorConfig(max_beam_width=max_beam_width, kv_cache_config=kv_cache_config)) profiler.stop('load tensorrt_llm engine') loading_time = profiler.elapsed_time_in_sec( "load tensorrt_llm engine") logger.info(f'Load engine takes: {loading_time} sec') return cls(executor, max_batch_size=max_batch_size, max_input_len=max_input_len, max_seq_len=max_input_len + max_output_len, max_beam_width=max_beam_width, model_config=decoder_model_config, world_config=world_config) config_path = Path(engine_dir) / "config.json" json_config = GptJsonConfig.parse_file(config_path) model_config = json_config.model_config if max_batch_size is None: max_batch_size = model_config.max_batch_size if max_input_len is None: max_input_len = model_config.max_input_len if max_output_len is None: max_output_len = model_config.max_seq_len - model_config.max_input_len if max_beam_width is None: max_beam_width = model_config.max_beam_width # Note: Parallel configuration will be fetched automatically from trtllm.Executor constructor # by inspecting the json file. These lines serve the purpose of serving vocab_size_padded and # num_layers properties. tp_size = json_config.tensor_parallelism pp_size = json_config.pipeline_parallelism gpus_per_node = json_config.gpus_per_node world_config = WorldConfig.mpi(tensor_parallelism=tp_size, pipeline_parallelism=pp_size, gpus_per_node=gpus_per_node) assert rank == world_config.rank profiler.start('load tensorrt_llm engine') kv_cache_config = trtllm.KvCacheConfig( free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction, max_attention_window=max_attention_window_size, sink_token_length=sink_token_length, max_tokens=max_tokens_in_paged_kv_cache, enable_block_reuse=kv_cache_enable_block_reuse) decoding_config = trtllm.DecodingConfig() if medusa_choices is not None: decoding_config.medusa_choices = medusa_choices trtllm_config = trtllm.ExecutorConfig(max_beam_width=max_beam_width, kv_cache_config=kv_cache_config, decoding_config=decoding_config) trtllm_config.enable_chunked_context = enable_chunked_context executor = trtllm.Executor(engine_dir, trtllm.ModelType.DECODER_ONLY, trtllm_config) profiler.stop('load tensorrt_llm engine') loading_time = profiler.elapsed_time_in_sec("load tensorrt_llm engine") logger.info(f'Load engine takes: {loading_time} sec') return cls(executor, max_batch_size=max_batch_size, max_input_len=max_input_len, max_seq_len=max_input_len + max_output_len, max_beam_width=max_beam_width, model_config=model_config, world_config=world_config) def _check_inputs(self, batch_input_ids: List[List[int]], sampling_config: trtllm.SamplingConfig, max_new_tokens): batch_size = len(batch_input_ids) if batch_size > self.max_batch_size: raise RuntimeError( f"Input batch size ({batch_size}) exceeds the engine or specified limit ({self.max_batch_size})" ) input_lengths = [len(x) for x in batch_input_ids] max_length = max(input_lengths) if max_length > self.max_input_len: raise RuntimeError( f"Maximum input length ({max_length}) exceeds the engine or specified limit ({self.max_input_len})" ) if max_length + max_new_tokens > self.max_seq_len: raise RuntimeError( f"Maximum input length ({max_length}) + maximum new tokens ({max_new_tokens}) exceeds the engine or specified limit ({self.max_seq_len})" ) if sampling_config.beam_width > self.max_beam_width: raise RuntimeError( f"Num beams ({sampling_config.beam_width}) exceeds the engine or specified limit ({self.max_beam_width})" ) @property def dtype(self) -> torch.dtype: bindings_dtype = self.model_config.data_type return _bindings_dtype_to_torch_dtype_dict[bindings_dtype] @property def vocab_size(self) -> int: return self.model_config.vocab_size @property def vocab_size_padded(self) -> int: return self.model_config.vocab_size_padded(self.world_config.size) @property def hidden_size(self) -> int: return self.model_config.hidden_size @property def num_heads(self) -> int: return self.model_config.num_heads @property def num_layers(self) -> int: return self.model_config.num_layers( self.world_config.pipeline_parallelism) @property def max_sequence_length(self) -> int: return self.max_seq_len @property def remove_input_padding(self) -> bool: return self.model_config.use_packed_input @property def max_prompt_embedding_table_size(self) -> int: return self.model_config.max_prompt_embedding_table_size @property def gather_context_logits(self) -> bool: return self.model_config.compute_context_logits @property def gather_generation_logits(self) -> bool: return self.model_config.compute_generation_logits def generate(self, batch_input_ids: List[torch.Tensor], *, encoder_input_ids: List[torch.Tensor] = None, sampling_config: Optional[SamplingConfig] = None, lora_uids: Optional[list] = None, streaming: bool = False, stopping_criteria: Optional[StoppingCriteria] = None, logits_processor: Optional[LogitsProcessor] = None, max_new_tokens: int = 1, end_id: int | None = None, pad_id: int | None = None, bad_words_list: list[list[int]] | None = None, stop_words_list: list[list[int]] | None = None, return_dict: bool = False, output_sequence_lengths: bool = False, output_log_probs: bool = False, output_cum_log_probs: bool = False, prompt_table: Optional[Union[str, torch.Tensor]] = None, prompt_tasks: Optional[str] = None, **kwargs) -> Union[torch.Tensor, dict]: """ Generates sequences of token ids. The generation-controlling parameters are set in the sampling_config; it will be set to a default one if not passed. You can override any sampling_config's attributes by passing corresponding parameters. Args: batch_input_ids (List[torch.Tensor]): A list of input id tensors. Each tensor is of shape (sequence_length, ). sampling_config (SamplingConfig): The sampling configuration to be used as base parametrization for the generation call. The passed **kwargs matching the sampling_config's attributes will override them. If the sampling_config is not provided, a default will be used. prompt_table (str or torch.Tensor): The file path of prompt table (.npy format, exported by nemo_prompt_convert.py) or the prompt table itself. prompt_tasks (str): The prompt tuning task ids for the input batch, in format of comma-separated list (e.g., 0,3,1,0). lora_uids (list): The uids of LoRA weights for the input batch. Use -1 to disable the LoRA module. streaming (bool): Whether or not to use streaming mode for generation. stopping_criteria (StoppingCriteria): Custom stopping criteria. logits_processor (LogitsProcessor): Custom logits processors. kwargs (Dict[str, Any]: Ad hoc parametrization of sampling_config. The passed **kwargs matching the sampling_config's attributes will override them. Returns: torch.Tensor or dict: If return_dict=False, the method returns generated output_ids. If return_dict=True, the method returns a dict of output_ids, sequence_lengths (if sampling_config.output_sequence_lengths=True), context_logits and generation_logits (if self.gather_context_logits=True and self.gather_generation_logits=True, respectively). """ # TODO: Check if these can be supported now and support them if lora_uids is not None: raise RuntimeError("LoRA is not supported in C++ session.") if stopping_criteria is not None: raise RuntimeError( "Stopping criteria is not supported in C++ session.") if logits_processor is not None: raise RuntimeError( "Logits processor is not supported in C++ session.") # If we are in a multi-gpu scenario, only rank 0 continues if not self.session.can_enqueue_requests(): return [] # Convert tensor input to plain lists batch_input_ids_list = [a.tolist() for a in batch_input_ids] encoder_input_ids_list = [a.tolist() for a in encoder_input_ids ] if encoder_input_ids else None if sampling_config is None: # Convert from old API of SamplingConfig # Note: Due to a Python3.10 bug one cannot use inspect on it currently accepted_parameters = [ "num_beams", "top_k", "top_p", "top_p_min", "top_p_reset_ids", "top_p_decay", "random_seed", "temperature", "min_length", "beam_search_diversity_rate", "repetition_penalty", "presence_penalty", "frequency_penalty", "length_penalty", "early_stopping", "no_repeat_ngram_size" ] rename_params = {"num_beams": "beam_width"} sampling_params = { k: v for k, v in kwargs.items() if k in accepted_parameters } for k, v in rename_params.items(): if k in sampling_params: sampling_params[v] = sampling_params.pop(k) if "top_p" in sampling_params and sampling_params["top_p"] == 0.0: sampling_params["top_p"] = None sampling_config = trtllm.SamplingConfig(**sampling_params) else: sampling_config = copy.deepcopy(sampling_config) self._check_inputs( encoder_input_ids_list if encoder_input_ids else batch_input_ids_list, sampling_config, max_new_tokens) output_config = trtllm.OutputConfig( return_context_logits=self.gather_context_logits, return_generation_logits=self.gather_generation_logits, return_log_probs=output_log_probs, ) prompt_tuning_configs = self._prepare_ptuning_executor( batch_input_ids_list, prompt_table, prompt_tasks) stop_words_list = self._prepare_words_list(stop_words_list, len(batch_input_ids_list)) bad_words_list = self._prepare_words_list(bad_words_list, len(batch_input_ids_list)) requests = [ trtllm.Request(input_token_ids=input_ids, encoder_input_token_ids=encoder_input_ids_list[i] if encoder_input_ids is not None else None, max_new_tokens=max_new_tokens, pad_id=pad_id, end_id=end_id, stop_words=stop_words, bad_words=bad_words, sampling_config=sampling_config, streaming=streaming, output_config=output_config, prompt_tuning_config=prompt_tuning_config) for i, (input_ids, stop_words, bad_words, prompt_tuning_config) in enumerate( zip(batch_input_ids_list, stop_words_list, bad_words_list, prompt_tuning_configs)) ] request_ids = self.session.enqueue_requests(requests) if not streaming: return self._initialize_and_fill_output(request_ids, end_id, return_dict, output_sequence_lengths, output_log_probs, output_cum_log_probs, batch_input_ids, streaming) else: return self._stream(request_ids, end_id, return_dict, output_sequence_lengths, output_log_probs, output_cum_log_probs, batch_input_ids, streaming, batch_input_ids_list) def _prepare_words_list(self, words_list: torch.Tensor, batch_size: int): if words_list is None: return [None] * batch_size words_list = words_list.tolist() assert len(words_list) == batch_size for words in words_list: while len(words[0]) > 0 and words[0][-1] == 0: words[0].pop() words[1].pop() return words_list def _prepare_ptuning_executor(self, batch_input_ids_list, prompt_table, prompt_tasks): prompt_tuning_configs = len(batch_input_ids_list) * [None] if prompt_table is not None: prompt_table_data = self._prepare_embedding_table( prompt_table).cuda() if prompt_tasks is not None: task_indices = [int(t) for t in prompt_tasks.split(',')] assert len(task_indices) == len(batch_input_ids_list), \ f"Number of supplied tasks ({len(task_indices)}) must match input batch size ({len(batch_input_ids_list)})" prompt_tuning_configs = [ trtllm.PromptTuningConfig( embedding_table=prompt_table_data[task_indices[i]]) for i in range(len(batch_input_ids_list)) ] else: prompt_tuning_configs = [ trtllm.PromptTuningConfig( embedding_table=prompt_table_data[0]) for _ in range(len(batch_input_ids_list)) ] return prompt_tuning_configs def _initialize_and_fill_output(self, request_ids, end_id, return_dict, output_sequence_lengths, output_log_probs, output_cum_log_probs, batch_input_ids, streaming): output_ids = [[] for _ in range(len(request_ids))] for reqid_pos in range(len(request_ids)): output_ids[reqid_pos] = [[] for _ in range(self.max_beam_width)] multi_responses = self.session.await_responses(request_ids) responses = [ response for responses in multi_responses for response in responses ] return self._fill_output(responses, output_ids, end_id, return_dict, output_sequence_lengths, output_log_probs, output_cum_log_probs, batch_input_ids, streaming, request_ids) def _stream(self, request_ids, end_id, return_dict, output_sequence_lengths, output_log_probs, output_cum_log_probs, batch_input_ids, streaming, batch_input_ids_list): output_ids = [[] for _ in range(len(request_ids))] for reqid_pos in range(len(request_ids)): output_ids[reqid_pos] = [ copy.deepcopy(batch_input_ids_list[reqid_pos]) for _ in range(self.max_beam_width) ] finished_reqs = 0 while finished_reqs < len(request_ids): responses = self.session.await_responses() for response in responses: if response.result.is_final: finished_reqs += 1 yield self._fill_output(responses, output_ids, end_id, return_dict, output_sequence_lengths, output_log_probs, output_cum_log_probs, batch_input_ids, streaming, request_ids) def _fill_output(self, responses, output_ids, end_id, return_dict, output_sequence_lengths, output_log_probs, output_cum_log_probs, batch_input_ids, streaming, request_ids): cuda_device = torch.device("cuda") for response in responses: if response.has_error(): raise RuntimeError(response.error_msg) reqid_pos = request_ids.index(response.request_id) for beam, output_tokens in enumerate( response.result.output_token_ids): output_ids[reqid_pos][beam] += output_tokens sequence_lengths = [] for output in output_ids: sequence_lengths.append([len(a) for a in output]) if streaming: output_ids = copy.deepcopy(output_ids) for beam in output_ids: for output_tokens in beam: output_tokens += (self.max_seq_len - len(output_tokens)) * [end_id] output_ids = torch.tensor(output_ids, dtype=torch.int32, device=cuda_device) if return_dict: outputs = {'output_ids': output_ids} if output_sequence_lengths: outputs['sequence_lengths'] = torch.tensor(sequence_lengths, dtype=torch.int32, device=cuda_device) if self.gather_context_logits: outputs['context_logits'] = [ a.result.context_logits.cuda() for a in responses if a.result.context_logits is not None ] # Pad context_logits into a rectangle max_input_length = max(a.shape[0] for a in outputs['context_logits']) for i, a in enumerate(outputs['context_logits']): pad_length = max_input_length - a.shape[0] outputs['context_logits'][i] = torch.nn.functional.pad( a, [0, 0, 0, pad_length]) outputs['context_logits'] = torch.stack( outputs['context_logits']) if self.gather_generation_logits: outputs['generation_logits'] = [ a.result.generation_logits.cuda() for a in responses if a.result.generation_logits is not None ] outputs['generation_logits'] = torch.stack( outputs['generation_logits']) if output_log_probs: outputs['log_probs'] = [ a.result.log_probs for a in responses if a.result.log_probs is not None ] # Pad log_probs into a rectangle max_seq_len = max( len(a) for beam_list in outputs['log_probs'] for a in beam_list) for i, a in enumerate(outputs['log_probs']): for j, b in enumerate(a): pad_length = max_seq_len - len(b) outputs['log_probs'][i][j] = b + [0.0] * pad_length outputs['log_probs'] = torch.tensor(outputs['log_probs'], device=cuda_device) if output_cum_log_probs: outputs['cum_log_probs'] = [ a.result.cum_log_probs for a in responses if a.result.cum_log_probs is not None ] outputs['cum_log_probs'] = torch.tensor( outputs['cum_log_probs'], device=cuda_device) input_lengths = torch.tensor([x.size(0) for x in batch_input_ids], dtype=torch.int32, device=cuda_device) outputs = self._prepare_outputs(outputs, input_lengths) else: outputs = output_ids return outputs