mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-31 08:11:27 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
169 lines
6.3 KiB
Python
169 lines
6.3 KiB
Python
import random
|
|
from pathlib import Path
|
|
from tempfile import TemporaryDirectory
|
|
from typing import Any
|
|
|
|
import torch
|
|
from janus import LifoQueue, Queue
|
|
from transformers import AutoTokenizer
|
|
|
|
import tensorrt_llm.bindings as tllm
|
|
from tensorrt_llm.hlapi.llm import LLM, ModelConfig
|
|
|
|
|
|
class AsyncLLMEngine:
|
|
TERMINATE_REQUEST_ID = 0
|
|
|
|
def __init__(self,
|
|
engine_dir: Path,
|
|
tokenizer: str | Path,
|
|
max_beam_width: int = 1,
|
|
max_num_sequences: int = 10) -> None:
|
|
self.requests: list[tllm.InferenceRequest] = []
|
|
self.results: dict[int, Queue] = {}
|
|
self.stop_set: set[int] = set()
|
|
self.stats: LifoQueue = LifoQueue()
|
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer,
|
|
legacy=False,
|
|
padding_side='left',
|
|
truncation_side='left',
|
|
trust_remote_code=True,
|
|
use_fast=True)
|
|
opt_params = tllm.TrtGptModelOptionalParams(
|
|
max_num_sequences=max_num_sequences)
|
|
self.engine = tllm.GptManager(
|
|
engine_dir, tllm.TrtGptModelType.InflightBatching, max_beam_width,
|
|
tllm.SchedulerPolicy.MAX_UTILIZATION, self.fetch_requests,
|
|
self.handle_response, self.get_stop_set, self.handle_stats,
|
|
opt_params, AsyncLLMEngine.TERMINATE_REQUEST_ID)
|
|
|
|
@staticmethod
|
|
def from_hf_dir(model_dir: str | Path):
|
|
config = ModelConfig(model_dir=str(model_dir))
|
|
config.build_config.plugin_config.set_gemm_plugin()
|
|
config.build_config.plugin_config.set_context_fmha()
|
|
config.build_config.plugin_config.set_gpt_attention_plugin()
|
|
config.build_config.plugin_config.enable_paged_kv_cache()
|
|
config.build_config.plugin_config.enable_remove_input_padding()
|
|
|
|
engine_dir = TemporaryDirectory()
|
|
LLM(config).save(engine_dir.name)
|
|
engine = AsyncLLMEngine(Path(engine_dir.name), model_dir)
|
|
# Reference the tmp dir in the object so it's cleaned once the engine disappears
|
|
setattr(engine, "_tmp_dir", engine_dir)
|
|
|
|
return engine
|
|
|
|
@staticmethod
|
|
def gen_id() -> int:
|
|
# underlying type is uint64
|
|
uint64_max = 2**64 - 1
|
|
return random.randint(AsyncLLMEngine.TERMINATE_REQUEST_ID + 1,
|
|
uint64_max)
|
|
|
|
@staticmethod
|
|
def create_inference_request(
|
|
req_id: int, parameters: dict[str, Any]) -> tllm.InferenceRequest:
|
|
|
|
def set_property(name: str, dtype: torch.dtype = torch.int32):
|
|
if name in parameters:
|
|
setattr(request, name,
|
|
torch.tensor([parameters[name]], dtype=dtype))
|
|
|
|
request = tllm.InferenceRequest(req_id)
|
|
request.input_ids = parameters["input_ids"]
|
|
set_property("max_new_tokens")
|
|
set_property("end_id")
|
|
set_property("pad_id")
|
|
set_property("min_length")
|
|
set_property("temperature", torch.float32)
|
|
set_property("runtime_top_k", torch.float32)
|
|
set_property("runtime_top_p", torch.float32)
|
|
set_property("random_seed", torch.int64)
|
|
if "streaming" in parameters:
|
|
request.is_streaming = parameters["streaming"]
|
|
|
|
return request
|
|
|
|
def add_request(self, request_dict: dict[str,
|
|
Any]) -> tllm.InferenceRequest:
|
|
ids = self.tokenizer(request_dict.pop("prompt"),
|
|
return_tensors="pt",
|
|
return_attention_mask=False)
|
|
request_dict["input_ids"] = ids["input_ids"].to(torch.int32)
|
|
request_dict["end_id"] = self.tokenizer.eos_token_id
|
|
if getattr(self.tokenizer, "pad_token_id") is not None:
|
|
request_dict["pad_id"] = self.tokenizer.pad_token_id
|
|
else:
|
|
request_dict["pad_id"] = request_dict["end_id"]
|
|
|
|
request = AsyncLLMEngine.create_inference_request(
|
|
AsyncLLMEngine.gen_id(), request_dict)
|
|
|
|
self.results[request.request_id] = Queue()
|
|
self.requests.append(request)
|
|
|
|
return request
|
|
|
|
async def get_response(self,
|
|
request_id: int) -> tuple[dict[str, Any], bool]:
|
|
outputs, finished = None, False
|
|
while outputs is None:
|
|
outputs, finished = await self.results[request_id].async_q.get()
|
|
|
|
last_idx = outputs["sequence_length"][0, 0].item()
|
|
output = outputs["output_ids"][0, 0, :last_idx]
|
|
|
|
if finished:
|
|
self.results.pop(request_id)
|
|
|
|
return output, finished
|
|
|
|
async def generate(self,
|
|
prompt: str,
|
|
max_new_tokens: int,
|
|
streaming: bool = True):
|
|
tllm_request = self.add_request({
|
|
"prompt": prompt,
|
|
"max_new_tokens": [max_new_tokens],
|
|
"streaming": streaming
|
|
})
|
|
request_id = tllm_request.request_id
|
|
current_tokens = tllm_request.input_ids[0].numpy().tolist()
|
|
current_str = self.tokenizer.decode(current_tokens)
|
|
|
|
finished = False
|
|
while not finished:
|
|
output, finished = await self.get_response(request_id)
|
|
|
|
current_tokens += output.numpy().tolist()
|
|
new_str = self.tokenizer.decode(current_tokens)
|
|
diff_str = new_str[len(current_str):]
|
|
current_str = new_str
|
|
|
|
yield diff_str
|
|
|
|
# Callbacks for BatchManager
|
|
def fetch_requests(self, max_num_sequences) -> list[tllm.InferenceRequest]:
|
|
fetched = []
|
|
for _ in range(max_num_sequences):
|
|
if len(self.requests) == 0:
|
|
break
|
|
fetched.append(self.requests.pop())
|
|
return fetched
|
|
|
|
def handle_response(self, req_id: int, tensors: list[tllm.NamedTensor],
|
|
is_ok: bool, err_msg: str) -> None:
|
|
self.results[req_id].sync_q.put(
|
|
[{t.name: t.tensor
|
|
for t in tensors}, is_ok] if not err_msg else err_msg)
|
|
|
|
def get_stop_set(self) -> set[int]:
|
|
return self.stop_set
|
|
|
|
def handle_stats(self, stats: str):
|
|
while self.stats.sync_q.full():
|
|
self.stats.sync_q.get()
|
|
|
|
self.stats.sync_q.put(stats)
|