mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
213 lines
8.2 KiB
Python
213 lines
8.2 KiB
Python
import argparse
|
|
import datetime
|
|
import typing as _tp
|
|
|
|
import torch as _tor
|
|
from lmformatenforcer import (JsonSchemaParser, TokenEnforcer,
|
|
TokenEnforcerTokenizerData)
|
|
from pydantic import BaseModel
|
|
from transformers import AutoTokenizer
|
|
|
|
import tensorrt_llm.bindings.executor as trtllm
|
|
|
|
|
|
def _build_regular_tokens_list(
|
|
tokenizer) -> _tp.List[_tp.Tuple[int, str, bool]]:
|
|
token_0 = [tokenizer.encode("0")[-1]]
|
|
regular_tokens = []
|
|
vocab_size = tokenizer.vocab_size
|
|
for token_idx in range(vocab_size):
|
|
if token_idx in tokenizer.all_special_ids:
|
|
continue
|
|
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
|
|
tensor_after_0 = _tor.tensor(token_0 + [token_idx], dtype=_tor.long)
|
|
decoded_after_0 = tokenizer.decode(tensor_after_0)[1:]
|
|
decoded_regular = tokenizer.decode(token_0)
|
|
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
|
|
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
|
|
return regular_tokens
|
|
|
|
|
|
def build_token_enforcer(tokenizer, character_level_parser):
|
|
"""
|
|
Build logits processor for feeding it into generate function (use_py_session should be True)
|
|
"""
|
|
regular_tokens = _build_regular_tokens_list(tokenizer)
|
|
|
|
def _decode(tokens: _tp.List[int]) -> str:
|
|
tensor = _tor.tensor(tokens, dtype=_tor.long)
|
|
return tokenizer.decode(tensor)
|
|
|
|
tokenizer_data = TokenEnforcerTokenizerData(regular_tokens, _decode,
|
|
tokenizer.eos_token_id)
|
|
return TokenEnforcer(tokenizer_data, character_level_parser)
|
|
|
|
|
|
# Prepare and enqueue the requests
|
|
def enqueue_requests(args: argparse.Namespace,
|
|
executor: trtllm.Executor) -> None:
|
|
|
|
sampling_config = trtllm.SamplingConfig(args.beam_width)
|
|
|
|
request_ids = []
|
|
for iter_id in range(args.batch_size):
|
|
# Create the request.
|
|
request = trtllm.Request(input_token_ids=prompt,
|
|
max_tokens=25,
|
|
end_id=tokenizer.eos_token_id,
|
|
sampling_config=sampling_config,
|
|
client_id=iter_id % 2)
|
|
request.logits_post_processor_name = request.BATCHED_POST_PROCESSOR_NAME if args.lpp_batched else "my_logits_pp"
|
|
|
|
# Enqueue the request.
|
|
req_id = executor.enqueue_request(request)
|
|
request_ids.append(req_id)
|
|
|
|
return request_ids
|
|
|
|
|
|
# Wait for responses and store output tokens
|
|
def wait_for_responses(args: argparse.Namespace, request_ids: list[int],
|
|
executor: trtllm.Executor) -> dict[dict[list[int]]]:
|
|
|
|
output_tokens = {
|
|
req_id: {
|
|
beam: []
|
|
for beam in range(args.beam_width)
|
|
}
|
|
for req_id in request_ids
|
|
}
|
|
num_finished = 0
|
|
iter = 0
|
|
while (num_finished < len(request_ids) and iter < args.timeout_ms):
|
|
responses = executor.await_responses(
|
|
datetime.timedelta(milliseconds=args.timeout_ms))
|
|
for response in responses:
|
|
req_id = response.request_id
|
|
if not response.has_error():
|
|
result = response.result
|
|
num_finished += 1 if result.is_final else 0
|
|
for beam, outTokens in enumerate(result.output_token_ids):
|
|
output_tokens[req_id][beam].extend(outTokens)
|
|
else:
|
|
raise RuntimeError(
|
|
str(req_id) + " encountered error:" + response.error_msg)
|
|
|
|
return output_tokens
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Executor Bindings Example")
|
|
parser.add_argument("--tokenizer_path",
|
|
"-t",
|
|
type=str,
|
|
required=True,
|
|
help="Directory containing model tokenizer")
|
|
parser.add_argument("--engine_path",
|
|
"-e",
|
|
type=str,
|
|
required=True,
|
|
help="Directory containing model engine")
|
|
parser.add_argument("--beam_width",
|
|
type=int,
|
|
required=False,
|
|
default=1,
|
|
help="The beam width")
|
|
parser.add_argument("--batch_size",
|
|
type=int,
|
|
required=False,
|
|
default=1,
|
|
help="The batch size")
|
|
parser.add_argument(
|
|
"--timeout_ms",
|
|
type=int,
|
|
required=False,
|
|
default=10000,
|
|
help="The maximum time to wait for all responses, in milliseconds")
|
|
parser.add_argument("--lpp_batched",
|
|
action="store_true",
|
|
default=False,
|
|
help="Enable batched logits post processor")
|
|
|
|
args = parser.parse_args()
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
|
|
|
|
class AnswerFormat(BaseModel):
|
|
last_name: str
|
|
year_of_birth: int
|
|
|
|
parser = JsonSchemaParser(AnswerFormat.model_json_schema())
|
|
token_enforcer = build_token_enforcer(tokenizer, parser)
|
|
|
|
def get_allowed_tokens(ids, client_id):
|
|
if client_id is None or client_id == 0: return [42]
|
|
|
|
def _trim(ids):
|
|
return [x for x in ids if x != tokenizer.eos_token_id]
|
|
|
|
allowed = token_enforcer.get_allowed_tokens(_trim(ids[0]))
|
|
return allowed
|
|
|
|
def logits_post_processor(req_id: int, logits: _tor.Tensor,
|
|
ids: _tp.List[_tp.List[int]], stream_ptr: int,
|
|
client_id: _tp.Optional[int]):
|
|
mask = _tor.full_like(logits, fill_value=float("-inf"), device="cpu")
|
|
allowed = get_allowed_tokens(ids, client_id)
|
|
mask[:, :, allowed] = 0
|
|
|
|
with _tor.cuda.stream(_tor.cuda.ExternalStream(stream_ptr)):
|
|
mask = mask.to(logits.device, non_blocking=True)
|
|
logits += mask
|
|
|
|
def logits_post_processor_batched(
|
|
req_ids_batch: _tp.List[int], logits_batch: _tp.List[_tor.Tensor],
|
|
ids_batch: _tp.List[_tp.List[_tp.List[int]]], stream_ptr,
|
|
client_ids_batch: _tp.List[_tp.Optional[int]]):
|
|
masks = []
|
|
for req_id, logits, ids, client_id in zip(req_ids_batch, logits_batch,
|
|
ids_batch, client_ids_batch):
|
|
del req_id
|
|
mask = _tor.full_like(logits,
|
|
fill_value=float("-inf"),
|
|
device="cpu")
|
|
allowed = get_allowed_tokens(ids, client_id)
|
|
mask[:, :, allowed] = 0
|
|
masks.append(mask)
|
|
|
|
with _tor.cuda.stream(_tor.cuda.ExternalStream(stream_ptr)):
|
|
for logits, mask in zip(logits_batch, masks):
|
|
logits += mask.to(logits.device, non_blocking=True)
|
|
|
|
# Create the executor.
|
|
executor_config = trtllm.ExecutorConfig(args.beam_width)
|
|
logits_proc_config = trtllm.LogitsPostProcessorConfig()
|
|
if not args.lpp_batched:
|
|
logits_proc_config.processor_map = {
|
|
"my_logits_pp": logits_post_processor
|
|
}
|
|
else:
|
|
logits_proc_config.processor_batched = logits_post_processor_batched
|
|
executor_config.logits_post_processor_config = logits_proc_config
|
|
executor = trtllm.Executor(args.engine_path, trtllm.ModelType.DECODER_ONLY,
|
|
executor_config)
|
|
|
|
input = "Please give me information about Michael Jordan. You MUST answer using the following json schema: "
|
|
prompt = tokenizer.encode(input)
|
|
print(f"Input text: {input}\n")
|
|
|
|
if executor.can_enqueue_requests():
|
|
request_ids = enqueue_requests(args, executor)
|
|
output_tokens = wait_for_responses(args, request_ids, executor)
|
|
|
|
# Print output
|
|
for req_id in request_ids:
|
|
for beam_id in range(args.beam_width):
|
|
result = tokenizer.decode(
|
|
output_tokens[req_id][beam_id][len(prompt):])
|
|
generated_tokens = len(
|
|
output_tokens[req_id][beam_id]) - len(prompt)
|
|
print(
|
|
f"Request {req_id} Beam {beam_id} ({generated_tokens} tokens): {result}"
|
|
)
|