TensorRT-LLMs/tensorrt_llm/hlapi/llm.py
Kaiyu Xie deaae40bd7
Update TensorRT-LLM (#787)
* Update TensorRT-LLM

---------

Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-01-02 17:54:32 +08:00

721 lines
28 KiB
Python

import json
import os
import time
from concurrent.futures import as_completed
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Iterable, List, Optional, Union
import torch
from tqdm import tqdm
from transformers import AutoTokenizer
from tensorrt_llm import Mapping, Module, logger
from tensorrt_llm.builder import BuildConfig, BuilderConfig
from tensorrt_llm.runtime import (GenerationSession, ModelRunner,
SamplingConfig, model_runner)
from tensorrt_llm.runtime.engine import EngineConfig
from .mpi_session import MpiSession, NodeSession, mpi_rank, mpi_size
@dataclass
class ParallelConfig:
''' The model distribution configs for LLM. '''
tp_size: int = 1
pp_size: int = 1
devices: List[int] = field(default_factory=list, init=False)
def __post_init__(self):
assert self.tp_size > 0, "tp_size should be larger than 0"
assert self.pp_size > 0, "pp_size should be larger than 0"
assert not self.devices or len(self.devices) == self.world_size
@property
def world_size(self) -> int:
return self.tp_size * self.pp_size
@dataclass
class ModelConfig:
''' ModelConfig holds the options for a model.
An example of the usage:
# A llama-7B model
config = ModelConfig('llama-7B')
# optionally override the default options
config.build_config.max_batch_size = 64
'''
# the options shared by all models
# ``model`` could either the model directory or a in-memory model.
# If ``model`` specifies the model kind like "llama-7B", etc. The model will be download automatically from third-party
# model hub like www.modelscope.cn or huggingface
model: Optional[Union[str, Module]] = None
# ``model_dir`` helps to locate a local model, the format of the model is determined by the model file itself.
# Either HF model, TensorRT-LLM checkpoints or TensorRT-LLM engine format is supported.
model_dir: Optional[str] = None
# ``build_config`` contains the options for building the model.
build_config = BuildConfig()
# ``quant_config`` contains the options for quantizing the model.
# quant_config: QuantMode = QuantMode()
# ``parallel_config`` contains the options for distributed inference.
parallel_config: ParallelConfig = ParallelConfig()
def __post_init__(self):
assert self.model or self.model_dir, "Either model or model_dir should be provided."
@property
def is_multi_gpu(self) -> bool:
return self.parallel_config.tp_size > 1
class ModelFormatKind(Enum):
HF = 0
TLLM_CKPT = 1
TLLM_ENGINE = 2
TokenIdsTy = List[int]
@dataclass
class GenerationOuptut:
request_id: int = -1
generate_pieces: List["GenerationPiece"] = field(default_factory=list)
@dataclass
class GenerationPiece:
''' The output of the generation.
For normal text generation, there is only one GenerationPiece for a given input.
For streaming generation, there could be multiple GenerationOutput each for a generated piece.
'''
index: int = 0
text: str = ""
token_ids: List[int] = field(default_factory=list)
logprobs: List[float] = field(default_factory=list)
class TokenizerBase:
''' This is a protocol for the tokenizer. Users can implement their own tokenizer by inheriting this class. '''
@property
def eos_token_id(self) -> int:
''' Return the id of the end of sentence token. '''
raise NotImplementedError()
@property
def pad_token_id(self) -> int:
''' Return the id of the padding token. '''
raise NotImplementedError()
def encode(self, text: str) -> TokenIdsTy:
''' Encode the text to token ids. '''
raise NotImplementedError()
def decode(self, token_ids: TokenIdsTy) -> str:
''' Decode the token ids to text. '''
raise NotImplementedError()
def batch_encode_plus(self, texts: List[str]) -> dict:
''' Encode the batch of texts to token ids. '''
raise NotImplementedError()
class TransformersTokenizer(TokenizerBase):
''' A wrapper for the Transformers' tokenizer.
This is the default tokenizer for LLM. '''
@classmethod
def from_pretrained(self, pretrained_model_dir: str, **kwargs):
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir,
**kwargs)
return TransformersTokenizer(tokenizer)
def __init__(self, tokenizer):
self.tokenizer = tokenizer
@property
def eos_token_id(self) -> int:
return self.tokenizer.eos_token_id
@property
def pad_token_id(self) -> int:
return self.tokenizer.pad_token_id
def encode(self, text: str) -> TokenIdsTy:
return self.tokenizer.encode(text)
def decode(self, token_ids: TokenIdsTy) -> str:
return self.tokenizer.decode(token_ids)
def batch_encode_plus(self, texts: List[str]) -> dict:
return self.tokenizer.batch_encode_plus(texts)
@dataclass
class LLM:
'''
An end-to-end runner for LLM tasks.
Classical usage:
config = ModelConfig("llama-7B")
llm = LLM(config)
llm("What is your name?") # => "My name is Llama."
'''
def __init__(self,
config: ModelConfig,
tokenizer: Optional[TokenizerBase] = None,
enable_tokenizer: bool = True,
disable_model_download: bool = False,
display_model_processing_summary: bool = False,
dump_model_processing_summary: Optional[str] = None):
'''
Args:
config: The model config for the model.
tokenizer: User provided tokenizer, will override the default one
enable_tokenizer: Turn on the preprocessing and postprocessing with a tokenizer to make the llm pipeline takes texts as input and produces text as output.
disable_model_download: Disable downloading the HF model from third-party model hub like www.modelscope.cn or huggingface.
display_model_processing_summary: Display the summary of the model building.
dump_model_processing_summary: Dump the summary of the model building into a log file.
'''
self.config = config
self._tokenizer = tokenizer
self.enable_tokenizer = enable_tokenizer
self.disable_model_download = disable_model_download
self.display_model_processing_summary = display_model_processing_summary
self.dump_model_processing_summary = dump_model_processing_summary
if self.config.is_multi_gpu:
import torch
assert torch.cuda.is_available(), "No CUDA device is available."
assert torch.cuda.device_count() >= self.config.parallel_config.world_size, \
f"Only {torch.cuda.device_count()} CUDA devices are available, but {self.config.parallel_config.world_size} are required."
logger.warning(
f'start MpiSession with {self.config.parallel_config.tp_size} workers'
)
self.mpi_session = MpiSession(
n_workers=self.config.parallel_config.tp_size)
self._build_model()
def __call__(
self,
prompts: List[str] | List[TokenIdsTy],
sampling_config: Optional[SamplingConfig] = None
) -> Iterable[GenerationOuptut]:
''' Generate the output for the given inputs.
Args:
prompts: The raw text or token ids to the model.
sampling_config: The sampling config for the generation, a default one will be used if not provided.
'''
if self.config.is_multi_gpu:
return self._generate_sync_multi_gpu(prompts, sampling_config)
else:
return self._generate_sync(prompts, self.runtime_stuff,
sampling_config)
def __getstate__(self):
# Customize the members to be pickled
# We should not pickle huge objects like tokenizer, since `self` maybe pickled and sent to MPI nodes each submit().
state = self.__dict__.copy()
def rm(attr):
if attr in state:
del state[attr]
for key in list(state.keys()):
if key == "_tokenizer":
# User passed tokenizer should be distributed to MPI nodes to override the default one.
continue
if key.startswith('_'):
del state[key]
rm("runtime_stuff")
rm("mpi_session")
# TODO[chunweiy]: Disable config pickle later
# rm_attr("config")
return state
def _generate_sync(self, prompts, runtime_stuff: "_ModelRuntimeStuff",
sampling_config) -> Iterable[GenerationOuptut]:
''' Generate in sync mode on a single GPU. '''
sampling_config = sampling_config or self.default_sampling_config
assert sampling_config is not None, "The sampling_config need to be provided."
build_config = self.config.build_config
if not prompts:
return []
assert runtime_stuff.runner, "The model runner is not built yet."
def generate_batch(batch_input_ids: List[torch.Tensor]):
batch_input_ids = [
torch.tensor(x, dtype=torch.int32) for x in batch_input_ids
] # List[torch.Tensor(seq)]
assert len(batch_input_ids) <= build_config.max_batch_size, \
f"Can not run batch size larger than {build_config.max_batch_size}, got {len(batch_input_ids)}"
outputs = runtime_stuff.runner.generate(batch_input_ids,
sampling_config)
output_ids = outputs['output_ids']
sequence_lengths = outputs['sequence_lengths']
batch_size, num_beams, max_len = output_ids.size()
input_lengths = [x.size(0) for x in batch_input_ids]
assert num_beams == 1, "Support beam search later"
for batch_idx in range(batch_size):
for beam in range(num_beams):
inputs = output_ids[batch_idx][
0][:input_lengths[batch_idx]].tolist()
output_begin = input_lengths[batch_idx]
output_end = sequence_lengths[batch_idx][beam]
outputs = output_ids[batch_idx][beam][
output_begin:output_end].tolist()
output_text = runtime_stuff.tokenizer.decode(
outputs) if runtime_stuff.tokenizer else None
# get a sequence for each prompt directly
piece = GenerationPiece(text=output_text, token_ids=outputs)
yield GenerationOuptut(generate_pieces=[piece])
tokenizer = runtime_stuff.tokenizer
def batching_prompts(prompts):
need_tokenize: bool = isinstance(prompts[0], str)
if need_tokenize:
assert tokenizer, "The tokenizer is not built or provided."
def process_batch(batch):
return tokenizer.batch_encode_plus(
batch)['input_ids'] if need_tokenize else batch
batch = []
for i, prompt in enumerate(prompts):
batch.append(prompt)
if len(batch) >= build_config.max_batch_size:
yield process_batch(batch)
batch = []
if batch:
yield process_batch(batch)
for batch in batching_prompts(prompts):
outs = generate_batch(batch)
for o in outs:
yield o
def _generate_sync_multi_gpu(
self, prompts, sampling_config: Optional[SamplingConfig]
) -> Iterable[GenerationOuptut]:
# TODO[chunweiy]: May merge this with the one gpu version later
assert self.config.is_multi_gpu, "The model is not distributed."
features = self.mpi_session.submit(self._node_generation_task, prompts,
sampling_config)
res = [feature.result() for feature in as_completed(features)]
# TODO[chunweiy]: make sure that the root's output is always the first
return res[0]
def _node_generation_task(
self, prompts: List[str] | List[List[int]],
sampling_config: Optional[SamplingConfig]
) -> List[GenerationOuptut]:
assert NodeSession.is_initialized(), "Model is not built yet."
assert isinstance(NodeSession.state, _ModelRuntimeStuff)
model: _ModelRuntimeStuff = NodeSession.state
if sampling_config is None:
sampling_config = SamplingConfig(
end_id=model.tokenizer.eos_token_id,
pad_id=model.tokenizer.eos_token_id
if model.tokenizer.pad_token_id is None else
model.tokenizer.pad_token_id,
output_sequence_lengths=True,
return_dict=True) if model.tokenizer else None
return list(self._generate_sync(prompts, model, sampling_config))
@property
def tokenizer(self) -> TokenizerBase:
return self._tokenizer or self.runtime_stuff.tokenizer
def save(self, engine_dir: str):
''' Save the built engine to the given path. '''
if self.config.is_multi_gpu:
futures = self.mpi_session.submit(
self._node_save_task, engine_dir, self.config.model_dir,
self.config.parallel_config.pp_size,
self.config.parallel_config.tp_size)
for future in futures:
future.result()
else:
_ModelBuilder.save(self.runtime_stuff,
self.config.model_dir,
engine_dir=engine_dir)
def _node_save_task(self, engine_dir: str, model_dir: str, pp_size: int,
tp_size: int):
runtime_stuff = NodeSession.state
mapping = Mapping(world_size=mpi_size(),
rank=mpi_rank(),
tp_size=tp_size,
pp_size=pp_size)
_ModelBuilder.save(runtime_stuff,
model_dir,
engine_dir=engine_dir,
mapping=mapping)
def _build_model(self):
if self.config.is_multi_gpu:
futures = self.mpi_session.submit(
self._node_build_task, self.config.parallel_config.tp_size,
self.config.parallel_config.pp_size, self._tokenizer)
res = []
for future in as_completed(futures):
res.append(future.result())
return bool(res)
else:
model_builder = _ModelBuilder(self.config,
self.enable_tokenizer,
tokenizer=self._tokenizer)
self.runtime_stuff = model_builder()
self._tokenizer = self.runtime_stuff.tokenizer
self.default_sampling_config = SamplingConfig(
end_id=self.tokenizer.eos_token_id,
pad_id=self.tokenizer.eos_token_id
if self.tokenizer.pad_token_id is None else
self.tokenizer.pad_token_id,
output_sequence_lengths=True,
return_dict=True) if self.tokenizer else None
return True
def _node_build_task(self,
tp_size: int,
pp_size: int,
tokenizer: TokenizerBase = None):
assert not NodeSession.is_initialized()
mapping = Mapping(tp_size=tp_size,
pp_size=pp_size,
rank=mpi_rank(),
world_size=tp_size * pp_size)
model_builder = _ModelBuilder(self.config,
self.enable_tokenizer,
tokenizer=tokenizer,
mapping=mapping)
runtime_stuff = model_builder()
# Hold the model builder for later use
NodeSession.state = runtime_stuff
@dataclass
class _ModelRuntimeStuff:
''' _ModelRuntimeStuff holds the minimum runtime stuff for running a model.
It will be hold as a runtime cache in MPI nodes in the multi-gpu mode.
'''
runner: Optional[ModelRunner] = None
tokenizer: Optional[TokenizerBase] = None
# engine is only used for saving the engine to disk
engine: Optional[bytes] = None
# engine_config is only used for saving the engine to disk
engine_config: Optional[dict | BuildConfig | EngineConfig] = None
@property
def model_structure(self) -> str:
# "llama" or "opt" and so on
return self.engine_config['builder_config']['name'] if isinstance(
self.engine_config, dict) else self.engine_config.name
class _ModelBuilder:
''' The model builder is used to build an end-to-end model pipeline from a model config.
It will construct the runtime resources including engine, tokenizer, model runner, etc.
'''
def __init__(
self,
config: ModelConfig,
enable_tokenizer: bool,
tokenizer: Optional[TokenizerBase],
mapping: Optional[Mapping] = None,
display_model_processing_summary: bool = False,
):
self.config = config
self.enable_tokenizer = enable_tokenizer
self.tokenizer = tokenizer
self.mapping = mapping
self.display_model_processing_summary = display_model_processing_summary
self._model_pipeline = []
# Prepare the model processing pipeline
if isinstance(self.config.model, Module):
''' Build engine from user provided model '''
raise NotImplementedError()
self._model_pipeline.append(
("build_engine", self._build_engine_from_inmemory_model))
return
self._model_dir = self.config.model_dir
self._model_name = self.config.model
self._model_structure = None
if self.config.model_dir is None:
''' Download HF model if necessary '''
raise NotImplementedError()
assert self.config.model is not None, "Either model_dir or model should be provided."
self._model_pipeline.append(
("download_hf_model", self._download_hf_model))
self._model_dir = self._cache_manager.hf_checkpoint_dir()
self._model_format = self._get_model_format(self._model_dir)
self._model_name = self._model_name or self._get_model_kind(
self._model_dir)
if self._model_format is ModelFormatKind.HF:
''' HF -> TFRT checkpoints -> engine '''
self._model_pipeline.append(
("hf_to_trtllm", self._build_model_from_hf))
self._model_pipeline.append(
("build_engine", self._build_engine_and_model_runner))
elif self._model_format is ModelFormatKind.TLLM_CKPT:
''' TFRT checkpoints -> engine '''
raise NotImplementedError()
elif self._model_format is ModelFormatKind.TLLM_ENGINE:
''' TFRT engine '''
self._model_pipeline.append(
("load_engine", self._load_model_runner))
if self.enable_tokenizer and not self.tokenizer:
''' Use the default tokenizer if user doesn't provide one '''
self._model_pipeline.append(
("init_tokenizer", self._init_default_tokenizer))
def __call__(self) -> _ModelRuntimeStuff:
if self.config.is_multi_gpu:
torch.cuda.set_device(self.mapping.rank)
for step_name, step in tqdm(self._model_pipeline,
desc="Model preprocessing"):
# Each step could have a separate progress bar
# e.g. the download_hf_model step or the build_engine step which is time-consuming
if self.config.is_multi_gpu:
# TODO[chunweiy]: here is for debugging, remove later
print(f"\n#rank-{mpi_rank()} Executing {step_name}")
else:
print(f"\nExecuting {step_name}")
start_time = time.time()
step()
end_time = time.time()
logger.warning(
f"Finish executing step {step_name} in {end_time - start_time} seconds"
)
if self.display_model_processing_summary:
self._display_summary()
assert self._model_structure is not None, "The model structure is not set yet."
assert self.runner is not None, "The model runner is not built yet."
assert hasattr(self, '_builder_config') or hasattr(
self, '_engine_config'), "config is not loaded."
config = self._engine_config if hasattr(
self, '_engine_config') else self._builder_config
return _ModelRuntimeStuff(tokenizer=self.tokenizer,
runner=self.runner,
engine=self._engine,
engine_config=config)
@staticmethod
def save(model: _ModelRuntimeStuff,
model_dir: str,
engine_dir: str,
mapping=None):
''' Save the built engine to the given path. '''
from tensorrt_llm.builder import Builder
mapping = mapping or Mapping()
rank = mapping.rank if mapping else 0
def save_engine_to_dir(engine_dir):
# TODO[chunweiy, tao]: Fix here. The self.module is del after the constructor, that's why the self.model.save is not used here.
def get_engine_name(model, dtype, tp_size, pp_size, rank):
if pp_size == 1:
return '{}_{}_tp{}_rank{}.engine'.format(
model, dtype, tp_size, rank)
return '{}_{}_tp{}_pp{}_rank{}.engine'.format(
model, dtype, tp_size, pp_size, rank)
engine_dir = Path(engine_dir)
if not engine_dir.exists():
engine_dir.mkdir()
config_path = engine_dir / 'config.json'
# TODO[chunweiy]: get dtype from config
engine_path = engine_dir / get_engine_name(
model.model_structure, 'float16', mapping.tp_size,
mapping.pp_size, rank)
builder = Builder()
# write config.json
if isinstance(model.engine_config, BuilderConfig):
builder.save_config(model.engine_config, config_path)
elif isinstance(model.engine_config, dict):
with open(config_path, 'w') as f:
json.dump(model.engine_config, f)
else:
raise ValueError("wrong engine_config type")
logger.debug(f"Saving engine to {engine_path}")
with open(engine_path, 'wb') as f:
f.write(model.engine)
def copy_hf_tokenizer_stuff_to_engine_dir():
# Copy the HF tokenizer stuff to the engine dir so that we can use the engine dir as a standalone model dir supports end-to-end task.
# This is only for HF model for now, not available for users' customized tokenizers.
import shutil
for name in os.listdir(model_dir):
src = os.path.join(model_dir, name)
dst = os.path.join(engine_dir, name)
if name.startswith('tokenizer'):
if os.path.isdir(src):
shutil.copytree(src, dst, dirs_exist_ok=True)
else:
shutil.copy2(src, dst)
save_engine_to_dir(engine_dir)
if isinstance(model.tokenizer, TransformersTokenizer):
if mapping is None or mapping.rank == 0:
copy_hf_tokenizer_stuff_to_engine_dir()
def _get_model_format(self, model_dir: str) -> ModelFormatKind:
''' Tell the format of the model. '''
# TODO[chunweiy]: refine this
return ModelFormatKind.HF if Path.exists(
Path(model_dir) /
'generation_config.json') else ModelFormatKind.TLLM_ENGINE
def _get_model_kind(self, model_dir: str) -> str:
''' Tell the kind of the model. e.g. "llama" '''
# TODO[chunweiy]: refine this
return 'llama'
def _download_hf_model(self):
''' Download HF model from third-party model hub like www.modelscope.cn or huggingface. '''
raise NotImplementedError()
def _build_model_from_hf(self):
''' Build a TRT-LLM model from a HF model. '''
from tensorrt_llm.models import LLaMAForCausalLM
# TODO[chunweiy]: inspect from hf model/config
model_structure = 'LLaMaForCausalLM'
# TODO[chunweiy]: add more models
model2struct = dict(LLaMaForCausalLM=LLaMAForCausalLM)
self.model = model2struct[model_structure].from_hugging_face(
self._model_dir, mapping=self.mapping)
def _load_model_runner(self):
''' Load a model runner from a TRT-LLM engine. '''
assert self._model_dir
logger.warning(f"Loading model runner from {self._model_dir}")
self.runner = ModelRunner.from_dir(self._model_dir)
self._engine = self.runner.session.runtime.engine
with open(os.path.join(self._model_dir, 'config.json'), 'r') as f:
self._engine_config = json.load(f)
self._model_structure = self._engine_config['builder_config'][
'name']
def _build_engine_and_model_runner(self):
''' Build TensorRT-LLM engine from a in-memory model.
The model runner will be created.
'''
from tensorrt_llm.mapping import Mapping
self._engine, self._builder_config = self.model.to_trt(
self.config.build_config.max_batch_size,
self.config.build_config.max_input_len,
self.config.build_config.max_output_len)
self._model_structure = self._builder_config.name
# TODO[chunweiy]: Fix this, plugin_config should not be None, which triggers OOTB mode
# plugin_config=self.config.build_config.plugin_config)
# delete the model explicitly to free all the build-time resources
del self.model
# TODO [chunweiy]: Is this conversion necessary?
model_config, other_config = model_runner._builder_to_model_config(
self._builder_config.to_dict())
max_batch_size = other_config.get('max_batch_size')
max_input_len = other_config.get('max_input_len')
max_output_len = other_config.get('max_output_len')
max_beam_width = other_config.get('max_beam_width')
runtime_mapping = self.mapping or Mapping()
session = GenerationSession(model_config, self._engine, runtime_mapping)
self.runner = ModelRunner(session, max_batch_size, max_input_len,
max_output_len, max_beam_width)
def _init_default_tokenizer(self):
self.tokenizer = TransformersTokenizer.from_pretrained(
self._model_dir,
legacy=False,
padding_side='left',
truncation_side='left',
trust_remote_code=True,
use_fast=True)
def _convert_hf_to_trtllm_checkpoints(self):
'''
Convert a HuggingFace model to a TensorRT-LLM checkpoints.
The checkpoints will be cached in the cache directory.
'''
raise NotImplementedError()
def _quantize(self):
''' Quantize a TensorRT-LLM checkpoints from a TensorRT-LLM checkpoints.
The checkpoints will be cached in the cache directory.
'''
raise NotImplementedError()
def _display_summary(self):
''' Display the summary of the model.
The following information could be displayed:
- model kind
- quantization information
- runtime setting information
- cache information
and so on.
'''
raise NotImplementedError()