TensorRT-LLMs/examples/bindings/executor/example_logits_processor.py
2024-12-16 21:50:47 -08:00

213 lines
8.2 KiB
Python

import argparse
import datetime
import typing as _tp
import torch as _tor
from lmformatenforcer import (JsonSchemaParser, TokenEnforcer,
TokenEnforcerTokenizerData)
from pydantic import BaseModel
from transformers import AutoTokenizer
import tensorrt_llm.bindings.executor as trtllm
def _build_regular_tokens_list(
tokenizer) -> _tp.List[_tp.Tuple[int, str, bool]]:
token_0 = [tokenizer.encode("0")[-1]]
regular_tokens = []
vocab_size = tokenizer.vocab_size
for token_idx in range(vocab_size):
if token_idx in tokenizer.all_special_ids:
continue
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
tensor_after_0 = _tor.tensor(token_0 + [token_idx], dtype=_tor.long)
decoded_after_0 = tokenizer.decode(tensor_after_0)[1:]
decoded_regular = tokenizer.decode(token_0)
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
return regular_tokens
def build_token_enforcer(tokenizer, character_level_parser):
"""
Build logits processor for feeding it into generate function (use_py_session should be True)
"""
regular_tokens = _build_regular_tokens_list(tokenizer)
def _decode(tokens: _tp.List[int]) -> str:
tensor = _tor.tensor(tokens, dtype=_tor.long)
return tokenizer.decode(tensor)
tokenizer_data = TokenEnforcerTokenizerData(regular_tokens, _decode,
tokenizer.eos_token_id)
return TokenEnforcer(tokenizer_data, character_level_parser)
# Prepare and enqueue the requests
def enqueue_requests(args: argparse.Namespace,
executor: trtllm.Executor) -> None:
sampling_config = trtllm.SamplingConfig(args.beam_width)
request_ids = []
for iter_id in range(args.batch_size):
# Create the request.
request = trtllm.Request(input_token_ids=prompt,
max_tokens=25,
end_id=tokenizer.eos_token_id,
sampling_config=sampling_config,
client_id=iter_id % 2)
request.logits_post_processor_name = request.BATCHED_POST_PROCESSOR_NAME if args.lpp_batched else "my_logits_pp"
# Enqueue the request.
req_id = executor.enqueue_request(request)
request_ids.append(req_id)
return request_ids
# Wait for responses and store output tokens
def wait_for_responses(args: argparse.Namespace, request_ids: list[int],
executor: trtllm.Executor) -> dict[dict[list[int]]]:
output_tokens = {
req_id: {
beam: []
for beam in range(args.beam_width)
}
for req_id in request_ids
}
num_finished = 0
iter = 0
while (num_finished < len(request_ids) and iter < args.timeout_ms):
responses = executor.await_responses(
datetime.timedelta(milliseconds=args.timeout_ms))
for response in responses:
req_id = response.request_id
if not response.has_error():
result = response.result
num_finished += 1 if result.is_final else 0
for beam, outTokens in enumerate(result.output_token_ids):
output_tokens[req_id][beam].extend(outTokens)
else:
raise RuntimeError(
str(req_id) + " encountered error:" + response.error_msg)
return output_tokens
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Executor Bindings Example")
parser.add_argument("--tokenizer_path",
"-t",
type=str,
required=True,
help="Directory containing model tokenizer")
parser.add_argument("--engine_path",
"-e",
type=str,
required=True,
help="Directory containing model engine")
parser.add_argument("--beam_width",
type=int,
required=False,
default=1,
help="The beam width")
parser.add_argument("--batch_size",
type=int,
required=False,
default=1,
help="The batch size")
parser.add_argument(
"--timeout_ms",
type=int,
required=False,
default=10000,
help="The maximum time to wait for all responses, in milliseconds")
parser.add_argument("--lpp_batched",
action="store_true",
default=False,
help="Enable batched logits post processor")
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
class AnswerFormat(BaseModel):
last_name: str
year_of_birth: int
parser = JsonSchemaParser(AnswerFormat.model_json_schema())
token_enforcer = build_token_enforcer(tokenizer, parser)
def get_allowed_tokens(ids, client_id):
if client_id is None or client_id == 0: return [42]
def _trim(ids):
return [x for x in ids if x != tokenizer.eos_token_id]
allowed = token_enforcer.get_allowed_tokens(_trim(ids[0]))
return allowed
def logits_post_processor(req_id: int, logits: _tor.Tensor,
ids: _tp.List[_tp.List[int]], stream_ptr: int,
client_id: _tp.Optional[int]):
mask = _tor.full_like(logits, fill_value=float("-inf"), device="cpu")
allowed = get_allowed_tokens(ids, client_id)
mask[:, :, allowed] = 0
with _tor.cuda.stream(_tor.cuda.ExternalStream(stream_ptr)):
mask = mask.to(logits.device, non_blocking=True)
logits += mask
def logits_post_processor_batched(
req_ids_batch: _tp.List[int], logits_batch: _tp.List[_tor.Tensor],
ids_batch: _tp.List[_tp.List[_tp.List[int]]], stream_ptr,
client_ids_batch: _tp.List[_tp.Optional[int]]):
masks = []
for req_id, logits, ids, client_id in zip(req_ids_batch, logits_batch,
ids_batch, client_ids_batch):
del req_id
mask = _tor.full_like(logits,
fill_value=float("-inf"),
device="cpu")
allowed = get_allowed_tokens(ids, client_id)
mask[:, :, allowed] = 0
masks.append(mask)
with _tor.cuda.stream(_tor.cuda.ExternalStream(stream_ptr)):
for logits, mask in zip(logits_batch, masks):
logits += mask.to(logits.device, non_blocking=True)
# Create the executor.
executor_config = trtllm.ExecutorConfig(args.beam_width)
logits_proc_config = trtllm.LogitsPostProcessorConfig()
if not args.lpp_batched:
logits_proc_config.processor_map = {
"my_logits_pp": logits_post_processor
}
else:
logits_proc_config.processor_batched = logits_post_processor_batched
executor_config.logits_post_processor_config = logits_proc_config
executor = trtllm.Executor(args.engine_path, trtllm.ModelType.DECODER_ONLY,
executor_config)
input = "Please give me information about Michael Jordan. You MUST answer using the following json schema: "
prompt = tokenizer.encode(input)
print(f"Input text: {input}\n")
if executor.can_enqueue_requests():
request_ids = enqueue_requests(args, executor)
output_tokens = wait_for_responses(args, request_ids, executor)
# Print output
for req_id in request_ids:
for beam_id in range(args.beam_width):
result = tokenizer.decode(
output_tokens[req_id][beam_id][len(prompt):])
generated_tokens = len(
output_tokens[req_id][beam_id]) - len(prompt)
print(
f"Request {req_id} Beam {beam_id} ({generated_tokens} tokens): {result}"
)