# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast import numpy as np import torch from ordered_set import OrderedSet from tensorrt_llm.logger import logger from tensorrt_llm.runtime import ModelRunnerCpp class NgramPool: # Ngrams pool for Ngram def __init__( self, input_batch_size: int, max_draft_len: int, max_matching_ngram_size: int, end_id: int, max_seq_len: list[int], is_keep_all: bool = True, is_use_oldest: bool = True, ): self.input_batch_size = input_batch_size self.max_draft_len = max_draft_len self.max_matching_ngram_size = max_matching_ngram_size self.end_id = end_id self.max_seq_len = max_seq_len self.is_keep_all = is_keep_all self.is_use_oldest = is_use_oldest self.pool = [{} for _ in range(input_batch_size)] self.start_index = [0 for _ in range(input_batch_size)] assert self.max_draft_len > 0, f"max_draft_len must be greater than 0, but got {self.max_draft_len}" assert self.max_matching_ngram_size > 0, f"max_matching_ngram_size must be greater than 0, but got {self.max_matching_ngram_size}" def print_pool(self): """ For debug """ logger.info(f"Batch size = {self.input_batch_size}") for i, map in enumerate(self.pool): logger.info(f"Slot {i}, size = {len(map)}") for key, values in map.items(): logger.info(f" {key}->{values}") def get_draft_tokens(self, prefix: list[torch.Tensor], batch_slot: list[int]): """ Get draft tokens from a batch of requests modified from `transformers/generation/candidate_generator.py` """ batch_size = len(prefix) prefix_len = [len(prefix[bi]) for bi in range(batch_size)] draft_tokens = [] # `logits` is useless yet for bi in range(batch_size): gbi = batch_slot[bi] # Global index in the input batch chosen_ids = [self.end_id] # Skip search if prefix is length of `max_length - 1` if prefix_len[bi] >= self.max_seq_len[gbi] - 1: draft_tokens.append(chosen_ids) continue # Update pool sequence = prefix[bi][self.start_index[gbi]:].tolist() for size in range( min(self.max_matching_ngram_size, prefix_len[bi] - 1), 0, -1): # Find each possible key-value combination, and use tuple for hash for l in range(len(sequence) - size): r = min(l + size + self.max_draft_len, len(sequence)) key = tuple(sequence[l:l + size]) value = tuple(sequence[l + size:r]) if key not in self.pool[gbi] or not self.is_keep_all or \ len(self.pool[gbi][key][0]) < self.max_draft_len: # Update the value if # 1. the key does not exist # 2. we only keep the newest one value for each key (MRU) # 3. the length of the value saved before is less than `max_draft_len` self.pool[gbi][key] = OrderedSet((value, )) elif value not in self.pool[gbi][key]: # Extend the value if the key is already existed but count of values is not enough self.pool[gbi][key].add(value) # Find match for size in range( min(self.max_matching_ngram_size, prefix_len[bi] - 1), 0, -1): pattern = tuple(prefix[bi][-size:].tolist()) if pattern not in self.pool[gbi]: continue if self.is_use_oldest: # Always choose the oldest match, aligned with HF chosen_ids = self.pool[gbi][pattern][0] else: # Always choose the newest match chosen_ids = self.pool[gbi][pattern][-1] break draft_tokens.append(chosen_ids) self.start_index[gbi] = max( 0, prefix_len[bi] - (self.max_draft_len + self.max_matching_ngram_size - 1)) return draft_tokens, None def run_dtm_ngram(batch_input_ids, args, runtime_rank, end_id, pad_id, stop_words_list, bad_words_list, vocab_size, *, target_runner=None): # `dtm` for Draft-Target-Model, `ngram` for NGram is_dtm = (args.draft_target_model_config is not None) is_ngram = (args.ngram_config is not None) assert is_dtm ^ is_ngram, "`--draft_target_model_config` and `--ngram_config` can not be specified at the same time." if is_dtm: assert args.draft_engine_dir is not None, "`--draft_engine_dir` must be specified in Draft-Target-Model." draft_len, draft_device_list, target_device_list, use_logits = ast.literal_eval( args.draft_target_model_config) logger.info(f"Using Draft-Target-Model speculative decoding") logger.info(f"draft_len: {draft_len}") logger.info(f"Device(s) for draft model: {draft_device_list}") logger.info(f"Device(s) for target model: {target_device_list}") logger.info(f"Use logits to accept tokens: {use_logits}") if is_ngram: logger.info(f"Using NGram speculative decoding V1 workflow") max_draft_len, max_matching_ngram_size, target_device_list = ast.literal_eval( args.ngram_config) logger.info(f"max_draft_len: {max_draft_len}") logger.info(f"max_matching_ngram_size: {max_matching_ngram_size}") logger.info(f"Device(s) for the model: {target_device_list}") use_logits = False # `logits` is useless in this approach yet # Variables keeping constant during decoding input_batch_size = len(batch_input_ids) # Note as `BS` beam_width = args.num_beams # Note as `BW` is_compute_acceptance_ratio = logger.level == 'verbose' # Only for verbose input_len = [len(p) for p in batch_input_ids] max_seq_len = [i + args.max_output_len for i in input_len] # Variables changing during decoding n_iteration = 0 prefix = batch_input_ids # Input for each iteration batch_slot = list(range(input_batch_size)) # Index of requests if is_compute_acceptance_ratio: n_draft_token = [0 for _ in range(input_batch_size)] n_accept_token = [0 for _ in range(input_batch_size)] if is_ngram: ngram_pool = NgramPool(input_batch_size, max_draft_len, max_matching_ngram_size, end_id, max_seq_len) # Repack the output like the output of function `generate` outputs = {} outputs["output_ids"] = torch.full( [input_batch_size, beam_width, max(max_seq_len)], end_id, dtype=torch.int32) for bi in range(input_batch_size): outputs["output_ids"][bi, :, :input_len[bi]] = batch_input_ids[bi] outputs["sequence_lengths"] = torch.full([input_batch_size, beam_width], 0, dtype=torch.int32) outputs["context_logits"] = None outputs["generation_logits"] = torch.full( [input_batch_size, beam_width, max(max_seq_len), vocab_size], 0, dtype=torch.float16) outputs['cum_log_probs'] = None outputs['log_probs'] = None # Model runner common_runner_kwargs = dict( lora_dir=args.lora_dir, rank=runtime_rank, debug_mode=args.debug_mode, lora_ckpt_source=args.lora_ckpt_source, gpu_weights_percent=args.gpu_weights_percent, max_output_len=args.max_output_len, is_enc_dec=False, max_batch_size=input_batch_size, max_input_len=max(input_len) + args.max_output_len, max_beam_width=beam_width, max_attention_window_size=args.max_attention_window_size, sink_token_length=args.sink_token_length, max_tokens_in_paged_kv_cache=args.max_tokens_in_paged_kv_cache, kv_cache_enable_block_reuse=args.kv_cache_enable_block_reuse, kv_cache_free_gpu_memory_fraction=args. kv_cache_free_gpu_memory_fraction, cross_kv_cache_fraction=None, enable_chunked_context=args.enable_chunked_context, multi_block_mode=args.multi_block_mode, cuda_graph_mode=args.cuda_graph_mode, enable_context_fmha_fp32_acc=args.enable_context_fmha_fp32_acc, is_orchestrator_mode=True, ) if is_dtm: draft_runner_kwargs = common_runner_kwargs.copy() draft_runner_kwargs.update(engine_dir=args.draft_engine_dir, device_ids=draft_device_list) draft_runner = ModelRunnerCpp.from_dir(**draft_runner_kwargs) if target_runner is None: # Skip this constructor if we have prepared the runner before target_runner_kwargs = common_runner_kwargs.copy() target_runner_kwargs.update(engine_dir=args.engine_dir, device_ids=target_device_list) target_runner = ModelRunnerCpp.from_dir(**target_runner_kwargs) if is_dtm and use_logits: assert draft_runner.gather_generation_logits and target_runner.gather_generation_logits, "`--gather_generation_logits` must be specified while building draft/target models for using logits to accept" common_generaion_kwargs = dict( max_attention_window_size=args.max_attention_window_size, sink_token_length=args.sink_token_length, end_id=end_id, pad_id=pad_id, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, num_beams=beam_width, num_return_sequences=args.num_return_sequences, length_penalty=args.length_penalty, early_stopping=args.early_stopping, beam_width_array=None, repetition_penalty=args.repetition_penalty, presence_penalty=args.presence_penalty, frequency_penalty=args.frequency_penalty, min_p=args.min_p, stop_words_list=stop_words_list, bad_words_list=bad_words_list, random_seed=args.random_seed, lora_uids=args.lora_task_uids, prompt_table=args.prompt_table_path, prompt_tasks=args.prompt_tasks, streaming=False, output_sequence_lengths=True, no_repeat_ngram_size=args.no_repeat_ngram_size, return_dict=True, return_all_generated_tokens=args.return_all_generated_tokens, ) while True: n_iteration += 1 # Dynamic batch_size, decreases if some requests finish batch_size = len(prefix) prefix_len = [len(prefix[i]) for i in range(batch_size)] # Get draft tokens # `d_*` means variables from draft # `d_seq_len` includes input part, but `d_len` doesn't if is_dtm: draft_generation_kwargs = common_generaion_kwargs.copy() draft_generation_kwargs.update( batch_input_ids=prefix, max_new_tokens=draft_len, streaming=False, output_sequence_lengths=True, return_dict=True, ) draft = draft_runner.generate(**draft_generation_kwargs) torch.cuda.synchronize() # draft["output_ids"].shape -> [BS, BW, maxSL] # draft["sequence_lengths"].shape -> [BS, BW] # draft["generation_logits"].shape -> [BS, BW, draft_len, vocab_size] d_ids = [[end_id]] * batch_size d_logits = [None] * batch_size if use_logits else None d_seq_len = draft["sequence_lengths"][:, 0].tolist() d_len = [d_seq_len[bi] - prefix_len[bi] for bi in range(batch_size)] for bi in range(batch_size): l, r = prefix_len[bi], d_seq_len[bi] if l >= r: # No useful draft tokens continue d_ids[bi] = draft["output_ids"][bi, 0, l:r].tolist() if use_logits: d_logits[bi] = draft["generation_logits"][bi, 0, -d_len[bi]:, :] if is_ngram: d_ids, d_logits = ngram_pool.get_draft_tokens(prefix, batch_slot) d_len = [len(i) for i in d_ids] # Run target model # `t_*` means variables from target model # `t_seq_len` and `t_seq_ids` include input part, but `t_len` or `t_ids` don't target_generation_kwargs = common_generaion_kwargs.copy() target_generation_kwargs.update(batch_input_ids=prefix, draft_tokens_list=d_ids, draft_logits_list=d_logits) if is_dtm: max_new_tokens = draft_len + 1 if is_ngram: max_new_tokens = max_draft_len + 1 target_generation_kwargs.update(max_new_tokens=max_new_tokens) target = target_runner.generate(**target_generation_kwargs) torch.cuda.synchronize() t_ids = [None] * batch_size t_seq_ids = [None] * batch_size t_seq_len = target["sequence_lengths"][:, 0].tolist() t_len = [t_seq_len[bi] - prefix_len[bi] for bi in range(batch_size)] # Update output and tokens for next iteration for bi in range(batch_size): gbi = batch_slot[bi] # Global index in the input batch l = prefix_len[bi] r = min(t_seq_len[bi], max_seq_len[gbi]) t_ids[bi] = target["output_ids"][bi, 0, l:r].tolist() t_seq_ids[bi] = target["output_ids"][bi, 0, :r] outputs["output_ids"][gbi, 0, l:r] = torch.IntTensor(t_ids[bi]) outputs["sequence_lengths"][gbi, 0] = r if use_logits: outputs["generation_logits"][gbi, 0, (l - input_len[bi]):(r - input_len[bi])] = \ target["generation_logits"][bi][0,:(r-l)].detach().cpu() if is_compute_acceptance_ratio: n_draft_token[gbi] += d_len[bi] length = min(d_len[bi], t_len[bi], max_seq_len[gbi] - prefix_len[bi]) res = [d_ids[bi][i] == t_ids[bi][i] for i in range(length)] n_accept_token[gbi] += \ ((~torch.BoolTensor(res)).cumsum(axis=-1) < 1).sum() # Yield output if using streaming if args.streaming and not n_iteration % args.streaming_interval: yield outputs # Evaluate stop criteria and prepare inputs for next iteration prefix_next = [] batch_slot_next = [] for bi in range(batch_size): gbi = batch_slot[bi] # Global index in the input batch # Stop due to output length if len(t_seq_ids[bi]) >= max_seq_len[gbi]: continue # No need to update for the stopped requests # Stop due to the same output. Normally target should return 1 more token. # if (d_ids is not None and np.array_equal(d_ids[bi], t_ids[bi])): # continue # Stop due to no change (hit early stopping) if np.array_equal(t_seq_ids[bi].cpu().numpy(), prefix[bi].cpu().numpy()): continue # Stop due to end words if end_id in t_seq_ids[bi][prefix_len[bi]:]: continue # TODO: Check bad words and stop words criteria prefix_next.append(t_seq_ids[bi]) batch_slot_next.append(gbi) prefix = prefix_next batch_slot = batch_slot_next if len(prefix) == 0: # Leave while loop if no request remained break if is_compute_acceptance_ratio: logger.debug(f"Count of iteration(s): {n_iteration}") logger.debug(f"Acceptance ratio:") for i, (a, d) in enumerate(zip(n_accept_token, n_draft_token)): logger.debug(f"Request {i}: {a / d * 100 :6.2f}%") # Return runner in No-Streaming mode if args.streaming: yield outputs else: yield outputs, target_runner