mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#613)
* Update TensorRT-LLM --------- Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> Co-authored-by: zhang-ge-hao <842720660@qq.com>
This commit is contained in:
parent
42af740db5
commit
f7eca56161
@ -44,4 +44,5 @@ repos:
|
||||
- id: codespell
|
||||
args:
|
||||
- --skip=".git,3rdparty"
|
||||
- --exclude-file=examples/whisper/tokenizer.py
|
||||
- --ignore-words-list=rouge,inout,atleast,strat
|
||||
|
||||
@ -257,7 +257,7 @@ The list of supported models is:
|
||||
* [InternLM](examples/internlm)
|
||||
* [LLaMA](examples/llama)
|
||||
* [LLaMA-v2](examples/llama)
|
||||
* [Mistral](examples/llama)
|
||||
* [Mistral](examples/llama#mistral-v01)
|
||||
* [MPT](examples/mpt)
|
||||
* [mT5](examples/enc_dec)
|
||||
* [OPT](examples/opt)
|
||||
@ -266,9 +266,10 @@ The list of supported models is:
|
||||
* [SantaCoder](examples/gpt)
|
||||
* [StarCoder](examples/gpt)
|
||||
* [T5](examples/enc_dec)
|
||||
* [Whisper](examples/whisper)
|
||||
|
||||
Note: [Encoder-Decoder](examples/enc_dec/) provides general encoder-decoder
|
||||
support that contains many encoder-decoder models such as T5, Flan-T5, etc. We
|
||||
functionality that supports many encoder-decoder models such as T5 family, BART family, Whisper family, etc. We
|
||||
unroll the exact model names in the list above to let users find specific
|
||||
models easier.
|
||||
|
||||
@ -372,7 +373,7 @@ For example: `mpirun -n 1 python3 examples/gpt/build.py ...`
|
||||
|
||||
### Change Log
|
||||
|
||||
#### Version 0.6.0
|
||||
#### Version 0.6.1
|
||||
|
||||
* Models
|
||||
* ChatGLM3
|
||||
|
||||
@ -257,6 +257,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
options.add_options()("ctx_micro_batch_size", "Batch size for context phase.", cxxopts::value<int>());
|
||||
options.add_options()("gen_micro_batch_size", "Batch size for generation phase.", cxxopts::value<int>());
|
||||
options.add_options()("max_attention_window", "Max kv cache length per sequence.", cxxopts::value<int>());
|
||||
options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>());
|
||||
options.add_options()(
|
||||
"kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.", cxxopts::value<float>());
|
||||
@ -352,6 +353,11 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
sessionConfig.kvCacheConfig.maxTokens = result["max_tokens_in_paged_kvcache"].as<int>();
|
||||
}
|
||||
// Argument: Max KV Cache Length
|
||||
if (result.count("max_attention_window"))
|
||||
{
|
||||
sessionConfig.kvCacheConfig.maxAttentionWindow = result["max_attention_window"].as<int>();
|
||||
}
|
||||
// Argument: K-V Cache Free Gpu Mem Fraction
|
||||
if (result.count("kv_cache_free_gpu_mem_fraction"))
|
||||
{
|
||||
|
||||
@ -12,6 +12,7 @@ The benchmark implementation and entrypoint can be found in [`benchmarks/python/
|
||||
* [`benchmarks/python/base_benchmark.py`](./base_benchmark.py) to implement the base class for benchmark.
|
||||
* [`benchmarks/python/gpt_benchmark.py`](./gpt_benchmark.py) to implement benchmark scripts for GPT and GPT-like(LLaMA/OPT/GPT-J/SmoothQuant-GPT) models.
|
||||
* [`benchmarks/python/bert_benchmark.py`](./bert_benchmark.py) to implement benchmark scripts for BERT models.
|
||||
* [`benchmarks/python/enc_dec_benchmark.py`](./enc_dec_benchmark.py) to implement benchmark scripts for Encoder-Decoder models.
|
||||
|
||||
## Usage
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ class BuildConfig(BaseModel, extra=Extra.allow):
|
||||
hidden_act: Optional[str]
|
||||
n_positions: int
|
||||
max_batch_size: int
|
||||
max_input_len: int
|
||||
max_input_len: Optional[int] = None
|
||||
num_kv_heads: Optional[int] = None
|
||||
max_output_len: Optional[int] = None
|
||||
max_beam_width: int = 1
|
||||
@ -54,10 +54,25 @@ class BuildConfig(BaseModel, extra=Extra.allow):
|
||||
moe_top_k: int = None
|
||||
|
||||
|
||||
class EncDecBuildConfig(BuildConfig, extra=Extra.allow):
|
||||
num_decoder_layers: Optional[int] = None
|
||||
head_size: Optional[int] = None
|
||||
ffn_hidden_size: Optional[int] = None
|
||||
num_buckets: Optional[int] = None
|
||||
max_distance: Optional[int] = None
|
||||
max_encoder_input_len: Optional[int] = None
|
||||
max_decoder_input_len: Optional[int] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.head_size is not None
|
||||
assert self.ffn_hidden_size is not None
|
||||
assert self.num_buckets is not None
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
name: str
|
||||
family: str
|
||||
benchmark_type: Literal["gpt", "bert"]
|
||||
benchmark_type: Literal["gpt", "bert", "enc_dec"]
|
||||
build_config: BuildConfig
|
||||
|
||||
|
||||
@ -263,7 +278,7 @@ _allowed_configs = {
|
||||
hidden_size=4096,
|
||||
vocab_size=32000,
|
||||
hidden_act='silu',
|
||||
n_positions=2048,
|
||||
n_positions=4096,
|
||||
inter_size=11008,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
@ -280,7 +295,7 @@ _allowed_configs = {
|
||||
hidden_size=5120,
|
||||
vocab_size=32000,
|
||||
hidden_act='silu',
|
||||
n_positions=2048,
|
||||
n_positions=4096,
|
||||
inter_size=13824,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
@ -315,7 +330,7 @@ _allowed_configs = {
|
||||
hidden_size=8192,
|
||||
vocab_size=32000,
|
||||
hidden_act='silu',
|
||||
n_positions=2048,
|
||||
n_positions=4096,
|
||||
inter_size=28672,
|
||||
max_batch_size=64,
|
||||
max_input_len=512,
|
||||
@ -332,7 +347,7 @@ _allowed_configs = {
|
||||
hidden_size=8192,
|
||||
vocab_size=32000,
|
||||
hidden_act='silu',
|
||||
n_positions=2048,
|
||||
n_positions=4096,
|
||||
inter_size=28672,
|
||||
max_batch_size=16,
|
||||
max_input_len=8000,
|
||||
@ -349,7 +364,7 @@ _allowed_configs = {
|
||||
hidden_size=8192,
|
||||
vocab_size=32000,
|
||||
hidden_act='silu',
|
||||
n_positions=2048,
|
||||
n_positions=4096,
|
||||
inter_size=28672,
|
||||
max_batch_size=64,
|
||||
max_input_len=200,
|
||||
@ -366,7 +381,7 @@ _allowed_configs = {
|
||||
hidden_size=8192,
|
||||
vocab_size=32000,
|
||||
hidden_act='silu',
|
||||
n_positions=2048,
|
||||
n_positions=4096,
|
||||
inter_size=28672,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
@ -588,7 +603,7 @@ _allowed_configs = {
|
||||
bias=False,
|
||||
use_alibi=False,
|
||||
parallel_attention=True,
|
||||
new_decoder_architecture=False,
|
||||
new_decoder_architecture=True,
|
||||
)),
|
||||
"falcon_180b":
|
||||
ModelConfig(name="falcon_180b",
|
||||
@ -609,7 +624,217 @@ _allowed_configs = {
|
||||
bias=False,
|
||||
use_alibi=False,
|
||||
parallel_attention=True,
|
||||
new_decoder_architecture=False,
|
||||
new_decoder_architecture=True,
|
||||
)),
|
||||
"t5_small":
|
||||
ModelConfig(name="t5_small",
|
||||
family="t5",
|
||||
benchmark_type="enc_dec",
|
||||
build_config=EncDecBuildConfig(
|
||||
num_layers=6,
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
ffn_hidden_size=2048,
|
||||
hidden_size=512,
|
||||
vocab_size=32128,
|
||||
hidden_act="relu",
|
||||
n_positions=512,
|
||||
num_buckets=32,
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"t5_base":
|
||||
ModelConfig(name="t5_base",
|
||||
family="t5",
|
||||
benchmark_type="enc_dec",
|
||||
build_config=EncDecBuildConfig(
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
head_size=64,
|
||||
ffn_hidden_size=3072,
|
||||
hidden_size=768,
|
||||
vocab_size=32128,
|
||||
hidden_act="relu",
|
||||
n_positions=512,
|
||||
num_buckets=32,
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"t5_large":
|
||||
ModelConfig(name="t5_large",
|
||||
family="t5",
|
||||
benchmark_type="enc_dec",
|
||||
build_config=EncDecBuildConfig(
|
||||
num_layers=24,
|
||||
num_heads=16,
|
||||
head_size=64,
|
||||
ffn_hidden_size=4096,
|
||||
hidden_size=1024,
|
||||
vocab_size=32128,
|
||||
hidden_act="relu",
|
||||
n_positions=512,
|
||||
num_buckets=32,
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"t5_3b":
|
||||
ModelConfig(name="t5_3b",
|
||||
family="t5",
|
||||
benchmark_type="enc_dec",
|
||||
build_config=EncDecBuildConfig(
|
||||
num_layers=24,
|
||||
num_heads=32,
|
||||
head_size=128,
|
||||
ffn_hidden_size=16384,
|
||||
hidden_size=1024,
|
||||
vocab_size=32128,
|
||||
hidden_act="relu",
|
||||
n_positions=512,
|
||||
num_buckets=32,
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"t5_11b":
|
||||
ModelConfig(name="t5_11b",
|
||||
family="t5",
|
||||
benchmark_type="enc_dec",
|
||||
build_config=EncDecBuildConfig(
|
||||
num_layers=24,
|
||||
num_heads=128,
|
||||
head_size=128,
|
||||
ffn_hidden_size=65536,
|
||||
hidden_size=1024,
|
||||
vocab_size=32128,
|
||||
hidden_act="relu",
|
||||
n_positions=512,
|
||||
num_buckets=32,
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"flan_t5_small":
|
||||
ModelConfig(name="flan_t5_small",
|
||||
family="t5",
|
||||
benchmark_type="enc_dec",
|
||||
build_config=EncDecBuildConfig(
|
||||
num_layers=8,
|
||||
num_decoder_layers=8,
|
||||
num_heads=6,
|
||||
head_size=64,
|
||||
ffn_hidden_size=1024,
|
||||
hidden_size=512,
|
||||
vocab_size=32128,
|
||||
hidden_act="gated-gelu",
|
||||
n_positions=512,
|
||||
num_buckets=32,
|
||||
max_distance=128,
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"flan_t5_base":
|
||||
ModelConfig(name="flan_t5_base",
|
||||
family="t5",
|
||||
benchmark_type="enc_dec",
|
||||
build_config=EncDecBuildConfig(
|
||||
num_layers=12,
|
||||
num_decoder_layers=12,
|
||||
num_heads=12,
|
||||
head_size=64,
|
||||
ffn_hidden_size=2048,
|
||||
hidden_size=768,
|
||||
vocab_size=32128,
|
||||
hidden_act="gated-gelu",
|
||||
n_positions=512,
|
||||
num_buckets=32,
|
||||
max_distance=128,
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"flan_t5_large":
|
||||
ModelConfig(name="flan_t5_large",
|
||||
family="t5",
|
||||
benchmark_type="enc_dec",
|
||||
build_config=EncDecBuildConfig(
|
||||
num_layers=24,
|
||||
num_decoder_layers=24,
|
||||
num_heads=16,
|
||||
head_size=64,
|
||||
ffn_hidden_size=2816,
|
||||
hidden_size=1024,
|
||||
vocab_size=32128,
|
||||
hidden_act="gated-gelu",
|
||||
n_positions=512,
|
||||
num_buckets=32,
|
||||
max_distance=128,
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"flan_t5_xl":
|
||||
ModelConfig(name="flan_t5_xl",
|
||||
family="t5",
|
||||
benchmark_type="enc_dec",
|
||||
build_config=EncDecBuildConfig(
|
||||
num_layers=24,
|
||||
num_decoder_layers=24,
|
||||
num_heads=32,
|
||||
head_size=64,
|
||||
ffn_hidden_size=5120,
|
||||
hidden_size=2048,
|
||||
vocab_size=32128,
|
||||
hidden_act="gated-gelu",
|
||||
n_positions=512,
|
||||
num_buckets=32,
|
||||
max_distance=128,
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"flan_t5_xxl":
|
||||
ModelConfig(name="flan_t5_xxl",
|
||||
family="t5",
|
||||
benchmark_type="enc_dec",
|
||||
build_config=EncDecBuildConfig(
|
||||
num_layers=24,
|
||||
num_decoder_layers=24,
|
||||
num_heads=64,
|
||||
head_size=64,
|
||||
ffn_hidden_size=10240,
|
||||
hidden_size=4096,
|
||||
vocab_size=32128,
|
||||
hidden_act="gelu_new",
|
||||
n_positions=0,
|
||||
num_buckets=32,
|
||||
max_distance=128,
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
}
|
||||
|
||||
|
||||
@ -59,12 +59,18 @@ def serialize_engine(engine, path):
|
||||
|
||||
class BaseBenchmark(object):
|
||||
|
||||
def __init__(self, engine_dir, model_name, dtype):
|
||||
def __init__(self,
|
||||
engine_dir,
|
||||
model_name,
|
||||
dtype,
|
||||
rank,
|
||||
world_size,
|
||||
serial_build: bool = False):
|
||||
self.engine_dir = engine_dir
|
||||
self.model_name = model_name
|
||||
self.dtype = dtype
|
||||
self.runtime_rank = tensorrt_llm.mpi_rank()
|
||||
self.world_size = tensorrt_llm.mpi_world_size()
|
||||
self.runtime_rank = rank
|
||||
self.world_size = world_size
|
||||
self.engine_model_name = model_name
|
||||
self.quant_mode = QuantMode(0)
|
||||
self.enable_fp8 = False
|
||||
@ -100,8 +106,9 @@ class BaseBenchmark(object):
|
||||
self.runtime_mapping = tensorrt_llm.Mapping(world_size=self.world_size,
|
||||
rank=self.runtime_rank,
|
||||
tp_size=self.world_size)
|
||||
torch.cuda.set_device(self.runtime_rank %
|
||||
self.runtime_mapping.gpus_per_node)
|
||||
if not serial_build:
|
||||
torch.cuda.set_device(self.runtime_rank %
|
||||
self.runtime_mapping.gpus_per_node)
|
||||
|
||||
self.csv_filename = "" # lazy init
|
||||
|
||||
|
||||
@ -141,6 +141,24 @@ def parse_arguments():
|
||||
help=
|
||||
('If this option is specified, it will override the max input len of '
|
||||
'TRT engines to the specified value instead of using pre-defined one'))
|
||||
parser.add_argument(
|
||||
'--max_encoder_input_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
('This argument is only for encoder-decoder models'
|
||||
'If this option is specified, it will override the max encoder input len of TRT engines to the specified value instead of using pre-defined one'
|
||||
'By default when this option is not used, it will use pre-defined max encoder input len'
|
||||
))
|
||||
parser.add_argument(
|
||||
'--max_decoder_input_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
('This argument is only for encoder-decoder models'
|
||||
'If this option is specified, it will override the max decoder input len of TRT engines to the specified value instead of using pre-defined one'
|
||||
'By default when this option is not used, it will use pre-defined max decoder input len'
|
||||
))
|
||||
parser.add_argument(
|
||||
'--max_output_len',
|
||||
type=int,
|
||||
@ -181,6 +199,33 @@ def parse_arguments():
|
||||
'int4_weight_only_awq', 'int4_weight_only_gptq'
|
||||
],
|
||||
help="Optimize the model with specified quantization recipe")
|
||||
parser.add_argument(
|
||||
'--build_only',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help=
|
||||
"Build engine only and skip inference, this can help to benchmark the build time on single gpu node for multi GPU model, where the inference is not possible"
|
||||
)
|
||||
|
||||
parser.add_argument('--serial_build',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Build engines serially")
|
||||
|
||||
parser.add_argument(
|
||||
'--rank',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
"The rank of the model to be built, only used when --build_only and --serial_build is specified"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--world_size',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
"The number of gpus to be used for inference, only used when --build_only and --serial_build is specified"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@ -193,9 +238,11 @@ def main(args):
|
||||
from allowed_configs import get_allowed_models
|
||||
from benchmark_profiler import BenchmarkProfiler
|
||||
from bert_benchmark import BERTBenchmark
|
||||
from enc_dec_benchmark import EncDecBenchmark
|
||||
from gpt_benchmark import GPTBenchmark
|
||||
from mem_monitor import MemoryMonitor
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
logger.set_level(args.log_level)
|
||||
@ -206,20 +253,40 @@ def main(args):
|
||||
# Input length (for BERT-like models)
|
||||
input_len_options = args.input_len.split(';')
|
||||
input_len_options = [int(i) for i in input_len_options]
|
||||
# Input-output length combination (for GPT-like models)
|
||||
# Input-output length combination (for GPT-like models and enc_dec models)
|
||||
in_out_len_options = args.input_output_len.split(';')
|
||||
in_out_len_options = [[int(i) for i in io.split(',')]
|
||||
for io in in_out_len_options]
|
||||
|
||||
if args.serial_build and not args.build_only:
|
||||
raise Exception(
|
||||
f"--serial_build must be used with --build_only, always need to parallel build to do inference in the same process"
|
||||
)
|
||||
|
||||
if args.build_only and args.serial_build and args.rank is not None and args.world_size is not None:
|
||||
rank = args.rank
|
||||
world_size = args.world_size
|
||||
else:
|
||||
rank = tensorrt_llm.mpi_rank()
|
||||
world_size = tensorrt_llm.mpi_world_size()
|
||||
|
||||
benchmark_profiler = None
|
||||
if args.model in get_allowed_models(benchmark_type="gpt"):
|
||||
benchmark_profiler = BenchmarkProfiler()
|
||||
benchmarker = GPTBenchmark(args, batch_size_options, in_out_len_options)
|
||||
benchmarker = GPTBenchmark(args, batch_size_options, in_out_len_options,
|
||||
rank, world_size)
|
||||
elif args.model in get_allowed_models(benchmark_type="bert"):
|
||||
benchmarker = BERTBenchmark(args, batch_size_options, input_len_options)
|
||||
benchmarker = BERTBenchmark(args, batch_size_options, input_len_options,
|
||||
rank, world_size)
|
||||
elif args.model in get_allowed_models(benchmark_type="enc_dec"):
|
||||
benchmarker = EncDecBenchmark(args, batch_size_options,
|
||||
in_out_len_options, rank, world_size)
|
||||
else:
|
||||
raise Exception(f'Unexpected model: {args.model}')
|
||||
|
||||
if args.build_only:
|
||||
return
|
||||
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
benchmarker.print_report_header(args.csv,
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import time
|
||||
|
||||
# isort: off
|
||||
import torch
|
||||
@ -30,8 +29,9 @@ from tensorrt_llm.runtime import TensorInfo
|
||||
|
||||
class BERTBenchmark(BaseBenchmark):
|
||||
|
||||
def __init__(self, args, batch_sizes, in_lens):
|
||||
super().__init__(args.engine_dir, args.model, args.dtype)
|
||||
def __init__(self, args, batch_sizes, in_lens, rank, world_size):
|
||||
super().__init__(args.engine_dir, args.model, args.dtype, rank,
|
||||
world_size, args.serial_build)
|
||||
self.batch_sizes = batch_sizes
|
||||
self.in_lens = in_lens
|
||||
self.build_time = 0
|
||||
@ -49,12 +49,17 @@ class BERTBenchmark(BaseBenchmark):
|
||||
setattr(self, key, value)
|
||||
if args.force_num_layer_1:
|
||||
self.num_layers = 1
|
||||
if args.max_batch_size is not None:
|
||||
self.max_batch_size = args.max_batch_size
|
||||
if args.max_input_len is not None:
|
||||
self.max_input_len = args.max_input_len
|
||||
|
||||
start = time.time()
|
||||
engine_buffer = build_bert(args)
|
||||
self.build_time = round(time.time() - start, 2)
|
||||
engine_buffer, build_time = build_bert(args)
|
||||
self.build_time = build_time
|
||||
|
||||
assert engine_buffer is not None
|
||||
if args.build_only:
|
||||
return
|
||||
|
||||
self.session = tensorrt_llm.runtime.Session.from_serialized_engine(
|
||||
engine_buffer)
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
import tensorrt as trt
|
||||
@ -28,7 +29,7 @@ from tensorrt_llm._utils import str_dtype_to_trt
|
||||
from tensorrt_llm.builder import Builder
|
||||
from tensorrt_llm.layers import PositionEmbeddingType
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.models import quantize_model
|
||||
from tensorrt_llm.models import PretrainedConfig, quantize_model
|
||||
from tensorrt_llm.network import net_guard
|
||||
from tensorrt_llm.plugin.plugin import ContextFMHAType
|
||||
from tensorrt_llm.quantization import QuantMode
|
||||
@ -117,6 +118,25 @@ def parse_arguments():
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Quick sanity check with num_layer=1.')
|
||||
parser.add_argument('--serial_build',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Build engines serially")
|
||||
|
||||
parser.add_argument(
|
||||
'--rank',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
"The rank of the model to be built, only used when --serial_build is specified"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--world_size',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
"The number of gpus to be used for inference, only used when --serial_build is specified"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@ -194,9 +214,15 @@ def build_gpt(args):
|
||||
build_config['num_layers'] = 1
|
||||
|
||||
# More parameters
|
||||
world_size = tensorrt_llm.mpi_world_size()
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
torch.cuda.set_device(runtime_rank)
|
||||
if args.serial_build and args.rank is not None and args.world_size is not None:
|
||||
runtime_rank = args.rank
|
||||
world_size = args.world_size
|
||||
else:
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
world_size = tensorrt_llm.mpi_world_size()
|
||||
if not args.serial_build:
|
||||
torch.cuda.set_device(runtime_rank)
|
||||
|
||||
num_kv_heads = build_config['num_heads'] \
|
||||
if build_config['num_kv_heads'] is None else build_config['num_kv_heads']
|
||||
apply_query_key_layer_scaling = False
|
||||
@ -265,18 +291,26 @@ def build_gpt(args):
|
||||
moe_layer_config=tensorrt_llm.moe_config.MoeLayerConfig(
|
||||
build_config["moe_num_experts"], build_config["moe_top_k"]))
|
||||
elif family == "opt":
|
||||
tensorrt_llm_model = tensorrt_llm.models.OPTLMHeadModel(
|
||||
num_layers=build_config['num_layers'],
|
||||
num_heads=build_config['num_heads'],
|
||||
hidden_size=build_config['hidden_size'],
|
||||
vocab_size=build_config['vocab_size'],
|
||||
hidden_act=build_config['hidden_act'],
|
||||
max_position_embeddings=build_config['n_positions'],
|
||||
dtype=kv_dtype,
|
||||
mapping=tensorrt_llm.Mapping(world_size=world_size,
|
||||
tp_size=world_size), # TP only
|
||||
pre_norm=build_config['pre_norm'],
|
||||
do_layer_norm_before=build_config['do_layer_norm_before'])
|
||||
config = {
|
||||
'architecture': 'OPTForCausalLM',
|
||||
'dtype': args.dtype,
|
||||
'vocab_size': build_config['vocab_size'],
|
||||
'hidden_size': build_config['hidden_size'],
|
||||
'num_hidden_layers': build_config['num_layers'],
|
||||
'num_attention_heads': build_config['num_heads'],
|
||||
'hidden_act': build_config['hidden_act'],
|
||||
'max_position_embeddings': build_config['n_positions'],
|
||||
'mapping': {
|
||||
'world_size': world_size,
|
||||
'tp_size': world_size
|
||||
},
|
||||
'use_parallel_embedding': False,
|
||||
'share_embedding_table': False,
|
||||
'embedding_sharding_dim': 0,
|
||||
'do_layer_norm_before': build_config['do_layer_norm_before']
|
||||
}
|
||||
config = PretrainedConfig.from_dict(config)
|
||||
tensorrt_llm_model = tensorrt_llm.models.OPTForCausalLM(config)
|
||||
elif family == "llama":
|
||||
tensorrt_llm_model = tensorrt_llm.models.LLaMAForCausalLM(
|
||||
num_layers=build_config['num_layers'],
|
||||
@ -467,8 +501,10 @@ def build_gpt(args):
|
||||
tensorrt_llm.graph_rewriting.optimize(network)
|
||||
|
||||
# Network -> Engine
|
||||
start = time.time()
|
||||
engine = builder.build_engine(network, builder_config)
|
||||
assert engine is not None, f'Failed to build engine for rank {runtime_rank}'
|
||||
build_time = round(time.time() - start, 2)
|
||||
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
@ -478,7 +514,7 @@ def build_gpt(args):
|
||||
config_path = os.path.join(args.output_dir, 'config.json')
|
||||
builder_config.plugin_config = network.plugin_config
|
||||
builder.save_config(builder_config, config_path)
|
||||
return engine
|
||||
return engine, build_time
|
||||
|
||||
|
||||
def build_bert(args):
|
||||
@ -487,9 +523,15 @@ def build_bert(args):
|
||||
build_config['num_layers'] = 1
|
||||
|
||||
# More parameters
|
||||
world_size = tensorrt_llm.mpi_world_size()
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
torch.cuda.set_device(runtime_rank)
|
||||
if args.serial_build and args.rank is not None and args.world_size is not None:
|
||||
runtime_rank = args.rank
|
||||
world_size = args.world_size
|
||||
else:
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
world_size = tensorrt_llm.mpi_world_size()
|
||||
if not args.serial_build:
|
||||
torch.cuda.set_device(runtime_rank)
|
||||
|
||||
num_kv_heads = build_config['num_heads'] \
|
||||
if build_config['num_kv_heads'] is None else build_config['num_kv_heads']
|
||||
max_batch_size = build_config['max_batch_size'] \
|
||||
@ -575,8 +617,10 @@ def build_bert(args):
|
||||
hidden_states.mark_output('hidden_states', hidden_states_dtype)
|
||||
|
||||
# Network -> Engine
|
||||
start = time.time()
|
||||
engine = builder.build_engine(network, builder_config)
|
||||
assert engine is not None, f'Failed to build engine for rank {runtime_rank}'
|
||||
build_time = round(time.time() - start, 2)
|
||||
|
||||
if args.output_dir is not None:
|
||||
if not os.path.exists(args.output_dir):
|
||||
@ -587,7 +631,7 @@ def build_bert(args):
|
||||
config_path = os.path.join(args.output_dir, 'config.json')
|
||||
builder_config.plugin_config = network.plugin_config
|
||||
builder.save_config(builder_config, config_path)
|
||||
return engine
|
||||
return engine, build_time
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
341
benchmarks/python/enc_dec_benchmark.py
Normal file
341
benchmarks/python/enc_dec_benchmark.py
Normal file
@ -0,0 +1,341 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
from base_benchmark import BaseBenchmark, get_engine_name
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._utils import trt_dtype_to_torch
|
||||
|
||||
|
||||
class EncDecBenchmark(BaseBenchmark):
|
||||
|
||||
def __init__(self, args, batch_sizes, in_out_lens, rank, world_size):
|
||||
self.engine_dir = args.engine_dir
|
||||
self.model_name = args.model
|
||||
self.mode = args.mode
|
||||
self.enable_fp8 = False # hardcode for enc-dec models
|
||||
self.dtype = args.dtype
|
||||
self.output_dir = args.output_dir
|
||||
self.runtime_rank = rank
|
||||
self.world_size = world_size
|
||||
self.csv_filename = "" # lazy init
|
||||
self.batch_sizes = batch_sizes
|
||||
self.in_out_lens = in_out_lens
|
||||
self.num_beams = args.num_beams
|
||||
self.build_time = 0
|
||||
# In current implementation, encoder and decoder have the same name,
|
||||
# builder config, and plugin config. But they can be different in the future.
|
||||
# So we use separate variables for encoder and decoder here.
|
||||
self.encoder_engine_model_name = args.model
|
||||
self.decoder_engine_model_name = args.model
|
||||
|
||||
if self.engine_dir is not None:
|
||||
|
||||
def read_config(component):
|
||||
config_path = os.path.join(self.engine_dir, component,
|
||||
"config.json")
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
# Sanity checks
|
||||
config_dtype = config["builder_config"]["precision"]
|
||||
assert (
|
||||
self.dtype == config_dtype
|
||||
), f"Engine dtype ({config_dtype}) != Runtime dtype ({self.dtype})"
|
||||
world_size = config["builder_config"]["tensor_parallel"]
|
||||
assert (
|
||||
world_size == self.world_size
|
||||
), f"Engine world size ({world_size}) != Runtime world size ({self.world_size})"
|
||||
tp_size = config["builder_config"]["tensor_parallel"]
|
||||
# TP only for benchmarking
|
||||
assert (
|
||||
tp_size == self.world_size
|
||||
), f"Engine tensor parallel size ({tp_size}) should be equal to world size ({self.world_size})"
|
||||
assert (
|
||||
config["plugin_config"]["remove_input_padding"] == False
|
||||
), "remove_input_padding should be False for enc-dec benchmarks"
|
||||
num_heads = config["builder_config"]["num_heads"]
|
||||
assert (num_heads % tp_size) == 0
|
||||
# Get model config
|
||||
num_heads = num_heads // tp_size
|
||||
hidden_size = config["builder_config"]["hidden_size"] // tp_size
|
||||
num_kv_heads = config["builder_config"].get(
|
||||
"num_kv_heads", config["builder_config"]["num_heads"])
|
||||
num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size
|
||||
|
||||
model_config = tensorrt_llm.runtime.ModelConfig(
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
hidden_size=hidden_size,
|
||||
head_size=config["builder_config"]["head_size"],
|
||||
vocab_size=config["builder_config"]["vocab_size"],
|
||||
num_layers=config["builder_config"]["num_layers"],
|
||||
gpt_attention_plugin=config["plugin_config"]
|
||||
["gpt_attention_plugin"],
|
||||
remove_input_padding=config["plugin_config"]
|
||||
["remove_input_padding"],
|
||||
cross_attention=config["builder_config"]["cross_attention"],
|
||||
has_position_embedding=config["builder_config"]
|
||||
["has_position_embedding"],
|
||||
has_token_type_embedding=config["builder_config"]
|
||||
["has_token_type_embedding"],
|
||||
use_custom_all_reduce=config["plugin_config"].get(
|
||||
"use_custom_all_reduce", False),
|
||||
dtype=config_dtype,
|
||||
)
|
||||
# get builder config
|
||||
builder_config = dict()
|
||||
for key, value in config["builder_config"].items():
|
||||
if key == "name":
|
||||
engine_model_name = value
|
||||
else:
|
||||
builder_config[key] = value
|
||||
# get plugin config
|
||||
plugin_config = dict()
|
||||
for key, value in config["plugin_config"].items():
|
||||
# Same effect as self.use_foo_plugin = config.json["foo_plugin"]
|
||||
if "plugin" in key:
|
||||
key = "use_" + key
|
||||
plugin_config[key] = value
|
||||
return engine_model_name, model_config, builder_config, plugin_config
|
||||
|
||||
(
|
||||
self.encoder_engine_model_name,
|
||||
self.encoder_model_config,
|
||||
self.encoder_builder_config,
|
||||
self.encoder_plugin_config,
|
||||
) = read_config("encoder")
|
||||
(
|
||||
self.decoder_engine_model_name,
|
||||
self.decoder_model_config,
|
||||
self.decoder_builder_config,
|
||||
self.decoder_plugin_config,
|
||||
) = read_config("decoder")
|
||||
|
||||
self.encoder_engine_name = get_engine_name(
|
||||
self.encoder_engine_model_name,
|
||||
self.dtype,
|
||||
self.world_size,
|
||||
self.runtime_rank,
|
||||
)
|
||||
self.decoder_engine_name = get_engine_name(
|
||||
self.decoder_engine_model_name,
|
||||
self.dtype,
|
||||
self.world_size,
|
||||
self.runtime_rank,
|
||||
)
|
||||
self.encoder_runtime_mapping = tensorrt_llm.Mapping(
|
||||
world_size=self.world_size,
|
||||
rank=self.runtime_rank,
|
||||
tp_size=self.world_size,
|
||||
gpus_per_node=self.encoder_builder_config.get("gpus_per_node", 8),
|
||||
)
|
||||
self.decoder_runtime_mapping = tensorrt_llm.Mapping(
|
||||
world_size=self.world_size,
|
||||
rank=self.runtime_rank,
|
||||
tp_size=self.world_size,
|
||||
gpus_per_node=self.encoder_builder_config.get("gpus_per_node", 8),
|
||||
)
|
||||
|
||||
if not args.serial_build:
|
||||
torch.cuda.set_device(self.runtime_rank %
|
||||
self.encoder_runtime_mapping.gpus_per_node)
|
||||
self.device = torch.cuda.current_device()
|
||||
|
||||
if self.engine_dir is not None:
|
||||
# Deserialize engine from engine directory
|
||||
self.encoder_serialize_path = os.path.join(self.engine_dir,
|
||||
"encoder",
|
||||
self.encoder_engine_name)
|
||||
with open(self.encoder_serialize_path, "rb") as f:
|
||||
encoder_engine_buffer = f.read()
|
||||
self.decoder_serialize_path = os.path.join(self.engine_dir,
|
||||
"decoder",
|
||||
self.decoder_engine_name)
|
||||
with open(self.decoder_serialize_path, "rb") as f:
|
||||
decoder_engine_buffer = f.read()
|
||||
else:
|
||||
# TODO: Build engine
|
||||
assert False, "Engine directory is currently required for enc-dec benchmarks"
|
||||
encoder_engine_buffer = None
|
||||
decoder_engine_buffer = None
|
||||
|
||||
assert encoder_engine_buffer is not None
|
||||
assert decoder_engine_buffer is not None
|
||||
|
||||
# session setup
|
||||
self.encoder_session = tensorrt_llm.runtime.Session.from_serialized_engine(
|
||||
encoder_engine_buffer)
|
||||
self.decoder_session = tensorrt_llm.runtime.GenerationSession(
|
||||
self.decoder_model_config,
|
||||
decoder_engine_buffer,
|
||||
self.decoder_runtime_mapping,
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
max_batch_size = self.encoder_builder_config["max_batch_size"]
|
||||
for inlen, outlen in self.in_out_lens:
|
||||
if (inlen > self.encoder_builder_config["max_encoder_input_len"]
|
||||
or outlen > self.encoder_builder_config["max_output_len"]):
|
||||
print(
|
||||
f"[WARNING] check inlen({inlen}) <= max_inlen({self.max_input_len}) and "
|
||||
f"outlen({outlen}) <= max_outlen({self.max_output_len}) failed, skipping."
|
||||
)
|
||||
continue
|
||||
for batch_size in self.batch_sizes:
|
||||
if batch_size > max_batch_size:
|
||||
print(
|
||||
f"[WARNING] check batch_size({batch_size}) "
|
||||
f"<= max_batch_size({max_batch_size}) failed, skipping."
|
||||
)
|
||||
continue
|
||||
yield (batch_size, inlen, outlen)
|
||||
|
||||
def prepare_inputs(self, config):
|
||||
batch_size, encoder_input_len = config[0], config[1]
|
||||
encoder_input_ids = (torch.randint(
|
||||
100, (batch_size, encoder_input_len)).int().cuda())
|
||||
# For now, just hardcode the decoder_start_token_id to 0 for t5 models.
|
||||
decoder_start_token_id = 0
|
||||
decoder_input_ids = torch.IntTensor([[decoder_start_token_id]
|
||||
]).to(self.device)
|
||||
decoder_input_ids = decoder_input_ids.repeat(
|
||||
(encoder_input_ids.shape[0], 1))
|
||||
# in padding mode --> keep input, just calculate actual length and max length
|
||||
# Note: 1st token should always count, even if it is pad_token_id (0). e.g., decoder start id in enc-dec models could be a single pad_token_id, we should count
|
||||
encoder_input_lengths = ((1 + (encoder_input_ids[:, 1:] != 0).sum(
|
||||
dim=1).type(torch.IntTensor).to(self.device)).clone().detach().to(
|
||||
dtype=torch.int32, device=self.device))
|
||||
decoder_input_lengths = ((1 + (decoder_input_ids[:, 1:] != 0).sum(
|
||||
dim=1).type(torch.IntTensor).to(self.device)).clone().detach().to(
|
||||
dtype=torch.int32, device=self.device))
|
||||
|
||||
stream = torch.cuda.current_stream().cuda_stream
|
||||
return (
|
||||
encoder_input_ids,
|
||||
encoder_input_lengths,
|
||||
decoder_input_ids,
|
||||
decoder_input_lengths,
|
||||
stream,
|
||||
)
|
||||
|
||||
def run(self, inputs, config, benchmark_profiler=None):
|
||||
output_len = config[2]
|
||||
(
|
||||
encoder_input_ids,
|
||||
encoder_input_lengths,
|
||||
decoder_input_ids,
|
||||
decoder_input_lengths,
|
||||
stream,
|
||||
) = inputs
|
||||
|
||||
hidden_size = (self.encoder_model_config.hidden_size *
|
||||
self.encoder_runtime_mapping.tp_size)
|
||||
hidden_states_shape = (
|
||||
encoder_input_ids.shape[0],
|
||||
encoder_input_ids.shape[1],
|
||||
hidden_size,
|
||||
)
|
||||
hidden_states_dtype = lambda name: trt_dtype_to_torch(
|
||||
self.encoder_session.engine.get_tensor_dtype(name))
|
||||
|
||||
# input tensors
|
||||
inputs = {}
|
||||
inputs["input_ids"] = encoder_input_ids.contiguous()
|
||||
inputs["input_lengths"] = encoder_input_lengths
|
||||
inputs["max_input_length"] = torch.empty(
|
||||
(self.encoder_builder_config["max_encoder_input_len"], ),
|
||||
dtype=hidden_states_dtype("max_input_length"),
|
||||
device=self.device,
|
||||
).contiguous()
|
||||
|
||||
# output tensors
|
||||
outputs = {}
|
||||
outputs["encoder_output"] = torch.empty(
|
||||
hidden_states_shape,
|
||||
dtype=hidden_states_dtype("encoder_output"),
|
||||
device=self.device,
|
||||
).contiguous()
|
||||
|
||||
# run encoder
|
||||
self.encoder_session.set_shapes(inputs)
|
||||
ok = self.encoder_session.run(inputs, outputs, stream)
|
||||
assert ok, "Runtime execution failed"
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# run decoder
|
||||
sampling_config = tensorrt_llm.runtime.SamplingConfig(
|
||||
end_id=1, pad_id=0, num_beams=self.num_beams, min_length=output_len)
|
||||
|
||||
self.decoder_session.setup(
|
||||
decoder_input_lengths.size(0),
|
||||
torch.max(decoder_input_lengths).item(),
|
||||
output_len,
|
||||
beam_width=self.num_beams,
|
||||
max_kv_cache_length=None,
|
||||
encoder_max_input_length=torch.max(encoder_input_lengths).item(),
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
self.decoder_session.decode(
|
||||
decoder_input_ids,
|
||||
decoder_input_lengths,
|
||||
sampling_config,
|
||||
encoder_output=outputs["encoder_output"],
|
||||
encoder_input_lengths=encoder_input_lengths,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def report(self,
|
||||
config,
|
||||
latency,
|
||||
percentile95,
|
||||
percentile99,
|
||||
peak_gpu_used,
|
||||
csv,
|
||||
benchmark_profiler=None):
|
||||
# Note: Theoretically, the encoder and decoder can have different configs.
|
||||
# But for current implementation, we assume they are the same. In the future,
|
||||
# we can have a special structure of report_dict for enc-dec models.
|
||||
report_dict = super().get_report_dict()
|
||||
batch_size, encoder_input_len, output_len = config[0], config[
|
||||
1], config[2]
|
||||
tokens_per_sec = round(batch_size * output_len / (latency / 1000), 2)
|
||||
report_dict["num_heads"] = self.encoder_model_config.num_heads
|
||||
report_dict["num_kv_heads"] = self.encoder_model_config.num_kv_heads
|
||||
report_dict["num_layers"] = self.encoder_model_config.num_layers
|
||||
report_dict["hidden_size"] = self.encoder_model_config.hidden_size
|
||||
report_dict["vocab_size"] = self.encoder_model_config.vocab_size
|
||||
report_dict["batch_size"] = batch_size
|
||||
report_dict["input_length"] = encoder_input_len
|
||||
report_dict["output_length"] = output_len
|
||||
report_dict["latency(ms)"] = latency
|
||||
report_dict["build_time(s)"] = self.build_time
|
||||
report_dict["tokens_per_sec"] = tokens_per_sec
|
||||
report_dict["percentile95(ms)"] = percentile95
|
||||
report_dict["percentile99(ms)"] = percentile99
|
||||
report_dict["gpu_peak_mem(gb)"] = peak_gpu_used
|
||||
if self.runtime_rank == 0:
|
||||
if csv:
|
||||
line = ",".join([str(v) for v in report_dict.values()])
|
||||
print(line)
|
||||
with open(self.get_csv_filename(), "a") as file:
|
||||
file.write(line + "\n")
|
||||
else:
|
||||
kv_pairs = [f"{k} {v}" for k, v in report_dict.items()]
|
||||
line = "[BENCHMARK] " + " ".join(kv_pairs)
|
||||
print(line)
|
||||
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import time
|
||||
from math import ceil
|
||||
|
||||
import torch
|
||||
@ -27,8 +26,9 @@ from build import build_gpt, get_quant_mode # isort:skip
|
||||
|
||||
class GPTBenchmark(BaseBenchmark):
|
||||
|
||||
def __init__(self, args, batch_sizes, in_out_lens):
|
||||
super().__init__(args.engine_dir, args.model, args.dtype)
|
||||
def __init__(self, args, batch_sizes, in_out_lens, rank, world_size):
|
||||
super().__init__(args.engine_dir, args.model, args.dtype, rank,
|
||||
world_size, args.serial_build)
|
||||
self.batch_sizes = batch_sizes
|
||||
self.in_out_lens = in_out_lens
|
||||
self.num_beams = args.num_beams
|
||||
@ -49,6 +49,13 @@ class GPTBenchmark(BaseBenchmark):
|
||||
setattr(self, key, value)
|
||||
if args.force_num_layer_1:
|
||||
self.num_layers = 1
|
||||
if args.max_batch_size is not None:
|
||||
self.max_batch_size = args.max_batch_size
|
||||
if args.max_input_len is not None:
|
||||
self.max_input_len = args.max_input_len
|
||||
if args.max_output_len is not None:
|
||||
self.max_output_len = args.max_output_len
|
||||
|
||||
self.quant_mode, _, _, _ = get_quant_mode(args.quantization)
|
||||
self.enable_fp8 = self.quant_mode.has_fp8_qdq()
|
||||
self.fp8_kv_cache = self.quant_mode.has_fp8_kv_cache()
|
||||
@ -62,11 +69,12 @@ class GPTBenchmark(BaseBenchmark):
|
||||
elif args.mode == 'ootb-except-mha':
|
||||
self.use_gpt_attention_plugin = True
|
||||
|
||||
start = time.time()
|
||||
engine_buffer = build_gpt(args)
|
||||
self.build_time = round(time.time() - start, 2)
|
||||
engine_buffer, build_time = build_gpt(args)
|
||||
self.build_time = build_time
|
||||
|
||||
assert engine_buffer is not None
|
||||
if args.build_only:
|
||||
return
|
||||
|
||||
if not hasattr(self, 'num_kv_heads') or self.num_kv_heads is None:
|
||||
self.num_kv_heads = self.num_heads
|
||||
@ -194,7 +202,8 @@ class GPTBenchmark(BaseBenchmark):
|
||||
'generation_step_count')
|
||||
token_per_step = batch_size * self.num_beams
|
||||
total_tokens = generation_step_count * token_per_step
|
||||
report_dict["generation_time(ms)"] = generation_time_ms / iter_count
|
||||
report_dict["generation_time(ms)"] = round(
|
||||
generation_time_ms / iter_count, 3)
|
||||
report_dict["total_generated_tokens"] = total_tokens / iter_count
|
||||
tokens_per_second = round(
|
||||
total_tokens * 1000.0 / generation_time_ms, 3)
|
||||
|
||||
@ -17,7 +17,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/namedTensor.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
@ -57,6 +56,9 @@ auto constexpr kReturnLogProbsTensorName = "return_log_probs";
|
||||
auto constexpr kPromptEmbeddingTableName = "prompt_embedding_table";
|
||||
auto constexpr kPromptVocabSizeName = "prompt_vocab_size";
|
||||
|
||||
// Obsolete names for backward compatibility
|
||||
auto constexpr kInputLengthsTensorName = "input_lengths";
|
||||
|
||||
// Output tensors
|
||||
auto constexpr kOutputIdsTensorName = "output_ids";
|
||||
auto constexpr kSequenceLengthTensorName = "sequence_length";
|
||||
@ -131,6 +133,8 @@ public:
|
||||
inference_request::kReturnLogProbsTensorName,
|
||||
inference_request::kPromptEmbeddingTableName,
|
||||
inference_request::kPromptVocabSizeName,
|
||||
// obsolete names for backward compatibility
|
||||
inference_request::kInputLengthsTensorName,
|
||||
};
|
||||
|
||||
#define TENSOR_GETTER_SETTER(funcName, tensorName) \
|
||||
@ -191,11 +195,8 @@ public:
|
||||
protected:
|
||||
static void validateTensorName(std::string const& tensorName)
|
||||
{
|
||||
// TODO (martinma): Throw an exception if the tensor name is not valid.
|
||||
if (std::find(kTensorNames.begin(), kTensorNames.end(), tensorName) == kTensorNames.end())
|
||||
{
|
||||
TLLM_LOG_WARNING("Invalid tensor name in InferenceRequest: %s", tensorName.c_str());
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(std::find(kTensorNames.begin(), kTensorNames.end(), tensorName) != kTensorNames.end(),
|
||||
"Invalid tensor name: %s", tensorName.c_str());
|
||||
}
|
||||
|
||||
uint64_t mRequestId;
|
||||
|
||||
@ -30,17 +30,19 @@ public:
|
||||
using SizeType = tensorrt_llm::runtime::SizeType;
|
||||
|
||||
explicit KvCacheConfig(std::optional<SizeType> maxTokens = std::nullopt,
|
||||
std::optional<SizeType> maxKvCacheLength = std::nullopt,
|
||||
std::optional<float> freeGpuMemoryFraction = std::nullopt)
|
||||
std::optional<SizeType> maxAttentionWindow = std::nullopt,
|
||||
std::optional<float> freeGpuMemoryFraction = std::nullopt, bool enableBlockReuse = false)
|
||||
: maxTokens{maxTokens}
|
||||
, maxKvCacheLength{maxKvCacheLength}
|
||||
, maxAttentionWindow{maxAttentionWindow}
|
||||
, freeGpuMemoryFraction{freeGpuMemoryFraction}
|
||||
, enableBlockReuse(enableBlockReuse)
|
||||
{
|
||||
}
|
||||
|
||||
std::optional<SizeType> maxTokens;
|
||||
std::optional<SizeType> maxKvCacheLength;
|
||||
std::optional<SizeType> maxAttentionWindow;
|
||||
std::optional<float> freeGpuMemoryFraction;
|
||||
bool enableBlockReuse;
|
||||
|
||||
static constexpr auto kDefaultGpuMemFraction = 0.85f;
|
||||
};
|
||||
|
||||
@ -17,8 +17,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h" // TODO forward declare
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/gptModelConfig.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
@ -26,14 +27,52 @@
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace std
|
||||
{
|
||||
|
||||
// Implement std::hash function object for vector<TokenIdType>.
|
||||
// This allows us to use unordered_map with vector<TokenIdType> as key.
|
||||
// Based on https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector/72073933#72073933
|
||||
|
||||
template <>
|
||||
struct hash<vector<int32_t>>
|
||||
{
|
||||
size_t operator()(vector<int32_t> const& vec) const noexcept
|
||||
{
|
||||
size_t seed = vec.size();
|
||||
for (auto x : vec)
|
||||
{
|
||||
uint32_t y = static_cast<uint32_t>(x);
|
||||
y = ((y >> 16) ^ y) * 0x45d9f3b;
|
||||
y = ((y >> 16) ^ y) * 0x45d9f3b;
|
||||
y = (y >> 16) ^ y;
|
||||
seed ^= y + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
{
|
||||
|
||||
class KVCacheBlock;
|
||||
|
||||
using SizeType = tensorrt_llm::runtime::SizeType;
|
||||
using TokenIdType = tensorrt_llm::runtime::TokenIdType;
|
||||
using VecTokens = std::vector<TokenIdType>;
|
||||
using BeamTokens = std::vector<VecTokens>;
|
||||
using BlockPtr = std::shared_ptr<KVCacheBlock>;
|
||||
using FreeBlocksQueue = std::list<BlockPtr>;
|
||||
using NextBlockMap = std::unordered_map<VecTokens, BlockPtr>;
|
||||
|
||||
struct KvCacheStats
|
||||
{
|
||||
@ -49,8 +88,6 @@ struct KvCacheStats
|
||||
class KVCacheBlock
|
||||
{
|
||||
public:
|
||||
using SizeType = tensorrt_llm::runtime::SizeType;
|
||||
|
||||
explicit KVCacheBlock(SizeType blockIdx);
|
||||
|
||||
void startScheduling();
|
||||
@ -67,6 +104,29 @@ public:
|
||||
|
||||
[[nodiscard]] bool hasSchedulingRefs() const;
|
||||
|
||||
void setTokens(VecTokens& tokens, bool isFull);
|
||||
|
||||
[[nodiscard]] VecTokens const& getTokens() const;
|
||||
|
||||
void setFreeBlockIterator(FreeBlocksQueue::iterator freeBlockIterator);
|
||||
|
||||
void resetFreeBlockIterator();
|
||||
|
||||
[[nodiscard]] std::optional<FreeBlocksQueue::iterator> const& getFreeBlockIterator() const;
|
||||
|
||||
void setPrevBlock(BlockPtr prevBlock);
|
||||
|
||||
void addNextBlock(VecTokens const& tokens, BlockPtr block);
|
||||
|
||||
void removeNextBlock(VecTokens const& tokens);
|
||||
|
||||
[[nodiscard]] BlockPtr findMatchingBlock(VecTokens const& tokens) const;
|
||||
|
||||
//! \brief Free block from previous block if present.
|
||||
void freeLeafBlock();
|
||||
|
||||
[[nodiscard]] bool isFull() const;
|
||||
|
||||
private:
|
||||
// Linear index of block in pool
|
||||
SizeType mBlockIdx;
|
||||
@ -76,6 +136,21 @@ private:
|
||||
|
||||
// Number of references to the block
|
||||
SizeType mSchedulingRefCount;
|
||||
|
||||
// Key of this block in mNextBlocks map in block pointed to by mPrevBlock
|
||||
VecTokens mTokens;
|
||||
|
||||
// Previous block in sequence
|
||||
BlockPtr mPrevBlock;
|
||||
|
||||
// Next block(s) in sequence(s)
|
||||
NextBlockMap mNextBlocks;
|
||||
|
||||
// Iterator pointing to this block in mFreeBlocks.
|
||||
std::optional<FreeBlocksQueue::iterator> mFreeBlockIterator;
|
||||
|
||||
// Flag indicating if block is full
|
||||
bool mIsFull;
|
||||
};
|
||||
|
||||
class GenerationRequest
|
||||
@ -84,32 +159,22 @@ public:
|
||||
using SizeType = tensorrt_llm::runtime::SizeType;
|
||||
using SharedPtr = std::shared_ptr<GenerationRequest>;
|
||||
|
||||
GenerationRequest(SizeType batchSlotIdx, SizeType numTokens, SizeType beamWidth)
|
||||
: mBatchSlotIdx(batchSlotIdx)
|
||||
explicit GenerationRequest(SizeType seqSlotIdx, SizeType numTokens, SizeType beamWidth)
|
||||
: mSeqSlotIdx(seqSlotIdx)
|
||||
, mNumTokens(numTokens)
|
||||
, mBeamWidth(beamWidth)
|
||||
, mCacheBlockIds(beamWidth)
|
||||
{
|
||||
}
|
||||
|
||||
void setBatchSlotIdx(SizeType batchSlotIdx)
|
||||
{
|
||||
mBatchSlotIdx = batchSlotIdx;
|
||||
}
|
||||
|
||||
void setNumTokens(SizeType numTokens)
|
||||
{
|
||||
mNumTokens = numTokens;
|
||||
}
|
||||
|
||||
void addToken()
|
||||
{
|
||||
mNumTokens++;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType getBatchSlotIdx() const
|
||||
[[nodiscard]] SizeType getSequenceSlotIdx() const
|
||||
{
|
||||
return mBatchSlotIdx;
|
||||
return mSeqSlotIdx;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType getNumTokens() const
|
||||
@ -140,15 +205,29 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void setNumPrepopulatedTokens(std::vector<int> numPrepopulatedTokens)
|
||||
{
|
||||
mNumPrepopulatedTokens = std::move(numPrepopulatedTokens);
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<int> const& getNumPrepopulatedTokens() const
|
||||
{
|
||||
return mNumPrepopulatedTokens;
|
||||
}
|
||||
|
||||
private:
|
||||
// Index of sequence in the batch
|
||||
SizeType mBatchSlotIdx;
|
||||
// Slot id of the sequence
|
||||
SizeType mSeqSlotIdx;
|
||||
// Current number of generated tokens
|
||||
SizeType mNumTokens;
|
||||
// Number of beams
|
||||
SizeType mBeamWidth;
|
||||
// List of blocks allocated for each beam of the sequence
|
||||
std::vector<std::vector<SizeType>> mCacheBlockIds;
|
||||
// Number of tokens already in kv cache before context phase.
|
||||
// A value > 0 indicates cached kv cache blocks were reused.
|
||||
// One value per beam.
|
||||
std::vector<int> mNumPrepopulatedTokens;
|
||||
};
|
||||
|
||||
// BlockManager manages overall metadata of KVCacheBlocks in a layer of the
|
||||
@ -161,23 +240,34 @@ private:
|
||||
// Block shape is [2, num_heads, tokens_per_block, head_size].
|
||||
// BlockManager maintains a list of free blocks at any time.
|
||||
// Alloc pops off the block at the front, and Free pushes it back to the vector.
|
||||
// BlockManager maintains a vector of lists of batchSlotIdx to allocated blocks
|
||||
// BlockManager maintains a vector of lists of seqSlotIdx to allocated blocks
|
||||
// per sequence. This can be used to Free all blocks belonging to a sequence.
|
||||
class BlockManager
|
||||
{
|
||||
public:
|
||||
using SizeType = tensorrt_llm::runtime::SizeType;
|
||||
|
||||
explicit BlockManager(SizeType blocksInPool);
|
||||
explicit BlockManager(SizeType blocksInPool, SizeType tokensPerBlock);
|
||||
|
||||
~BlockManager();
|
||||
|
||||
void startScheduling();
|
||||
|
||||
//! \brief Assign blocks for new sequence. Try to reuse blocks.
|
||||
void addSequence(GenerationRequest& sequence, std::shared_ptr<LlmRequest> const& llmRequest);
|
||||
|
||||
//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
|
||||
void addSequence(GenerationRequest& sequence, SizeType inputLength, bool enableCyclicKvCache);
|
||||
|
||||
//! \brief Allocate new block for each beam of the sequence.
|
||||
//! \details Might free cached blocks if no free blocks are available.
|
||||
void allocateBlock(GenerationRequest& sequence, bool shareAmongBeams = false);
|
||||
|
||||
void freeAllBlocks(GenerationRequest& sequence);
|
||||
//! \brief Release blocks of the sequence. Store blocks for reuse if llmReqeust is provided.
|
||||
void releaseBlocks(GenerationRequest& sequence, std::shared_ptr<LlmRequest> const& llmRequest = nullptr);
|
||||
|
||||
// Simulate freeing all blocks for that sequence to check impact on number of free blocks
|
||||
void schedulingFreeAllBlocks(GenerationRequest& sequence);
|
||||
//! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
|
||||
void schedulingReleaseBlocks(GenerationRequest& sequence);
|
||||
|
||||
[[nodiscard]] SizeType getNumFreeBlocks() const
|
||||
{
|
||||
@ -186,7 +276,7 @@ public:
|
||||
|
||||
[[nodiscard]] SizeType getNumAllocatedBlocks() const
|
||||
{
|
||||
return mAllocatedBlocks.size();
|
||||
return getMaxNumBlocks() - getNumFreeBlocks();
|
||||
}
|
||||
|
||||
[[nodiscard]] bool hasFreeBlocks(SizeType numRequired = 1) const
|
||||
@ -199,13 +289,58 @@ public:
|
||||
return mSchedulingNumFreeBlocks >= numRequired;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType getMaxNumBlocks() const
|
||||
{
|
||||
return static_cast<SizeType>(mAllBlocksByIdx.size());
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType getTokensPerBlock() const
|
||||
{
|
||||
return mTokensPerBlock;
|
||||
}
|
||||
|
||||
private:
|
||||
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
|
||||
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType beamIdx, SizeType seqSlotIdx);
|
||||
|
||||
//! \brief Store blocks in cached blocks.
|
||||
//! \param blockedTokens Tokens of each block.
|
||||
//! \param blockIds Id of each block.
|
||||
void storeBlocks(std::list<VecTokens> blockedTokens, std::vector<SizeType> const& blockIds);
|
||||
|
||||
//! \brief Try to load blocks from cache. Allocate new blocks if necessary.
|
||||
//! \param blockedTokens Tokens of each block.
|
||||
//! \param sequence Sequence to which blocks are assigned.
|
||||
//! \param beamIdx Beam of sequence to which blocks are assigned.
|
||||
//! \param seqSlotIdx Batch slot of sequence to which blocks are assigned.
|
||||
//! \return Number of matched tokens from loaded blocks.
|
||||
SizeType loadOrAllocateBlocks(
|
||||
std::list<VecTokens> blockedTokens, GenerationRequest& sequence, SizeType beamIdx, SizeType seqSlotIdx);
|
||||
|
||||
//! \brief Find block least likely to be reused, free it if necessary and return.
|
||||
[[nodiscard]] BlockPtr getFreeBlock();
|
||||
|
||||
//! \brief Claim block if it is in free blocks list.
|
||||
void claimBlock(KVCacheBlock& block);
|
||||
|
||||
//! \brief Free block from previous block and claim it from free blocks list.
|
||||
void claimLeafBlock(KVCacheBlock& block);
|
||||
|
||||
private:
|
||||
// List of free blocks
|
||||
std::list<KVCacheBlock> mFreeBlocks;
|
||||
FreeBlocksQueue mFreeBlocks;
|
||||
// List of allocated blocks for each sequences
|
||||
std::vector<std::vector<KVCacheBlock>> mAllocatedBlocks;
|
||||
std::vector<std::vector<BlockPtr>> mAllocatedBlocksPerSeq;
|
||||
// Used to keep track of number of free blocks during scheduling
|
||||
SizeType mSchedulingNumFreeBlocks;
|
||||
// Number of tokens per one block
|
||||
SizeType mTokensPerBlock;
|
||||
// List of all blocks by idx
|
||||
std::vector<BlockPtr> mAllBlocksByIdx;
|
||||
// Dummy block acting as root for BlockToken searches
|
||||
BlockPtr mCachedBlocksRoot;
|
||||
// Statistics for block allocations/reuse
|
||||
std::size_t mAllocTotalBlocks, mAllocNewBlocks, mReusedBlocks;
|
||||
};
|
||||
|
||||
class KVCacheManager
|
||||
@ -216,19 +351,20 @@ public:
|
||||
using CudaStreamPtr = std::shared_ptr<runtime::CudaStream>;
|
||||
|
||||
KVCacheManager(SizeType numLayers, SizeType numHeads, SizeType numKvHeads, SizeType hiddenSize,
|
||||
SizeType tokensPerBlock, SizeType maxNumBlocks, SizeType maxBatchSize, SizeType maxBeamWidth,
|
||||
SizeType maxBlocksPerSeq, SizeType maxKvCacheLength, nvinfer1::DataType dtype, CudaStreamPtr stream);
|
||||
SizeType tokensPerBlock, SizeType maxNumBlocks, SizeType maxNumSequences, SizeType maxBeamWidth,
|
||||
SizeType maxBlocksPerSeq, SizeType maxAttentionWindow, nvinfer1::DataType dtype, CudaStreamPtr stream,
|
||||
bool enableBlockReuse = false);
|
||||
|
||||
void startScheduling();
|
||||
|
||||
[[nodiscard]] SizeType getTokensPerBlock() const
|
||||
{
|
||||
return mTokensPerBlock;
|
||||
return mBlockManager.getTokensPerBlock();
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType getMaxNumBlocks() const
|
||||
{
|
||||
return mMaxNumBlocks;
|
||||
return mBlockManager.getMaxNumBlocks();
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType getUsedNumBlocks() const
|
||||
@ -267,32 +403,33 @@ public:
|
||||
/// iterations
|
||||
/// @param req The request for which we need to calculate the number of needed KV cache blocks
|
||||
/// @return The number of blocks
|
||||
SizeType getNeededBlocksOneStep(const LlmRequest& req, bool twoStepsLookAhead) const;
|
||||
SizeType getNeededBlocksOneStep(LlmRequest const& req, bool twoStepsLookAhead) const;
|
||||
|
||||
/// @brief Function that computes the number of KV cache blocks needed to advance a request to completion (i.e. for
|
||||
/// maxNewTokens)
|
||||
/// @param req The request for which we need to calculate the number of needed KV cache blocks
|
||||
/// @return The number of blocks
|
||||
SizeType getNeededBlocksToCompletion(const LlmRequest& req) const;
|
||||
SizeType getNeededBlocksToCompletion(LlmRequest const& req) const;
|
||||
|
||||
[[nodiscard]] std::vector<runtime::ITensor::SharedPtr> const& getMemoryPools() const
|
||||
{
|
||||
return mPools;
|
||||
}
|
||||
|
||||
void addToken(SizeType batchSlotIdx);
|
||||
void addToken(SizeType seqSlotIdx);
|
||||
|
||||
void addSequence(SizeType batchSlotIdx, SizeType inputLength, SizeType beamWidth);
|
||||
void addSequence(SizeType seqSlotIdx, SizeType inputLength, SizeType beamWidth,
|
||||
std::shared_ptr<LlmRequest> const& llmRequest = nullptr);
|
||||
|
||||
void removeSequence(SizeType batchSlotIdx);
|
||||
void removeSequence(SizeType seqSlotIdx, std::shared_ptr<LlmRequest> const& llmRequest = nullptr);
|
||||
|
||||
void schedulingRemoveSequence(SizeType batchSlotIdx);
|
||||
void schedulingRemoveSequence(SizeType seqSlotIdx);
|
||||
|
||||
void getBlockPointersOfBatch(
|
||||
runtime::ITensor& dstPointers, SizeType firstBatchSlotIdx, SizeType batchSize, SizeType beamWidth) const;
|
||||
|
||||
void copyBlockPointers(
|
||||
runtime::ITensor& dstPointers, SizeType dstSlotOffset, SizeType batchSlotIdx, SizeType beamWidth) const;
|
||||
runtime::ITensor& dstPointers, SizeType dstSlotOffset, SizeType seqSlotIdx, SizeType beamWidth) const;
|
||||
|
||||
// Volume of [2, numKvHeads, tokensPerBlock, sizePerHead]
|
||||
[[nodiscard]] static SizeType constexpr calculatePageSize(tensorrt_llm::runtime::GptModelConfig const& modelConfig)
|
||||
@ -313,26 +450,22 @@ public:
|
||||
runtime::BufferManager const& bufferManager);
|
||||
|
||||
private:
|
||||
void resetBlockPointers(SizeType batchSlotIdx, SizeType beamWidth);
|
||||
|
||||
void cacheNewBlockPointer(const GenerationRequest& seq, SizeType batchSlotIdx);
|
||||
void resetBlockPointers(SizeType seqSlotIdx, SizeType beamWidth);
|
||||
void cacheBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx);
|
||||
void cacheNewBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx);
|
||||
|
||||
private:
|
||||
// Number of elements per one blocks
|
||||
SizeType mBlockSize;
|
||||
// Number of tokens per one blocks
|
||||
SizeType mTokensPerBlock;
|
||||
// Total maximum number of blocks
|
||||
SizeType mMaxNumBlocks;
|
||||
// Maximum size of batch
|
||||
SizeType mMaxBatchSize;
|
||||
// Maximum number of sequences
|
||||
SizeType mMaxNumSequences;
|
||||
// Maximum beam width
|
||||
SizeType mMaxBeamWidth;
|
||||
// Maximum number of blocks per sequence
|
||||
SizeType mMaxBlocksPerSeq;
|
||||
// Maximum kv cache length per sequence
|
||||
// Enable cyclic kv cache when it exceeds
|
||||
SizeType mMaxKvCacheLength;
|
||||
SizeType mMaxAttentionWindow;
|
||||
// Pools
|
||||
std::vector<runtime::ITensor::SharedPtr> mPools;
|
||||
// Block manager
|
||||
@ -341,7 +474,9 @@ private:
|
||||
std::vector<SequencesPtr> mSequences;
|
||||
// buffer for block pointers for all managed sequences
|
||||
runtime::ITensor::SharedPtr mSequenceBlockPointers;
|
||||
|
||||
runtime::BufferManager mManager;
|
||||
// Buffer manager
|
||||
runtime::BufferManager mBufferManager;
|
||||
// Whether to cache KV pages for reuse
|
||||
bool mEnableBlockReuse;
|
||||
};
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
@ -48,11 +48,10 @@ public:
|
||||
using BeamTokens = std::vector<VecTokens>;
|
||||
using TensorPtr = TTensor;
|
||||
|
||||
GenericLlmRequest(RequestIdType requestId, SizeType maxNewTokens,
|
||||
std::shared_ptr<std::vector<TokenIdType>> inputTokens, runtime::SamplingConfig samplingConfig, bool isStreaming,
|
||||
std::optional<SizeType> endId = std::nullopt, std::optional<SizeType> padId = std::nullopt,
|
||||
std::optional<TensorPtr> embeddingBias = std::nullopt, std::optional<TensorPtr> badWordsList = std::nullopt,
|
||||
std::optional<TensorPtr> stopWordsList = std::nullopt,
|
||||
GenericLlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<VecTokens> inputTokens,
|
||||
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
|
||||
std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
|
||||
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
|
||||
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
|
||||
std::optional<SizeType> promptVocabSize = std::nullopt, bool returnLogProbs = false,
|
||||
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt,
|
||||
@ -65,7 +64,7 @@ public:
|
||||
, mIsStreaming(isStreaming)
|
||||
, mEndId(endId)
|
||||
, mPadId(padId)
|
||||
, mBatchSlot(-1)
|
||||
, mSeqSlot(-1)
|
||||
, mOrigPromptLen(inputTokens->size())
|
||||
, mEmbeddingBias(embeddingBias)
|
||||
, mBadWordsList(badWordsList)
|
||||
@ -132,14 +131,21 @@ public:
|
||||
/// @brief Get the tokens at a given beam index
|
||||
/// @param beam The beam index
|
||||
/// @return A vector of tokens for this beam index, includes the prompt
|
||||
std::vector<TokenIdType> const& getTokens(SizeType beam) const
|
||||
VecTokens const& getTokens(SizeType beam) const
|
||||
{
|
||||
return mTokens.at(beam);
|
||||
}
|
||||
|
||||
/// @brief Get all tokens (input+output) for all beams
|
||||
/// @return A vector of vector of tokens.
|
||||
BeamTokens const& getTokens() const
|
||||
{
|
||||
return mTokens;
|
||||
}
|
||||
|
||||
/// @brief Get the draft tokens
|
||||
/// @return shared_ptr to vector of draft tokens
|
||||
std::shared_ptr<std::vector<TokenIdType>> const& getDraftTokens() const
|
||||
std::shared_ptr<VecTokens> const& getDraftTokens() const
|
||||
{
|
||||
return mDraftTokens;
|
||||
}
|
||||
@ -176,7 +182,7 @@ public:
|
||||
/// @brief Add new generated tokens to the vector of tokens
|
||||
/// @param beamTokens A vector containing the tokens to add for each beam index
|
||||
/// beamTokens is expected to be of size beamWidth
|
||||
void addNewTokens(const std::vector<TokenIdType>& beamTokens)
|
||||
void addNewTokens(VecTokens const& beamTokens)
|
||||
{
|
||||
assert(static_cast<size_t>(mSamplingConfig.beamWidth) == beamTokens.size());
|
||||
for (std::size_t beam = 0; beam < beamTokens.size(); ++beam)
|
||||
@ -236,7 +242,7 @@ public:
|
||||
mPromptLen = newPromptLen;
|
||||
}
|
||||
mState = REQUEST_STATE_CONTEXT_INIT;
|
||||
mBatchSlot = -1;
|
||||
mSeqSlot = -1;
|
||||
}
|
||||
|
||||
/// @brief Get the maximum position of the tokens returned to the client. Use to ensure we don't return to
|
||||
@ -335,7 +341,7 @@ public:
|
||||
bool mIsStreaming;
|
||||
std::optional<SizeType> mEndId;
|
||||
std::optional<SizeType> mPadId;
|
||||
SizeType mBatchSlot;
|
||||
SizeType mSeqSlot;
|
||||
|
||||
protected:
|
||||
SizeType mOrigPromptLen;
|
||||
@ -369,7 +375,7 @@ public:
|
||||
using BeamTokens = Base::BeamTokens;
|
||||
using VecTokens = Base::VecTokens;
|
||||
|
||||
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<std::vector<TokenIdType>> inputTokens,
|
||||
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<VecTokens> inputTokens,
|
||||
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
|
||||
std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
|
||||
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
|
||||
|
||||
@ -34,11 +34,13 @@ public:
|
||||
|
||||
explicit TrtGptModelOptionalParams(KvCacheConfig const& kvCacheConfig = KvCacheConfig{},
|
||||
std::optional<SizeType> maxNumSequences = std::nullopt, bool enableTrtOverlap = true,
|
||||
bool useContextFMHAForGeneration = false)
|
||||
bool useContextFMHAForGeneration = false,
|
||||
std::optional<std::vector<SizeType>> userSpecifiedDeviceIds = std::nullopt)
|
||||
: kvCacheConfig{kvCacheConfig}
|
||||
, maxNumSequences{maxNumSequences}
|
||||
, enableTrtOverlap{enableTrtOverlap}
|
||||
, useContextFMHAForGeneration(useContextFMHAForGeneration)
|
||||
, userSpecifiedDeviceIds(userSpecifiedDeviceIds)
|
||||
{
|
||||
}
|
||||
|
||||
@ -46,6 +48,7 @@ public:
|
||||
std::optional<SizeType> maxNumSequences;
|
||||
bool enableTrtOverlap;
|
||||
bool useContextFMHAForGeneration;
|
||||
std::optional<std::vector<SizeType>> userSpecifiedDeviceIds;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -29,10 +29,11 @@ class DecodingInput
|
||||
public:
|
||||
using TensorPtr = std::shared_ptr<ITensor const>;
|
||||
|
||||
DecodingInput(SizeType maxLength, SizeType maxKvCacheLength, SizeType batchSize, TensorPtr logits, TensorPtr endIds)
|
||||
DecodingInput(
|
||||
SizeType maxLength, SizeType maxAttentionWindow, SizeType batchSize, TensorPtr logits, TensorPtr endIds)
|
||||
: step{maxLength}
|
||||
, maxLength{maxLength}
|
||||
, maxKvCacheLength{maxKvCacheLength}
|
||||
, maxAttentionWindow{maxAttentionWindow}
|
||||
, batchSize{batchSize}
|
||||
, logits{std::move(logits)}
|
||||
, endIds{std::move(endIds)}
|
||||
@ -44,7 +45,7 @@ public:
|
||||
// mandatory parameters
|
||||
SizeType step;
|
||||
SizeType maxLength;
|
||||
SizeType maxKvCacheLength;
|
||||
SizeType maxAttentionWindow;
|
||||
SizeType batchSize;
|
||||
TensorPtr logits; // [batchSize, beamWidth, vocabSizePadded], on gpu
|
||||
TensorPtr endIds; // [batchSize * beamWidth], on gpu
|
||||
|
||||
@ -45,7 +45,7 @@ public:
|
||||
GptDecoderBatch(std::size_t vocabSize, std::size_t vocabSizePadded, CudaStreamPtr stream);
|
||||
|
||||
//! Setup the decoder before calling `forward()`
|
||||
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength,
|
||||
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow, SizeType maxSequenceLength,
|
||||
SizeType maxTokensPerStep, nvinfer1::DataType dtype) override;
|
||||
|
||||
//! @brief Initialize the decoder at `batchIdx` with a new `request`.
|
||||
@ -200,7 +200,7 @@ private:
|
||||
TensorPtr mTargetProbs; // [batchSize, maxDraftTokens+1, beamWidth, vocabPadded], temporary data for speculative
|
||||
// decoding accept by logits kernel, on gpu
|
||||
SizeType mMaxSequenceLength{};
|
||||
SizeType mMaxKvCacheLength{};
|
||||
SizeType mMaxAttentionWindow{};
|
||||
SizeType mActualBatchSize{};
|
||||
SizeType mMaxTokensPerStep{};
|
||||
};
|
||||
|
||||
@ -32,9 +32,10 @@ namespace tensorrt_llm::runtime
|
||||
class GptJsonConfig
|
||||
{
|
||||
public:
|
||||
GptJsonConfig(std::string name, std::string precision, SizeType tensorParallelism, SizeType pipelineParallelism,
|
||||
GptModelConfig const& modelConfig)
|
||||
GptJsonConfig(std::string name, std::string version, std::string precision, SizeType tensorParallelism,
|
||||
SizeType pipelineParallelism, GptModelConfig const& modelConfig)
|
||||
: mName(std::move(name))
|
||||
, mVersion(std::move(version))
|
||||
, mPrecision(std::move(precision))
|
||||
, mTensorParallelism{tensorParallelism}
|
||||
, mPipelineParallelism{pipelineParallelism}
|
||||
@ -58,6 +59,11 @@ public:
|
||||
return mName;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string const& getVersion() const
|
||||
{
|
||||
return mVersion;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string const& getPrecision() const
|
||||
{
|
||||
return mPrecision;
|
||||
@ -87,6 +93,7 @@ public:
|
||||
|
||||
private:
|
||||
std::string const mName;
|
||||
std::string const mVersion;
|
||||
std::string const mPrecision;
|
||||
SizeType const mTensorParallelism;
|
||||
SizeType const mPipelineParallelism;
|
||||
|
||||
@ -46,6 +46,7 @@ public:
|
||||
, mTokensPerBlock{64}
|
||||
, mQuantMode{common::QuantMode::none()}
|
||||
, mMaxBatchSize(0)
|
||||
, mMaxBeamWidth(0)
|
||||
, mMaxInputLen(0)
|
||||
, mMaxOutputLen(0)
|
||||
, mMaxNumTokens(std::nullopt)
|
||||
@ -169,6 +170,16 @@ public:
|
||||
mMaxBatchSize = maxBatchSize;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType constexpr getMaxBeamWidth() const noexcept
|
||||
{
|
||||
return mMaxBeamWidth;
|
||||
}
|
||||
|
||||
void constexpr setMaxBeamWidth(SizeType maxBeamWidth) noexcept
|
||||
{
|
||||
mMaxBeamWidth = maxBeamWidth;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType constexpr getMaxInputLen() const noexcept
|
||||
{
|
||||
return mMaxInputLen;
|
||||
@ -259,6 +270,11 @@ public:
|
||||
mMaxDraftLen = maxDraftLen;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType getMaxDraftLen() const
|
||||
{
|
||||
return mMaxDraftLen;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType constexpr getMaxTokensPerStep() const noexcept
|
||||
{
|
||||
return mMaxDraftLen + 1;
|
||||
@ -277,6 +293,7 @@ private:
|
||||
SizeType mTokensPerBlock;
|
||||
common::QuantMode mQuantMode;
|
||||
SizeType mMaxBatchSize;
|
||||
SizeType mMaxBeamWidth;
|
||||
SizeType mMaxInputLen;
|
||||
SizeType mMaxOutputLen;
|
||||
std::optional<SizeType> mMaxNumTokens;
|
||||
|
||||
@ -143,9 +143,9 @@ private:
|
||||
|
||||
void createContexts(SizeType numBatchesCtx, SizeType numBatchesGen, bool useCudaGraphs);
|
||||
void createBuffers(SizeType numMicroBatches);
|
||||
void createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength,
|
||||
void createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow, SizeType maxSequenceLength,
|
||||
nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches);
|
||||
void createKvCacheManager(SizeType batchSize, SizeType beamWidth, SizeType maxKvCacheLength,
|
||||
void createKvCacheManager(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow,
|
||||
SizeType maxSequenceLength, KvCacheConfig const& config);
|
||||
void createCustomAllReduceWorkspace(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength);
|
||||
|
||||
@ -261,7 +261,7 @@ private:
|
||||
std::vector<std::shared_ptr<IpcMemory>> mIpcMemoryHandles;
|
||||
|
||||
SizeType mDecoderMaxSequenceLength{};
|
||||
SizeType mDecoderMaxKvCacheLength{};
|
||||
SizeType mDecoderMaxAttentionWindow{};
|
||||
|
||||
LoggerPtr mLogger;
|
||||
std::shared_ptr<TllmRuntime> mRuntime;
|
||||
|
||||
@ -74,7 +74,7 @@ public:
|
||||
using TensorPtr = std::shared_ptr<ITensor>;
|
||||
|
||||
//! Setup the decoder before calling `forward()`, also calls reshapeBuffers
|
||||
virtual void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength,
|
||||
virtual void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow,
|
||||
SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype)
|
||||
= 0;
|
||||
|
||||
|
||||
@ -29,12 +29,13 @@ class WorldConfig
|
||||
public:
|
||||
static SizeType constexpr kDefaultGpusPerNode = 8;
|
||||
|
||||
constexpr explicit WorldConfig(SizeType tensorParallelism = 1, SizeType pipelineParallelism = 1, SizeType rank = 0,
|
||||
SizeType gpusPerNode = kDefaultGpusPerNode)
|
||||
explicit WorldConfig(SizeType tensorParallelism = 1, SizeType pipelineParallelism = 1, SizeType rank = 0,
|
||||
SizeType gpusPerNode = kDefaultGpusPerNode, std::vector<SizeType> deviceIds = {})
|
||||
: mTensorParallelism{tensorParallelism}
|
||||
, mPipelineParallelism{pipelineParallelism}
|
||||
, mRank{rank}
|
||||
, mGpusPerNode{gpusPerNode}
|
||||
, mDeviceIds{deviceIds}
|
||||
{
|
||||
}
|
||||
|
||||
@ -73,8 +74,12 @@ public:
|
||||
return mGpusPerNode;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType constexpr getDevice() const noexcept
|
||||
[[nodiscard]] SizeType getDevice() const noexcept
|
||||
{
|
||||
if (mDeviceIds.size())
|
||||
{
|
||||
return mDeviceIds[mRank % mGpusPerNode];
|
||||
}
|
||||
return mRank % mGpusPerNode;
|
||||
}
|
||||
|
||||
@ -110,17 +115,20 @@ public:
|
||||
|
||||
static WorldConfig mpi(nvinfer1::ILogger& logger, SizeType gpusPerNode = kDefaultGpusPerNode,
|
||||
std::optional<SizeType> tensorParallelism = std::nullopt,
|
||||
std::optional<SizeType> pipelineParallelism = std::nullopt);
|
||||
std::optional<SizeType> pipelineParallelism = std::nullopt,
|
||||
std::optional<std::vector<SizeType>> userSpecifiedDeviceIds = std::nullopt);
|
||||
|
||||
static WorldConfig mpi(SizeType gpusPerNode = kDefaultGpusPerNode,
|
||||
std::optional<SizeType> tensorParallelism = std::nullopt,
|
||||
std::optional<SizeType> pipelineParallelism = std::nullopt);
|
||||
std::optional<SizeType> pipelineParallelism = std::nullopt,
|
||||
std::optional<std::vector<SizeType>> userSpecifiedDeviceIds = std::nullopt);
|
||||
|
||||
private:
|
||||
SizeType mTensorParallelism;
|
||||
SizeType mPipelineParallelism;
|
||||
SizeType mRank;
|
||||
SizeType mGpusPerNode;
|
||||
std::vector<SizeType> mDeviceIds;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ba982afff27c597c9f5f25bec4ed37debd883c7be2107b47776a014075899fbd
|
||||
size 1719266
|
||||
oid sha256:7d9f7d0f7dee2c48a424ff8873c2fd1298a27850f870657734641f2eb1190faf
|
||||
size 1791038
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:04ec1f2f45dde1ef6b6b0f605e79715eebed38b19b4d833fcb668d2cb71f8a03
|
||||
size 1733118
|
||||
oid sha256:fa79a0d563fc01a0cb2fe94dcb626ff4e5b736284d9244313cbe7aa0261dd48e
|
||||
size 1806500
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
aab384dfc59de5df4c7ecf53e30d03e9 libtensorrt_llm_batch_manager_static.a
|
||||
e0074afa6959c896f1cbc7ab90872058 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
c7450aa071e91659a3e2855c0cca21021f96ada8 commit
|
||||
d9723ab671c9fc3889cc624a58def81a libtensorrt_llm_batch_manager_static.a
|
||||
4b6773c990e8a59f1c716d88505b84a2 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
9a136bb59c51bbae09221c1667e23529ed05c752 commit
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:546c9e2b79cb3cf2623876902ef2d40c65925157d43850b2505eedf274e060a1
|
||||
size 1638840
|
||||
oid sha256:6a7b872fe6ee63a4342c3cd17b3557d74c72e537dbf0d4ddf132a2c40e000e57
|
||||
size 1709462
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:935a706ce0d107f8c226566a50946a0f0e35ce926c98b7a12b000b3d72e5f0b6
|
||||
size 1635602
|
||||
oid sha256:c83f7c0e4fc22b32df669ada2b99b88f0f7faac935a251fe7a20030e2b364cc8
|
||||
size 1705432
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
8e0c5b31d579f4118b84a34ffb00c15a libtensorrt_llm_batch_manager_static.a
|
||||
885ea7b9f594d7aa9cc9018527b95f6d libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
583141c3003a08acebc7054d024bee89 libtensorrt_llm_batch_manager_static.a
|
||||
03e5360f9b8074b8273500898581212f libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
|
||||
@ -102,12 +102,12 @@ struct Multihead_attention_params_base
|
||||
int batch_size = 0;
|
||||
// The beam width
|
||||
int beam_width = 0;
|
||||
// By default, max_kv_cache_length == cyclic_kv_cache_length
|
||||
// By default, max_attention_window_size == cyclic_attention_window_size
|
||||
// unless each layer has different cyclic kv cache length.
|
||||
// Max cache capacity (used to allocate KV cache)
|
||||
int max_kv_cache_length = 0;
|
||||
int max_attention_window_size = 0;
|
||||
// Cyclic kv cache capacity (used to get the cyclic kv cache position for new tokens)
|
||||
int cyclic_kv_cache_length = 0;
|
||||
int cyclic_attention_window_size = 0;
|
||||
// The number of heads (H).
|
||||
int num_heads = 0;
|
||||
// Controls MHA/MQA/GQA
|
||||
|
||||
@ -19,15 +19,27 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
// clang-format off
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_90_cubin[];
|
||||
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_90_cubin_len;
|
||||
|
||||
|
||||
static const struct XQAKernelMetaInfo
|
||||
@ -42,10 +54,16 @@ static const struct XQAKernelMetaInfo
|
||||
unsigned int mCubinSize;
|
||||
const char* mFuncName;
|
||||
} sXqaKernelMetaInfo[] = {
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_sm_90_cubin_len, "kernel_mha"}
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_90_cubin_len, "kernel_mha"}
|
||||
};
|
||||
|
||||
// clang-format on
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -44,8 +44,8 @@ inline size_t smem_size_in_bytes(const Multihead_attention_params<T, DO_CROSS_AT
|
||||
using Tk = typename kernel_type_t<T>::Type;
|
||||
// The amount of shared memory needed to store the Q*K^T values in float.
|
||||
const int max_timesteps = DO_CROSS_ATTENTION
|
||||
? params.cyclic_kv_cache_length
|
||||
: min((DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep), params.cyclic_kv_cache_length);
|
||||
? params.cyclic_attention_window_size
|
||||
: min((DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep), params.cyclic_attention_window_size);
|
||||
const auto qk_elts = static_cast<std::size_t>(divUp(max_timesteps + 1, 4)); // explicit cast because of the sign
|
||||
const auto qk_sz = qk_elts * 16;
|
||||
|
||||
@ -110,6 +110,9 @@ inline void multi_block_grid_setup(dim3& grid, const Multihead_attention_params<
|
||||
|
||||
params.seq_len_tile = std::min(params.seq_len_tile, max_seq_len_tile);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
params.seq_len_tile <= block_size, "The number of blocks per sequence may not exceed the thread block size.");
|
||||
|
||||
// We should consider the new timestep.
|
||||
params.timesteps_per_block = mmha::divUp(tlength + 1, params.seq_len_tile);
|
||||
|
||||
|
||||
@ -1273,9 +1273,9 @@ __global__ void masked_multihead_attention_kernel(
|
||||
|
||||
// The maximum sequence length in the cyclic kv_cache, i.e., an upper bound on L.
|
||||
// Note that the maximum sequence length supported by the model might be greater than this.
|
||||
// Note max_kv_cache_length is maximum of cyclic_kv_cache_length among all layers.
|
||||
// Note max_attention_window_size is maximum of cyclic_attention_window_size among all layers.
|
||||
// By default, you can assume that they are the same.
|
||||
const auto cyclic_kv_cache_len = static_cast<unsigned>(params.cyclic_kv_cache_length);
|
||||
const auto cyclic_kv_cache_len = static_cast<unsigned>(params.cyclic_attention_window_size);
|
||||
// The current timestep (including paddings).
|
||||
// It is only used to calculate the smem stride.
|
||||
const auto timestep = static_cast<unsigned>(DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep);
|
||||
@ -1796,11 +1796,12 @@ __global__ void masked_multihead_attention_kernel(
|
||||
: divUp(static_cast<unsigned>(kv_loop_length), K_PER_WARP) * K_PER_WARP;
|
||||
|
||||
// Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
|
||||
// Note max_kv_cache_length is maximum of cyclic_kv_cache_length among all layers.
|
||||
// Note max_attention_window_size is maximum of cyclic_attention_window_size among all layers.
|
||||
// By default, you can assume that they are the same.
|
||||
const auto bi_seq_len_offset = static_cast<std::size_t>(batch_beam_idx) * params.max_kv_cache_length;
|
||||
// Beam indices are based on the max_kv_cache_length while each layer may have different cyclic_kv_cache_length
|
||||
// So we need to rebuild the beam_indices if max_kv_cache_length is not equal to cyclic_kv_cache_length.
|
||||
const auto bi_seq_len_offset = static_cast<std::size_t>(batch_beam_idx) * params.max_attention_window_size;
|
||||
// Beam indices are based on the max_attention_window_size while each layer may have different
|
||||
// cyclic_attention_window_size So we need to rebuild the beam_indices if max_attention_window_size is not equal to
|
||||
// cyclic_attention_window_size.
|
||||
const int* beam_indices = HAS_BEAMS ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr;
|
||||
|
||||
const auto c_tile_times_timesteps_per_block = c_tile * timesteps_per_block; // 0 if !MULTI_BLOCK_FLAG
|
||||
@ -2596,39 +2597,38 @@ __global__ void masked_multihead_attention_kernel(
|
||||
T* out_oi_smem = reinterpret_cast<T*>(smem_);
|
||||
|
||||
const auto o_idx = chunk_index<T, V_vec_k, THREADS_PER_VALUE>(tidx);
|
||||
// The partial output region this thread takes care of
|
||||
const auto oo = o_idx.x;
|
||||
|
||||
// Init partial out for accumulation.
|
||||
V_vec_k zero_k;
|
||||
zero(zero_k);
|
||||
V_vec_k thread_accumulated_out = zero_k;
|
||||
|
||||
// The hidden dimensions computed by this particular thread. (refer to vi)
|
||||
const auto oi = o_idx.y;
|
||||
|
||||
// Within the bound.
|
||||
const bool within_bound = oo < gridDim.z;
|
||||
// The partial output region this thread takes care of
|
||||
const auto oo = o_idx.x;
|
||||
|
||||
// Load partial output
|
||||
int thread_partial_out_offset = oo * params.batch_size * num_heads * params.hidden_size_per_head;
|
||||
// Load partial max (different to thread_partial_max since the threadIdx rule changes here)
|
||||
float thread_partial_max_for_out = within_bound ? params.partial_max[bhi_seq_len_tile + oo] : final_max;
|
||||
|
||||
// Load the partial outputs.
|
||||
V_vec_k zero_k;
|
||||
zero(zero_k);
|
||||
V_vec_k thread_partial_out = within_bound
|
||||
? *reinterpret_cast<const V_vec_k*>(¶ms.partial_out[thread_partial_out_offset + bhi * Dh + oi])
|
||||
: zero_k;
|
||||
|
||||
Tk factor_compute;
|
||||
convert_from_float(&factor_compute, __expf(thread_partial_max_for_out - final_max));
|
||||
thread_partial_out = mul<V_vec_k, Tk, V_vec_k>(factor_compute, thread_partial_out);
|
||||
|
||||
// Make sure we can start writing to shared memory.
|
||||
__syncthreads();
|
||||
|
||||
// The reduction iteration should start with a number which is a power of 2
|
||||
const auto reduction_iteration = static_cast<int>(cuda::std::bit_ceil(gridDim.z));
|
||||
// Each thread may handle more than one partial output.
|
||||
for (int tile_idx = o_idx.x; tile_idx < gridDim.z; tile_idx += V_PER_ITER)
|
||||
{
|
||||
// Load partial output
|
||||
int thread_partial_out_offset = tile_idx * params.batch_size * num_heads * params.hidden_size_per_head;
|
||||
// Load partial max (different to thread_partial_max since the threadIdx rule changes here)
|
||||
float thread_partial_max_for_out = params.partial_max[bhi_seq_len_tile + tile_idx];
|
||||
// Load the partial outputs.
|
||||
V_vec_k thread_partial_out
|
||||
= *reinterpret_cast<const V_vec_k*>(¶ms.partial_out[thread_partial_out_offset + bhi * Dh + oi]);
|
||||
// Apply the correction factor.
|
||||
Tk factor_compute;
|
||||
convert_from_float(&factor_compute, __expf(thread_partial_max_for_out - final_max));
|
||||
thread_partial_out = mul<V_vec_k, Tk, V_vec_k>(factor_compute, thread_partial_out);
|
||||
thread_accumulated_out = add(thread_partial_out, thread_accumulated_out);
|
||||
}
|
||||
|
||||
// Run the final reduction amongst the different groups computing different partial outputs.
|
||||
#pragma unroll
|
||||
for (int active_groups = reduction_iteration; active_groups >= 2; active_groups /= 2)
|
||||
for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2)
|
||||
{
|
||||
|
||||
// The midpoint in the number of active groups.
|
||||
@ -2637,15 +2637,15 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// The upper part of active threads store to shared memory.
|
||||
if (oo >= midpoint && oo < active_groups && (Dh == Dh_MAX || oi < Dh))
|
||||
{
|
||||
*reinterpret_cast<V_vec_k*>(&out_oi_smem[(oo - midpoint) * Dh + oi]) = thread_partial_out;
|
||||
*reinterpret_cast<V_vec_k*>(&out_oi_smem[(oo - midpoint) * Dh + oi]) = thread_accumulated_out;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// The bottom warps update their values.
|
||||
if (oo < midpoint && (Dh == Dh_MAX || oi < Dh))
|
||||
{
|
||||
thread_partial_out
|
||||
= add(thread_partial_out, *reinterpret_cast<const V_vec_k*>(&out_oi_smem[oo * Dh + oi]));
|
||||
thread_accumulated_out
|
||||
= add(thread_accumulated_out, *reinterpret_cast<const V_vec_k*>(&out_oi_smem[oo * Dh + oi]));
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
@ -2661,8 +2661,8 @@ __global__ void masked_multihead_attention_kernel(
|
||||
Tk inv_sum_compute;
|
||||
convert_from_float(&inv_sum_compute, inv_sum);
|
||||
|
||||
thread_partial_out = mul<V_vec_k, Tk, V_vec_k>(inv_sum_compute, thread_partial_out);
|
||||
*reinterpret_cast<V_vec_k*>(¶ms.out[bhi * Dh + oi]) = thread_partial_out;
|
||||
thread_accumulated_out = mul<V_vec_k, Tk, V_vec_k>(inv_sum_compute, thread_accumulated_out);
|
||||
*reinterpret_cast<V_vec_k*>(¶ms.out[bhi * Dh + oi]) = thread_accumulated_out;
|
||||
}
|
||||
|
||||
// Reset qk_current_smem and block_counter for the next timestep
|
||||
|
||||
@ -110,11 +110,11 @@ public:
|
||||
const void* qkv;
|
||||
#ifdef USE_KV_SCALE
|
||||
const float* kv_scale_orig_quant = nullptr;
|
||||
const float* kv_scale_quant_orig = nullptr;
|
||||
#endif
|
||||
// Max 3K size
|
||||
KVCache<HasBeam> cacheList[kMaxBatchSizePerWave];
|
||||
int batch_size;
|
||||
const float* kv_scale_quant_orig = nullptr;
|
||||
void* scratch = nullptr;
|
||||
};
|
||||
|
||||
@ -144,21 +144,42 @@ int buildXQALaunchParams(XQALaunchParam<HasBeam>& launchParams, const XQAParams&
|
||||
launchParams.num_k_heads = params.num_kv_heads;
|
||||
#ifdef USE_KV_SCALE
|
||||
launchParams.kv_scale_orig_quant = params.kv_scale_orig_quant;
|
||||
launchParams.kv_scale_quant_orig = params.kv_scale_quant_orig;
|
||||
#endif
|
||||
launchParams.kv_scale_quant_orig = params.kv_scale_quant_orig;
|
||||
launchParams.batch_size = micro_batch_size;
|
||||
launchParams.scratch = params.workspaces;
|
||||
|
||||
int max_context_length = 0;
|
||||
int max_past_kv_length = 0;
|
||||
if (params.host_context_lengths)
|
||||
{
|
||||
// TODO: remove this logic, maybe use xqaParams.sequence_lengths inside kernel.
|
||||
max_context_length
|
||||
= *std::max_element(params.host_context_lengths, params.host_context_lengths + params.batch_size);
|
||||
max_past_kv_length = *std::max_element(
|
||||
params.host_past_key_value_lengths, params.host_past_key_value_lengths + params.batch_size);
|
||||
}
|
||||
for (int i = 0; i < micro_batch_size; i++)
|
||||
{
|
||||
int batch_idx = start_batch_idx + i;
|
||||
launchParams.cacheList[i].data = kv_linear_buffer.getKBlockPtr(batch_idx * params.beam_width, 0);
|
||||
// the kernel_mha use KV from KVCache, so need plus 1 here.
|
||||
launchParams.cacheList[i].size = params.host_past_key_value_lengths[batch_idx] + 1;
|
||||
launchParams.cacheList[i].capacity = params.max_kv_cache_length;
|
||||
int current_len = 0;
|
||||
// TODO: remove this logic, maybe use xqaParams.sequence_lengths inside kernel.
|
||||
if (params.host_context_lengths)
|
||||
{
|
||||
// the kernel_mha use KV from KVCache, so need plus 1 here.
|
||||
current_len = params.host_context_lengths[batch_idx] + max_past_kv_length - max_context_length + 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
current_len = params.host_past_key_value_lengths[batch_idx] + 1;
|
||||
}
|
||||
launchParams.cacheList[i].size = current_len;
|
||||
launchParams.cacheList[i].capacity = params.max_attention_window_size;
|
||||
if constexpr (HasBeam)
|
||||
{
|
||||
launchParams.cacheList[i].cacheInDir
|
||||
= params.cache_indir + batch_idx * params.beam_width * params.max_kv_cache_length;
|
||||
= params.cache_indir + batch_idx * params.beam_width * params.max_attention_window_size;
|
||||
}
|
||||
}
|
||||
return micro_batch_size;
|
||||
@ -175,6 +196,14 @@ public:
|
||||
, mKernelMeta(&sXqaKernelMetaInfo[0])
|
||||
, mSM(sm)
|
||||
{
|
||||
const char* enable_xqa_env_var = getenv("TRTLLM_FORCE_XQA");
|
||||
if (enable_xqa_env_var != nullptr)
|
||||
{
|
||||
if (enable_xqa_env_var[0] == '1' && enable_xqa_env_var[1] == '\0')
|
||||
{
|
||||
mForceXQA = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void loadXQAKernels()
|
||||
@ -238,8 +267,12 @@ public:
|
||||
return findIter != mFunctions.end();
|
||||
}
|
||||
|
||||
static bool mayHavePerfGain(const XQAParams& xqaParams, int multiprocessor_count)
|
||||
bool mayHavePerfGain(const XQAParams& xqaParams, int multiprocessor_count) const
|
||||
{
|
||||
if (mForceXQA)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
int num_kv_heads = xqaParams.num_kv_heads;
|
||||
int batch_size = static_cast<int>(xqaParams.batch_size);
|
||||
int multi_block_count = 1;
|
||||
@ -270,7 +303,7 @@ public:
|
||||
|
||||
invokeApplyBiasRopeUpdateKVCache<T, KVLinearBuffer, true>(static_cast<T*>(const_cast<void*>(xqaParams.qkv)),
|
||||
nullptr, kv_linear_buffer, static_cast<const T*>(xqaParams.qkv_bias), xqaParams.sequence_lengths, nullptr,
|
||||
nullptr, xqaParams.batch_size, 1, xqaParams.cyclic_kv_cache_length, xqaParams.batch_size * beam_width,
|
||||
nullptr, xqaParams.batch_size, 1, xqaParams.cyclic_attention_window_size, xqaParams.batch_size * beam_width,
|
||||
xqaParams.num_q_heads, xqaParams.num_kv_heads, xqaParams.head_size, xqaParams.rotary_embedding_dim,
|
||||
xqaParams.rotary_embedding_base, xqaParams.rotary_embedding_scale_type, xqaParams.rotary_embedding_scale,
|
||||
xqaParams.rotary_embedding_max_positions, xqaParams.position_embedding_type, (float*) nullptr, 0,
|
||||
@ -291,8 +324,9 @@ public:
|
||||
XQALaunchParam<HAS_BEAM> launchParams;
|
||||
int micro_batch_size = buildXQALaunchParams(
|
||||
launchParams, xqaParams, kv_linear_buffer, start_batch_idx, multiprocessor_count);
|
||||
void* kernelParams[] = {&launchParams.num_k_heads, &launchParams.output, &launchParams.qkv,
|
||||
&launchParams.cacheList, &launchParams.batch_size, &launchParams.scratch, nullptr};
|
||||
void* kernelParams[]
|
||||
= {&launchParams.num_k_heads, &launchParams.output, &launchParams.qkv, &launchParams.cacheList,
|
||||
&launchParams.batch_size, &launchParams.kv_scale_quant_orig, &launchParams.scratch, nullptr};
|
||||
int multi_block = 1;
|
||||
if (xqaParams.multi_block_mode)
|
||||
{
|
||||
@ -352,6 +386,8 @@ protected:
|
||||
unsigned int mSM;
|
||||
std::unordered_map<const unsigned long long*, CUmodule> mModules;
|
||||
|
||||
bool mForceXQA = false;
|
||||
|
||||
struct XQAKernelFuncInfo
|
||||
{
|
||||
unsigned int mMetaInfoIndex;
|
||||
@ -420,7 +456,7 @@ public:
|
||||
|
||||
bool shouldUse(const XQAParams& xqaParams)
|
||||
{
|
||||
return xqaKernel->supportConfig(xqaParams) && XQAKernelList::mayHavePerfGain(xqaParams, mMultiProcessorCount);
|
||||
return xqaKernel->supportConfig(xqaParams) && xqaKernel->mayHavePerfGain(xqaParams, mMultiProcessorCount);
|
||||
}
|
||||
|
||||
void run(const XQAParams& xqa_params, KVLinearBuffer& kv_linear_buffer, const cudaStream_t& stream,
|
||||
|
||||
@ -57,11 +57,12 @@ struct XQAParams
|
||||
const float* kv_scale_orig_quant = nullptr;
|
||||
const float* kv_scale_quant_orig = nullptr;
|
||||
const int32_t* host_past_key_value_lengths = nullptr;
|
||||
const int32_t* host_context_lengths = nullptr;
|
||||
void* workspaces = nullptr;
|
||||
uint32_t batch_size = 0;
|
||||
int32_t beam_width = 0;
|
||||
int32_t max_kv_cache_length = 0;
|
||||
int32_t cyclic_kv_cache_length = 0;
|
||||
int32_t max_attention_window_size = 0;
|
||||
int32_t cyclic_attention_window_size = 0;
|
||||
int timestep = 0;
|
||||
const void* qkv_bias;
|
||||
const int32_t* sequence_lengths; //
|
||||
@ -128,21 +129,10 @@ public:
|
||||
SUPPORT_RETURN_FALSE("rotary_embedding_base");
|
||||
if (xqaParams.rotary_embedding_scale_type != tensorrt_llm::kernels::RotaryScalingType::kNONE)
|
||||
SUPPORT_RETURN_FALSE("rotary_embedding_scale_type");
|
||||
if (xqaParams.rotary_embedding_scale != 1.0f)
|
||||
SUPPORT_RETURN_FALSE("rotary_embedding_scale");
|
||||
if (xqaParams.position_embedding_type != tensorrt_llm::kernels::PositionEmbeddingType::kROPE_GPT_NEOX)
|
||||
SUPPORT_RETURN_FALSE("position_embedding_type");
|
||||
// xqaParams.remove_padding
|
||||
if (xqaParams.mask_type != tensorrt_llm::kernels::AttentionMaskType::CAUSAL)
|
||||
SUPPORT_RETURN_FALSE("mask_type");
|
||||
if (xqaParams.paged_kv_cache)
|
||||
SUPPORT_RETURN_FALSE("paged_kv_cache");
|
||||
if (xqaParams.kv_cache_quant_mode != tensorrt_llm::common::QuantMode::int8KvCache()
|
||||
&& xqaParams.kv_cache_quant_mode != tensorrt_llm::common::QuantMode::fp8KvCache()
|
||||
&& xqaParams.kv_cache_quant_mode != tensorrt_llm::common::QuantMode::none())
|
||||
SUPPORT_RETURN_FALSE("kv_cache_quant_mode");
|
||||
if (xqaParams.kv_cache_quant_mode != tensorrt_llm::common::QuantMode::none())
|
||||
SUPPORT_RETURN_FALSE("kv_cache_quant_mode");
|
||||
if (xqaParams.qkv_bias_enabled)
|
||||
SUPPORT_RETURN_FALSE("qkv_bias_enabled");
|
||||
if (xqaParams.cross_attention)
|
||||
@ -152,8 +142,8 @@ public:
|
||||
SUPPORT_RETURN_FALSE("host_past_key_value_lengths");
|
||||
if (xqaParams.beam_width != 1)
|
||||
SUPPORT_RETURN_FALSE("beam_width");
|
||||
if (xqaParams.cyclic_kv_cache_length != xqaParams.max_kv_cache_length)
|
||||
SUPPORT_RETURN_FALSE("cyclic_kv_cache_length != max_kv_cache_length");
|
||||
if (xqaParams.cyclic_attention_window_size != xqaParams.max_attention_window_size)
|
||||
SUPPORT_RETURN_FALSE("cyclic_attention_window_size != max_attention_window_size");
|
||||
return shouldUseImpl(xqaParams);
|
||||
}
|
||||
|
||||
|
||||
@ -134,7 +134,7 @@ __global__ void computePaddingOffsets(int* paddingOffsets, const int* seqOffsets
|
||||
|
||||
template <typename AttentionMaskDataType>
|
||||
__global__ void computeAttentionMask(AttentionMaskDataType* attentionMask, const int* seqOffsets, int maxSeqLength,
|
||||
int maxKvCacheLength, AttentionMaskType attentionMaskType)
|
||||
int attentionWindowSize, AttentionMaskType attentionMaskType)
|
||||
{
|
||||
// The index of the sequence in the batch.
|
||||
int batchIdx = blockIdx.y;
|
||||
@ -174,7 +174,7 @@ __global__ void computeAttentionMask(AttentionMaskDataType* attentionMask, const
|
||||
case AttentionMaskType::CAUSAL:
|
||||
isValid = rowIdx < seqLength && colIdx < seqLength && colIdx <= rowIdx;
|
||||
// Sliding_window_causal when there are not enough kv cache.
|
||||
isValid = isValid && colIdx >= max(0, rowIdx - maxKvCacheLength);
|
||||
isValid = isValid && colIdx >= max(0, rowIdx - attentionWindowSize);
|
||||
// seq_length==4, max_seq_len==5
|
||||
// 1 0 0 0 0
|
||||
// 1 1 0 0 0
|
||||
@ -182,7 +182,7 @@ __global__ void computeAttentionMask(AttentionMaskDataType* attentionMask, const
|
||||
// 1 1 1 1 0
|
||||
// 0 0 0 0 0
|
||||
|
||||
// seq_length==6, max_seq_len==6, max_kv_cache_length = 2
|
||||
// seq_length==6, max_seq_len==6, max_attention_window_size = 2
|
||||
// 1 0 0 0 0 0
|
||||
// 1 1 0 0 0 0
|
||||
// 1 1 1 0 0 0
|
||||
@ -248,7 +248,7 @@ void invokeBuildDecoderInfo(const BuildDecoderInfoParams<T>& params, cudaStream_
|
||||
}
|
||||
dim3 grid(blocksPerSeq, params.batchSize);
|
||||
computeAttentionMask<<<grid, THREADS_PER_BLOCK, 0, stream>>>(params.attentionMask, params.seqQOffsets,
|
||||
params.maxSeqLength, params.maxKvCacheLength, params.attentionMaskType);
|
||||
params.maxSeqLength, params.attentionWindowSize, params.attentionMaskType);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -82,7 +82,7 @@ struct BuildDecoderInfoParams
|
||||
int maxSeqLength;
|
||||
// The kv cache capacity.
|
||||
// We will apply the limited_length_causal mask when there are not enough kv cache.
|
||||
int maxKvCacheLength;
|
||||
int attentionWindowSize;
|
||||
// The number of tokens in total. It's \sum_{ii=0}^{batchSize} seqLengths[ii].
|
||||
int numTokens;
|
||||
// The type of attention.
|
||||
|
||||
160
cpp/tensorrt_llm/kernels/groupGemm.cu
Normal file
160
cpp/tensorrt_llm/kernels/groupGemm.cu
Normal file
@ -0,0 +1,160 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/gemm_grouped.h"
|
||||
|
||||
#include "groupGemm.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
template <int M1, int N1, int K1, int M2, int N2, int K2>
|
||||
void run_cutlass_(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA, std::vector<void*> ptrB,
|
||||
std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize, void* cublasWorkSpace,
|
||||
int64_t cublasWorkspaceSize, cudaStream_t stream)
|
||||
{
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
|
||||
const int kAlignmentA = 8;
|
||||
const int kAlignmentB = 8;
|
||||
|
||||
int problem_count = problem_sizes.size();
|
||||
|
||||
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<ElementA, LayoutA,
|
||||
cutlass::ComplexTransform::kNone, kAlignmentA, ElementB, LayoutB, cutlass::ComplexTransform::kNone, kAlignmentB,
|
||||
ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<M1, N1, K1>, cutlass::gemm::GemmShape<M2, N2, K2>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator, ElementAccumulator>,
|
||||
// NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
|
||||
// This parameter is passed in at present to match the APIs of other kernels. The parameter
|
||||
// is unused within the kernel.
|
||||
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
||||
4, // kStages
|
||||
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>::GemmKernel;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
typename Gemm::EpilogueOutputOp::Params epilogue_op(alpha, beta);
|
||||
|
||||
auto gemm_coord_size = tensorrt_llm::common::divUp(problem_count * sizeof(cutlass::gemm::GemmCoord), 16) * 16;
|
||||
auto ptr_size = tensorrt_llm::common::divUp(problem_count * sizeof(half*), 16) * 16;
|
||||
auto ldd_size = tensorrt_llm::common::divUp(problem_count * sizeof(int64_t), 16) * 16;
|
||||
|
||||
char* host_workspace = (char*) std::malloc(workSpaceSize);
|
||||
cutlass::gemm::GemmCoord* problem_sizes_host = reinterpret_cast<cutlass::gemm::GemmCoord*>(host_workspace);
|
||||
ElementA** ptr_A_host = reinterpret_cast<ElementA**>(host_workspace + gemm_coord_size);
|
||||
ElementB** ptr_B_host = reinterpret_cast<ElementB**>(host_workspace + gemm_coord_size + ptr_size);
|
||||
ElementOutput** ptr_C_host = reinterpret_cast<ElementOutput**>(host_workspace + gemm_coord_size + 2 * ptr_size);
|
||||
ElementOutput** ptr_D_host = reinterpret_cast<ElementOutput**>(host_workspace + gemm_coord_size + 3 * ptr_size);
|
||||
int64_t* lda_host = reinterpret_cast<int64_t*>(host_workspace + gemm_coord_size + 4 * ptr_size + 0 * ldd_size);
|
||||
int64_t* ldb_host = reinterpret_cast<int64_t*>(host_workspace + gemm_coord_size + 4 * ptr_size + 1 * ldd_size);
|
||||
int64_t* ldc_host = reinterpret_cast<int64_t*>(host_workspace + gemm_coord_size + 4 * ptr_size + 2 * ldd_size);
|
||||
int64_t* ldd_host = reinterpret_cast<int64_t*>(host_workspace + gemm_coord_size + 4 * ptr_size + 3 * ldd_size);
|
||||
|
||||
for (int32_t i = 0; i < problem_count; ++i)
|
||||
{
|
||||
problem_sizes_host[i] = problem_sizes.at(i);
|
||||
ptr_A_host[i] = (ElementA*) ptrA.at(i);
|
||||
ptr_B_host[i] = (ElementB*) ptrB.at(i);
|
||||
ptr_C_host[i] = (ElementOutput*) ptrC.at(i);
|
||||
ptr_D_host[i] = (ElementOutput*) ptrD.at(i);
|
||||
|
||||
auto problem = problem_sizes.at(i);
|
||||
lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0);
|
||||
ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0);
|
||||
ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0);
|
||||
ldd_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0);
|
||||
}
|
||||
|
||||
cutlass::gemm::GemmCoord* problem_sizes_device = reinterpret_cast<cutlass::gemm::GemmCoord*>(workspace);
|
||||
ElementA** ptr_A = reinterpret_cast<ElementA**>((char*) workspace + gemm_coord_size);
|
||||
ElementB** ptr_B = reinterpret_cast<ElementB**>((char*) workspace + gemm_coord_size + ptr_size);
|
||||
ElementOutput** ptr_C = reinterpret_cast<ElementOutput**>((char*) workspace + gemm_coord_size + 2 * ptr_size);
|
||||
ElementOutput** ptr_D = reinterpret_cast<ElementOutput**>((char*) workspace + gemm_coord_size + 3 * ptr_size);
|
||||
int64_t* lda = reinterpret_cast<int64_t*>((char*) workspace + gemm_coord_size + 4 * ptr_size + 0 * ldd_size);
|
||||
int64_t* ldb = reinterpret_cast<int64_t*>((char*) workspace + gemm_coord_size + 4 * ptr_size + 1 * ldd_size);
|
||||
int64_t* ldc = reinterpret_cast<int64_t*>((char*) workspace + gemm_coord_size + 4 * ptr_size + 2 * ldd_size);
|
||||
int64_t* ldd = reinterpret_cast<int64_t*>((char*) workspace + gemm_coord_size + 4 * ptr_size + 3 * ldd_size);
|
||||
|
||||
TLLM_CHECK(((char*) ldc_host - (char*) host_workspace) == ((char*) ldc - (char*) workspace));
|
||||
tensorrt_llm::common::cudaAutoCpy((int8_t*) workspace, (int8_t*) host_workspace, workSpaceSize, stream);
|
||||
|
||||
int threadblock_count = Gemm::sufficient(problem_sizes.data(), problem_count);
|
||||
|
||||
typename Gemm::Arguments args(problem_sizes_device, problem_count, threadblock_count, epilogue_op, ptr_A, ptr_B,
|
||||
ptr_C, ptr_D, lda, ldb, ldc, ldd, problem_sizes.data());
|
||||
|
||||
// Initialize the GEMM object
|
||||
Gemm gemm;
|
||||
|
||||
size_t workspace_size = gemm.get_workspace_size(args);
|
||||
TLLM_CHECK(gemm.get_workspace_size(args) <= cublasWorkspaceSize);
|
||||
|
||||
cutlass::Status status = gemm.initialize(args, cublasWorkSpace);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to initialize CUTLASS Grouped GEMM kernel.");
|
||||
|
||||
// Run the grouped GEMM object
|
||||
status = gemm.run(stream);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to run CUTLASS Grouped GEMM kernel.");
|
||||
|
||||
std::free(host_workspace);
|
||||
}
|
||||
|
||||
void run_cutlass_1(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA,
|
||||
std::vector<void*> ptrB, std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize,
|
||||
void* cublasWorkSpace, int64_t cublasWorkspaceSize, cudaStream_t stream)
|
||||
{
|
||||
// For lora in, which has smaller N
|
||||
run_cutlass_<128, 32, 32, 32, 32, 32>(
|
||||
problem_sizes, ptrA, ptrB, ptrC, ptrD, workspace, workSpaceSize, cublasWorkSpace, cublasWorkspaceSize, stream);
|
||||
}
|
||||
|
||||
void run_cutlass_2(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA,
|
||||
std::vector<void*> ptrB, std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize,
|
||||
void* cublasWorkSpace, int64_t cublasWorkspaceSize, cudaStream_t stream)
|
||||
{
|
||||
// For lora out, which has larger N
|
||||
run_cutlass_<128, 128, 32, 64, 64, 32>(
|
||||
problem_sizes, ptrA, ptrB, ptrC, ptrD, workspace, workSpaceSize, cublasWorkSpace, cublasWorkspaceSize, stream);
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
35
cpp/tensorrt_llm/kernels/groupGemm.h
Normal file
35
cpp/tensorrt_llm/kernels/groupGemm.h
Normal file
@ -0,0 +1,35 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm_coord.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
void run_cutlass_1(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA,
|
||||
std::vector<void*> ptrB, std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize,
|
||||
void* cublasWorkSpace, int64_t cublasWorkspaceSize, cudaStream_t stream);
|
||||
|
||||
void run_cutlass_2(std::vector<cutlass::gemm::GemmCoord> problem_sizes, std::vector<void*> ptrA,
|
||||
std::vector<void*> ptrB, std::vector<void*> ptrC, std::vector<void*> ptrD, void* workspace, int64_t workSpaceSize,
|
||||
void* cublasWorkSpace, int64_t cublasWorkspaceSize, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
@ -411,7 +411,7 @@ template <typename T, int EXPERTS, int WARPS_PER_TB>
|
||||
void topkGatingSoftmaxLauncherHelper(const T* input, const bool* finished, T* output, int* indices, int* source_row,
|
||||
const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
||||
{
|
||||
static constexpr unsigned long MAX_BYTES_PER_LDG = 16;
|
||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||
|
||||
static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(T) * EXPERTS);
|
||||
using Constants = detail::TopkConstants<T, EXPERTS, BYTES_PER_LDG>;
|
||||
|
||||
@ -1628,8 +1628,8 @@ INSTANTIATE_TRANSPOSE_4D(half);
|
||||
|
||||
template <typename T, typename T_cache, typename KVCacheBuffer>
|
||||
__global__ void transpose4dBatchMajorKVCache(const T* kSrc, const T* vSrc, KVCacheBuffer kvCacheBuffer,
|
||||
const int headNum, const int sizePerHead, const int seqLen, const int maxKvCacheLen, const float* kvScaleOrigQuant,
|
||||
const int* sequence_lengths)
|
||||
const int headNum, const int sizePerHead, const int seqLen, const int attentionWindowSize,
|
||||
const float* kvScaleOrigQuant, const int* sequence_lengths)
|
||||
{
|
||||
// We allow only fp32/fp16/bf16 as input types
|
||||
static_assert(sizeof(T) == 4 || sizeof(T) == 2, "");
|
||||
@ -1655,9 +1655,9 @@ __global__ void transpose4dBatchMajorKVCache(const T* kSrc, const T* vSrc, KVCac
|
||||
|
||||
// Get linear token index
|
||||
int tokenIdx = idx / sizePerHeadDivX;
|
||||
// Apply cyclic kv cache if tokenIdx >= max_kv_cache_length.
|
||||
// which means we will drop the tokens in the beginning if seqLen > max_kv_cache_length.
|
||||
const int tokenIdxLowerBound = max(sequence_lengths[batchIdx] - maxKvCacheLen, 0);
|
||||
// Apply cyclic kv cache if tokenIdx >= max_attention_window_size.
|
||||
// which means we will drop the tokens in the beginning if seqLen > max_attention_window_size.
|
||||
const int tokenIdxLowerBound = max(sequence_lengths[batchIdx] - attentionWindowSize, 0);
|
||||
// Get channel index
|
||||
const int channelIdx = idx % sizePerHeadDivX;
|
||||
if (tokenIdx >= sequence_lengths[batchIdx] || tokenIdx < tokenIdxLowerBound)
|
||||
@ -1665,8 +1665,8 @@ __global__ void transpose4dBatchMajorKVCache(const T* kSrc, const T* vSrc, KVCac
|
||||
return;
|
||||
}
|
||||
|
||||
// Apply cyclic kv cache if tokenIdx >= max_kv_cache_length.
|
||||
tokenIdx = tokenIdx % maxKvCacheLen;
|
||||
// Apply cyclic kv cache if tokenIdx >= max_attention_window_size.
|
||||
tokenIdx = tokenIdx % attentionWindowSize;
|
||||
|
||||
// Get pointer to the dst block given sequence, head and token ids
|
||||
auto valDst = handle_k ? reinterpret_cast<T_dst*>(kvCacheBuffer.getKBlockPtr(batchIdx, tokenIdx))
|
||||
@ -1702,7 +1702,7 @@ __global__ void transpose4dBatchMajorKVCache(const T* kSrc, const T* vSrc, KVCac
|
||||
|
||||
template <typename T, typename KVCacheBuffer>
|
||||
void invokeTranspose4dBatchMajor(const T* kSrc, const T* vSrc, KVCacheBuffer& kvTable, const int localBatchSize,
|
||||
const int seqLen, const int maxKvCacheLen, const int sizePerHead, const int localHeadNum,
|
||||
const int seqLen, const int attentionWindowSize, const int sizePerHead, const int localHeadNum,
|
||||
const KvCacheDataType cache_type, const float* kvScaleOrigQuant, const int* sequence_lengths, cudaStream_t stream)
|
||||
{
|
||||
// Block handles both K and V tile.
|
||||
@ -1714,26 +1714,26 @@ void invokeTranspose4dBatchMajor(const T* kSrc, const T* vSrc, KVCacheBuffer& kv
|
||||
|
||||
if (cache_type == KvCacheDataType::INT8)
|
||||
{
|
||||
transpose4dBatchMajorKVCache<T, int8_t, KVCacheBuffer><<<gridSz, blockSz, 0, stream>>>(
|
||||
kSrc, vSrc, kvTable, localHeadNum, sizePerHead, seqLen, maxKvCacheLen, kvScaleOrigQuant, sequence_lengths);
|
||||
transpose4dBatchMajorKVCache<T, int8_t, KVCacheBuffer><<<gridSz, blockSz, 0, stream>>>(kSrc, vSrc, kvTable,
|
||||
localHeadNum, sizePerHead, seqLen, attentionWindowSize, kvScaleOrigQuant, sequence_lengths);
|
||||
}
|
||||
#ifdef ENABLE_FP8
|
||||
else if (cache_type == KvCacheDataType::FP8)
|
||||
{
|
||||
transpose4dBatchMajorKVCache<T, __nv_fp8_e4m3, KVCacheBuffer><<<gridSz, blockSz, 0, stream>>>(
|
||||
kSrc, vSrc, kvTable, localHeadNum, sizePerHead, seqLen, maxKvCacheLen, kvScaleOrigQuant, sequence_lengths);
|
||||
transpose4dBatchMajorKVCache<T, __nv_fp8_e4m3, KVCacheBuffer><<<gridSz, blockSz, 0, stream>>>(kSrc, vSrc,
|
||||
kvTable, localHeadNum, sizePerHead, seqLen, attentionWindowSize, kvScaleOrigQuant, sequence_lengths);
|
||||
}
|
||||
#endif // ENABLE_FP8
|
||||
else
|
||||
{
|
||||
transpose4dBatchMajorKVCache<T, T, KVCacheBuffer><<<gridSz, blockSz, 0, stream>>>(
|
||||
kSrc, vSrc, kvTable, localHeadNum, sizePerHead, seqLen, maxKvCacheLen, kvScaleOrigQuant, sequence_lengths);
|
||||
transpose4dBatchMajorKVCache<T, T, KVCacheBuffer><<<gridSz, blockSz, 0, stream>>>(kSrc, vSrc, kvTable,
|
||||
localHeadNum, sizePerHead, seqLen, attentionWindowSize, kvScaleOrigQuant, sequence_lengths);
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TRANSPOSE_4D_BATCH_MAJOR_KV_CACHE_TYPE(T, KVCacheBuffer) \
|
||||
template void invokeTranspose4dBatchMajor(const T* kSrc, const T* vSrc, KVCacheBuffer& kvTable, \
|
||||
const int localBatchSize, const int seqLen, const int maxKvCacheLen, const int sizePerHead, \
|
||||
const int localBatchSize, const int seqLen, const int attentionWindowSize, const int sizePerHead, \
|
||||
const int localHeadNum, const KvCacheDataType cache_type, const float* kvScaleOrigQuant, \
|
||||
const int* sequence_lengths, cudaStream_t stream)
|
||||
|
||||
|
||||
@ -105,7 +105,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, const
|
||||
|
||||
template <typename T, typename KVCacheBuffer>
|
||||
void invokeTranspose4dBatchMajor(const T* k_src, const T* v_src, KVCacheBuffer& kvTable, const int local_batch_size,
|
||||
const int seq_len, const int max_kv_cache_len, const int size_per_head, const int local_head_num,
|
||||
const int seq_len, const int max_attention_window_size, const int size_per_head, const int local_head_num,
|
||||
const KvCacheDataType cache_type, const float* kvScaleOrigQuant, const int* sequence_lengths, cudaStream_t stream);
|
||||
|
||||
// NOTE: this kernel is in-place, QKV will be modified, if other kernels need that, may need copy or use before it.
|
||||
|
||||
@ -32,41 +32,41 @@ namespace layers
|
||||
|
||||
__global__ void update_indir_cache_kernel(int* tgt_indir_cache, const int* src_indir_cache, const int** parent_ids,
|
||||
const FinishedState* finished, const int* sequence_lengths, const int* input_lengths, int batch_dim,
|
||||
int local_batch_size, int beam_width, int max_kv_cache_length, int max_seq_len)
|
||||
int local_batch_size, int beam_width, int max_attention_window, int max_seq_len)
|
||||
{
|
||||
int time_step = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int bb_id = threadIdx.y + blockIdx.y * blockDim.y; // should be just blockIdx.y?
|
||||
const int current_step{sequence_lengths[bb_id] - 1}; // the sequence_lengths is updated, need to minus 1
|
||||
const int batch_id = bb_id / beam_width;
|
||||
const int beam_id = bb_id % beam_width;
|
||||
if (bb_id >= beam_width * local_batch_size || time_step < (max_seq_len - max_kv_cache_length)
|
||||
if (bb_id >= beam_width * local_batch_size || time_step < (max_seq_len - max_attention_window)
|
||||
|| finished[bb_id].isFinished())
|
||||
{
|
||||
return;
|
||||
}
|
||||
int time_step_circ = time_step % max_kv_cache_length;
|
||||
int time_step_circ = time_step % max_attention_window;
|
||||
|
||||
// for the parent_ids, we will still keep it for all past tokens (i.e. max_seq_len)
|
||||
const int src_beam = parent_ids[batch_id][beam_id * max_seq_len + current_step];
|
||||
|
||||
// for the indir tables, we have the cyclic kv cache.
|
||||
const uint32_t tgt_offset
|
||||
= batch_id * beam_width * max_kv_cache_length + beam_id * max_kv_cache_length + time_step_circ;
|
||||
= batch_id * beam_width * max_attention_window + beam_id * max_attention_window + time_step_circ;
|
||||
const uint32_t src_offset
|
||||
= batch_id * beam_width * max_kv_cache_length + src_beam * max_kv_cache_length + time_step_circ;
|
||||
= batch_id * beam_width * max_attention_window + src_beam * max_attention_window + time_step_circ;
|
||||
|
||||
tgt_indir_cache[tgt_offset] = (time_step == current_step) ? beam_id : src_indir_cache[src_offset];
|
||||
}
|
||||
|
||||
void update_indir_cache_kernelLauncher(int* tgt_indir_cache, const int* src_indir_cache, const int** parent_ids,
|
||||
const FinishedState* finished, const int* sequence_lengths, const int* input_lengths, int batch_dim,
|
||||
int local_batch_size, int beam_width, int max_seq_len, int max_kv_cache_length, cudaStream_t stream)
|
||||
int local_batch_size, int beam_width, int max_seq_len, int max_attention_window, cudaStream_t stream)
|
||||
{
|
||||
const dim3 block(32);
|
||||
// Update indirections steps [input_length[bb_id], sequence_lengths[bb_id]], included
|
||||
const dim3 grid((max_seq_len + block.x - 1) / block.x, local_batch_size * beam_width);
|
||||
update_indir_cache_kernel<<<grid, block, 0, stream>>>(tgt_indir_cache, src_indir_cache, parent_ids, finished,
|
||||
sequence_lengths, input_lengths, batch_dim, local_batch_size, beam_width, max_kv_cache_length, max_seq_len);
|
||||
sequence_lengths, input_lengths, batch_dim, local_batch_size, beam_width, max_attention_window, max_seq_len);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -191,7 +191,7 @@ void BaseBeamSearchLayer<T>::forward(BeamSearchOutputParams& outputs, ForwardPar
|
||||
reinterpret_cast<const FinishedState*>(
|
||||
outputs.finished->template getPtr<const FinishedState::UnderlyingType>()),
|
||||
sequence_length, input_lengths, batch_size, local_batch_size, beam_width, max_seq_len,
|
||||
params.max_kv_cache_length, stream_);
|
||||
params.max_attention_window, stream_);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
|
||||
@ -56,16 +56,16 @@ public:
|
||||
{
|
||||
public:
|
||||
ForwardParams(int step, int ite, tc::Tensor logits, tc::Tensor endIds, tc::Tensor src_cache_indirection,
|
||||
int max_kv_cache_length, int max_seq_len)
|
||||
int max_attention_window, int max_seq_len)
|
||||
: SoftmaxParams(step, ite, std::move(logits), std::move(endIds))
|
||||
, src_cache_indirection{std::move(src_cache_indirection)}
|
||||
, max_kv_cache_length{max_kv_cache_length}
|
||||
, max_attention_window{max_attention_window}
|
||||
, max_seq_len{max_seq_len}
|
||||
{
|
||||
}
|
||||
|
||||
// mandatory parameters
|
||||
int max_kv_cache_length;
|
||||
int max_attention_window;
|
||||
int max_seq_len;
|
||||
tc::Tensor src_cache_indirection; // [local_batch_size, beam_width, max_seq_len]
|
||||
|
||||
|
||||
@ -296,7 +296,7 @@ void DynamicDecodeLayer<T>::forward(OutputParams& outputs, ForwardParams const&
|
||||
auto const end_id_offset
|
||||
= end_ids.slice({dynamic_decode_batch_size}, dynamic_ite * dynamic_decode_batch_size);
|
||||
typename BaseBeamSearchLayer<T>::ForwardParams dynamic_decode_input_tensors{step, ite, logits_offset,
|
||||
end_id_offset, *params.src_cache_indirection, static_cast<std::int32_t>(params.max_kv_cache_length),
|
||||
end_id_offset, *params.src_cache_indirection, static_cast<std::int32_t>(params.max_attention_window),
|
||||
static_cast<std::int32_t>(max_seq_len)};
|
||||
|
||||
dynamic_decode_input_tensors.embedding_bias = params.embedding_bias;
|
||||
|
||||
@ -79,12 +79,12 @@ public:
|
||||
class ForwardParams
|
||||
{
|
||||
public:
|
||||
ForwardParams(int step, int ite, int maxInputLength, int maxKvCacheLength, int localBatchSize,
|
||||
ForwardParams(int step, int ite, int maxInputLength, int maxAttentionWindow, int localBatchSize,
|
||||
tc::Tensor logits, tc::Tensor endIds)
|
||||
: step{step}
|
||||
, ite{ite}
|
||||
, max_input_length{maxInputLength}
|
||||
, max_kv_cache_length{maxKvCacheLength}
|
||||
, max_attention_window{maxAttentionWindow}
|
||||
, local_batch_size{localBatchSize}
|
||||
, logits{std::move(logits)}
|
||||
, end_ids{std::move(endIds)}
|
||||
@ -95,7 +95,7 @@ public:
|
||||
int step;
|
||||
int ite;
|
||||
int max_input_length;
|
||||
int max_kv_cache_length;
|
||||
int max_attention_window;
|
||||
int local_batch_size;
|
||||
tc::Tensor logits; // [batch_size, beam_width, vocab_size_padded], on gpu
|
||||
tc::Tensor end_ids; // [batch_size], on gpu
|
||||
|
||||
@ -84,8 +84,8 @@ struct FusedQKVMaskedAttentionDispatchParams
|
||||
float rotary_embedding_scale;
|
||||
int rotary_embedding_max_positions;
|
||||
PositionEmbeddingType position_embedding_type;
|
||||
int max_kv_cache_length;
|
||||
int cyclic_kv_cache_length;
|
||||
int max_attention_window;
|
||||
int cyclic_attention_window_size;
|
||||
const int* input_lengths;
|
||||
int step;
|
||||
float q_scaling;
|
||||
@ -182,11 +182,12 @@ bool GPTAttentionPluginCommon::convertMMHAParamsToXQAParams(
|
||||
xqaParams.kv_scale_orig_quant = generationsParams.kv_scale_orig_quant;
|
||||
xqaParams.kv_scale_quant_orig = generationsParams.kv_scale_quant_orig;
|
||||
xqaParams.host_past_key_value_lengths = generationsParams.host_past_key_value_lengths;
|
||||
xqaParams.host_context_lengths = generationsParams.host_context_lengths;
|
||||
xqaParams.workspaces = generationsParams.workspace;
|
||||
xqaParams.batch_size = generationsParams.num_requests;
|
||||
xqaParams.beam_width = generationsParams.beam_width;
|
||||
xqaParams.max_kv_cache_length = generationsParams.max_kv_cache_length;
|
||||
xqaParams.cyclic_kv_cache_length = generationsParams.cyclic_kv_cache_length;
|
||||
xqaParams.max_attention_window_size = generationsParams.max_attention_window;
|
||||
xqaParams.cyclic_attention_window_size = generationsParams.cyclic_attention_window_size;
|
||||
xqaParams.timestep = generationsParams.past_kv_length;
|
||||
xqaParams.qkv_bias = generationsParams.qkv_bias;
|
||||
xqaParams.sequence_lengths = generationsParams.sequence_lengths;
|
||||
@ -241,8 +242,8 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS
|
||||
params.cache_indir = input_params.cache_indir;
|
||||
params.batch_size = input_params.inference_batch_size;
|
||||
params.beam_width = input_params.beam_width;
|
||||
params.max_kv_cache_length = input_params.max_kv_cache_length;
|
||||
params.cyclic_kv_cache_length = input_params.cyclic_kv_cache_length;
|
||||
params.max_attention_window_size = input_params.max_attention_window;
|
||||
params.cyclic_attention_window_size = input_params.cyclic_attention_window_size;
|
||||
params.length_per_sample = input_params.sequence_lengths; // max_input_length + current output length
|
||||
// timestep for shared memory size calculation and rotary embedding computation
|
||||
params.timestep = input_params.step - 1;
|
||||
@ -550,7 +551,7 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
{
|
||||
using BufferDataType = typename KVCacheBufferDataType<KVCacheBuffer>::Type;
|
||||
kv_cache_buffer = KVCacheBuffer(params.batch_size, 1,
|
||||
isCrossAttention() ? params.cross_qkv_length : params.max_kv_cache_length,
|
||||
isCrossAttention() ? params.cross_qkv_length : params.max_attention_window,
|
||||
num_kv_heads * head_size * elem_size);
|
||||
kv_cache_buffer.data = reinterpret_cast<BufferDataType*>(params.key_value_cache);
|
||||
}
|
||||
@ -651,7 +652,7 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
decoder_params.seqKVLengths = isCrossAttention() ? params.encoder_input_lengths : params.kv_seq_lengths;
|
||||
decoder_params.batchSize = params.batch_size;
|
||||
decoder_params.maxSeqLength = isCrossAttention() ? params.cross_qkv_length : params.input_seq_length;
|
||||
decoder_params.maxKvCacheLength = params.cyclic_kv_cache_length;
|
||||
decoder_params.attentionWindowSize = params.cyclic_attention_window_size;
|
||||
decoder_params.numTokens = params.num_tokens;
|
||||
decoder_params.attentionMaskType = mMaskType;
|
||||
invokeBuildDecoderInfo(decoder_params, stream);
|
||||
@ -712,7 +713,7 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
invokeApplyBiasRopeUpdateKVCache(const_cast<T*>(params.attention_input), q_buf_2_, kv_cache_buffer,
|
||||
const_cast<T*>(params.qkv_bias), params.q_seq_lengths, params.kv_seq_lengths,
|
||||
mRemovePadding ? padding_offset : nullptr, params.batch_size, params.input_seq_length,
|
||||
params.cyclic_kv_cache_length, params.num_tokens, mNumHeads, mNumKVHeads, getHeadSize(),
|
||||
params.cyclic_attention_window_size, params.num_tokens, mNumHeads, mNumKVHeads, getHeadSize(),
|
||||
mRotaryEmbeddingDim, mRotaryEmbeddingBase, mRotaryEmbeddingScaleType, mRotaryEmbeddingScale,
|
||||
mRotaryEmbeddingMaxPositions, position_embedding_type, (float*) nullptr, 0, cache_type,
|
||||
params.kv_scale_orig_quant, enablePagedKVContextFMHA, stream);
|
||||
@ -733,9 +734,9 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
// - cu_kv_seqlens: the cumulative kv sequence lengths, needed for variable sequence length.
|
||||
|
||||
// the token will pay attention to previous tokens while starting from max(0, rowIdx -
|
||||
// cyclic_kv_cache_length);
|
||||
// cyclic_attention_window_size);
|
||||
mFMHARunner->setup_paged_kv(params.batch_size, params.input_seq_length, params.max_past_kv_len,
|
||||
blocks_per_context_sequence, mTokensPerBlock, params.cyclic_kv_cache_length, params.num_tokens,
|
||||
blocks_per_context_sequence, mTokensPerBlock, params.cyclic_attention_window_size, params.num_tokens,
|
||||
isALiBi(), isAliBiWithScale(), mTpSize, mTpRank);
|
||||
mFMHARunner->run_paged_kv(q_buf_2_, paged_kv_tma_desc, host_kv_cache_block_ptrs,
|
||||
reinterpret_cast<KVBlockArray&>(kv_cache_buffer), cu_q_seqlens, cu_kv_seqlens, params.context_buf,
|
||||
@ -744,8 +745,8 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
else
|
||||
{
|
||||
// the token will pay attention to previous tokens while starting from max(0, rowIdx -
|
||||
// cyclic_kv_cache_length);
|
||||
mFMHARunner->setup(params.batch_size, params.input_seq_length, params.cyclic_kv_cache_length,
|
||||
// cyclic_attention_window_size);
|
||||
mFMHARunner->setup(params.batch_size, params.input_seq_length, params.cyclic_attention_window_size,
|
||||
params.num_tokens, isALiBi(), isAliBiWithScale(), mTpSize, mTpRank);
|
||||
mFMHARunner->run(const_cast<T*>(params.attention_input), cu_q_seqlens, params.context_buf, stream);
|
||||
}
|
||||
@ -795,7 +796,7 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
{
|
||||
invokeTranspose4dBatchMajor(k_buf_2_, v_buf_2_, kv_cache_buffer, params.batch_size,
|
||||
isCrossAttention() ? params.cross_qkv_length : params.input_seq_length,
|
||||
isCrossAttention() ? params.cross_qkv_length : params.cyclic_kv_cache_length, getHeadSize(),
|
||||
isCrossAttention() ? params.cross_qkv_length : params.cyclic_attention_window_size, getHeadSize(),
|
||||
mNumKVHeads, cache_type, params.kv_scale_orig_quant,
|
||||
isCrossAttention() ? params.encoder_input_lengths : params.q_seq_lengths, stream);
|
||||
}
|
||||
@ -891,7 +892,7 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
// [num_heads, num_buckets], with necessary params (max_distance, num_buckets) passed at the end
|
||||
invokeAddRelativeAttentionBiasUnaligned(qk_buf_float_, relative_attention_bias, params.batch_size,
|
||||
mNumHeads, attention_seq_len_1,
|
||||
isCrossAttention() ? params.cross_qkv_length : params.cyclic_kv_cache_length, stream,
|
||||
isCrossAttention() ? params.cross_qkv_length : params.cyclic_attention_window_size, stream,
|
||||
max_distance > 0, relative_attention_bias_stride, max_distance, true /* bidirectional */);
|
||||
}
|
||||
|
||||
@ -918,8 +919,9 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
// already max_output_len + 1. In implicit mode, relative_attention_bias is relative_attention_table
|
||||
// [num_heads, num_buckets], with necessary params (max_distance, num_buckets) passed at the end
|
||||
invokeAddRelativeAttentionBiasUnaligned(qk_buf_, relative_attention_bias, params.batch_size, mNumHeads,
|
||||
attention_seq_len_1, isCrossAttention() ? params.cross_qkv_length : params.cyclic_kv_cache_length,
|
||||
stream, max_distance > 0, relative_attention_bias_stride, max_distance, true /* bidirectional */);
|
||||
attention_seq_len_1,
|
||||
isCrossAttention() ? params.cross_qkv_length : params.cyclic_attention_window_size, stream,
|
||||
max_distance > 0, relative_attention_bias_stride, max_distance, true /* bidirectional */);
|
||||
}
|
||||
|
||||
MaskedSoftmaxParam<T, T> param;
|
||||
@ -1078,7 +1080,7 @@ int GPTAttentionPluginCommon::enqueueGeneration(
|
||||
{
|
||||
using BufferDataType = typename KVCacheBufferDataType<KVCacheBuffer>::Type;
|
||||
kv_cache_buffer
|
||||
= KVCacheBuffer(batch_beam, 1, params.max_kv_cache_length, num_kv_heads * head_size * elem_size);
|
||||
= KVCacheBuffer(batch_beam, 1, params.max_attention_window, num_kv_heads * head_size * elem_size);
|
||||
kv_cache_buffer.data = reinterpret_cast<BufferDataType*>(params.key_value_cache);
|
||||
}
|
||||
}
|
||||
@ -1098,8 +1100,8 @@ int GPTAttentionPluginCommon::enqueueGeneration(
|
||||
}
|
||||
|
||||
int timestep = params.past_kv_length;
|
||||
const int max_timesteps
|
||||
= mCrossAttention ? params.cyclic_kv_cache_length : std::min(timestep, params.cyclic_kv_cache_length);
|
||||
const int max_timesteps = mCrossAttention ? params.cyclic_attention_window_size
|
||||
: std::min(timestep, params.cyclic_attention_window_size);
|
||||
int estimated_min_multi_block_count
|
||||
= estimate_min_multi_block_count<T>(max_timesteps, mMaxSharedMemoryPerBlockOptin - 2048);
|
||||
|
||||
@ -1157,8 +1159,8 @@ int GPTAttentionPluginCommon::enqueueGeneration(
|
||||
dispatch_params.size_per_head = getHeadSize();
|
||||
dispatch_params.rotary_embedding_dim = mRotaryEmbeddingDim;
|
||||
dispatch_params.position_embedding_type = mPositionEmbeddingType;
|
||||
dispatch_params.max_kv_cache_length = params.max_kv_cache_length;
|
||||
dispatch_params.cyclic_kv_cache_length = params.cyclic_kv_cache_length;
|
||||
dispatch_params.max_attention_window = params.max_attention_window;
|
||||
dispatch_params.cyclic_attention_window_size = params.cyclic_attention_window_size;
|
||||
dispatch_params.input_lengths = params.context_lengths;
|
||||
dispatch_params.step = step;
|
||||
dispatch_params.q_scaling = q_scaling;
|
||||
|
||||
@ -87,12 +87,12 @@ protected:
|
||||
T const* qkv_bias;
|
||||
int32_t input_seq_length; // padded input length
|
||||
int32_t max_past_kv_len;
|
||||
// By default, max_kv_cache_length == cyclic_kv_cache_length
|
||||
// By default, max_attention_window == cyclic_attention_window_size
|
||||
// unless each layer has different cyclic kv cache length.
|
||||
// Max cache capacity (used to allocate KV cache)
|
||||
int32_t max_kv_cache_length;
|
||||
int32_t max_attention_window;
|
||||
// Cyclic kv cache capacity (used to get the cyclic kv cache position for new tokens)
|
||||
int32_t cyclic_kv_cache_length;
|
||||
int32_t cyclic_attention_window_size;
|
||||
int32_t const* q_seq_lengths;
|
||||
int32_t const* kv_seq_lengths;
|
||||
float const* kv_scale_orig_quant;
|
||||
@ -134,12 +134,12 @@ protected:
|
||||
T* context_buf;
|
||||
void* key_value_cache;
|
||||
void* block_pointers;
|
||||
// By default, max_kv_cache_length == cyclic_kv_cache_length
|
||||
// By default, max_attention_window == cyclic_attention_window_size
|
||||
// unless each layer has different cyclic kv cache length.
|
||||
// Max cache capacity (used to allocate KV cache)
|
||||
int32_t max_kv_cache_length;
|
||||
int32_t max_attention_window;
|
||||
// Cyclic kv cache capacity (used to get the cyclic kv cache position for new tokens)
|
||||
int32_t cyclic_kv_cache_length;
|
||||
int32_t cyclic_attention_window_size;
|
||||
int32_t num_requests;
|
||||
int32_t max_blocks_per_sequence;
|
||||
int32_t const* cache_indir;
|
||||
@ -150,6 +150,7 @@ protected:
|
||||
int relative_attention_bias_stride = 0;
|
||||
// optional when cross attention
|
||||
int32_t const* encoder_input_lengths = nullptr;
|
||||
int32_t const* host_context_lengths = nullptr;
|
||||
};
|
||||
|
||||
template <typename T, typename KVCacheBuffer>
|
||||
|
||||
@ -66,7 +66,7 @@ bool GPTAttentionPlugin::isEntryUsed(const IdxEntry& entry) const
|
||||
case IdxEntry::QKV_TENSOR: return true;
|
||||
case IdxEntry::SEQUENCE_LENGTH: return useKVCache();
|
||||
case IdxEntry::HOST_PAST_KEY_VALUE_LENGTHS: return useKVCache();
|
||||
case IdxEntry::HOST_MAX_KV_CACHE_LENGTH: return true;
|
||||
case IdxEntry::HOST_MAX_ATTENTION_WINDOW: return true;
|
||||
case IdxEntry::CONTEXT_LENGTHS: return true;
|
||||
case IdxEntry::CACHE_INDIR: return useKVCache();
|
||||
case IdxEntry::REQUEST_TYPES: return true;
|
||||
@ -132,7 +132,7 @@ bool GPTAttentionPlugin::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept
|
||||
{
|
||||
if (pos == getIdx(IdxEntry::CONTEXT_LENGTHS) || pos == getIdx(IdxEntry::REQUEST_TYPES)
|
||||
|| pos == getIdx(IdxEntry::HOST_MAX_KV_CACHE_LENGTH))
|
||||
|| pos == getIdx(IdxEntry::HOST_MAX_ATTENTION_WINDOW))
|
||||
{
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32;
|
||||
}
|
||||
@ -319,17 +319,17 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
|
||||
const int beamWidth = useKVCache() ? inputDesc[getIdx(IdxEntry::CACHE_INDIR)].dims.d[1] : 1;
|
||||
|
||||
// Commonly, cyclic kv cache length, and max kv cache length will be the same
|
||||
// unless each layer has different max kv cache length.
|
||||
// Commonly, cyclic_attention_window_size, and max_attention_window_size will be the same
|
||||
// unless each layer has different attention window sizes.
|
||||
// the kv_cache capacity.
|
||||
const int max_kv_cache_length = isCrossAttention()
|
||||
const int max_attention_window_size = isCrossAttention()
|
||||
? max_encoder_context_len
|
||||
: (useKVCache() ? inputDesc[getIdx(IdxEntry::CACHE_INDIR)].dims.d[2] : 0);
|
||||
// The cyclic_kv_cache_length will determine the cyclic kv cache position of new tokens.
|
||||
// Note that this cyclic_kv_cache_length might be smaller than the actual kv cache capactity (max_kv_cache_length).
|
||||
const int cyclic_kv_cache_length = isCrossAttention()
|
||||
// The cyclic_attention_window_size will determine the cyclic kv cache position of new tokens.
|
||||
// Note that this cyclic_attention_window_size might be smaller than the actual kv cache capactity.
|
||||
const int cyclic_attention_window_size = isCrossAttention()
|
||||
? max_encoder_context_len
|
||||
: reinterpret_cast<const int*>(inputs[getIdx(IdxEntry::HOST_MAX_KV_CACHE_LENGTH)])[0];
|
||||
: reinterpret_cast<const int*>(inputs[getIdx(IdxEntry::HOST_MAX_ATTENTION_WINDOW)])[0];
|
||||
|
||||
const float* kv_scale_orig_quant = nullptr;
|
||||
const float* kv_scale_quant_orig = nullptr;
|
||||
@ -395,9 +395,9 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
}
|
||||
|
||||
EnqueueContextParams<T, KVCacheBuffer> enqueue_params{attention_input, qkv_bias, max_context_q_len,
|
||||
max_context_kv_len, max_kv_cache_length, cyclic_kv_cache_length, context_q_lengths, sequence_kv_length,
|
||||
kv_scale_orig_quant, kv_scale_quant_orig, alibi_slopes, context_buf_, key_value_cache, block_pointers,
|
||||
host_block_pointers, batch_size, localNbTokens, max_blocks_per_sequence, workspace};
|
||||
max_context_kv_len, max_attention_window_size, cyclic_attention_window_size, context_q_lengths,
|
||||
sequence_kv_length, kv_scale_orig_quant, kv_scale_quant_orig, alibi_slopes, context_buf_, key_value_cache,
|
||||
block_pointers, host_block_pointers, batch_size, localNbTokens, max_blocks_per_sequence, workspace};
|
||||
if (isRelativePosition())
|
||||
{
|
||||
enqueue_params.relative_attention_bias
|
||||
@ -425,11 +425,14 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
|
||||
const int* cache_indir
|
||||
= beamWidth == 1 ? nullptr : reinterpret_cast<const int*>(inputs[getIdx(IdxEntry::CACHE_INDIR)]);
|
||||
const int* host_context_lengths
|
||||
= mRemovePadding ? reinterpret_cast<const int*>(inputs[getIdx(IdxEntry::HOST_CONTEXT_LENGTH)]) : nullptr;
|
||||
|
||||
EnqueueGenerationParams<T, KVCacheBuffer> enqueue_params{attention_input, qkv_bias, sequence_kv_length,
|
||||
max_context_kv_len, beamWidth, context_q_lengths, kv_scale_orig_quant, kv_scale_quant_orig, alibi_slopes,
|
||||
context_buf_, key_value_cache, block_pointers, max_kv_cache_length, cyclic_kv_cache_length, num_requests,
|
||||
max_blocks_per_sequence, cache_indir, workspace, max_context_kv_len_list};
|
||||
context_buf_, key_value_cache, block_pointers, max_attention_window_size, cyclic_attention_window_size,
|
||||
num_requests, max_blocks_per_sequence, cache_indir, workspace, max_context_kv_len_list};
|
||||
enqueue_params.host_context_lengths = host_context_lengths;
|
||||
if (isRelativePosition())
|
||||
{
|
||||
enqueue_params.relative_attention_bias
|
||||
|
||||
@ -46,7 +46,7 @@ namespace tensorrt_llm::plugins
|
||||
// enable_remove_input_padding
|
||||
// 1. sequence_length [batch_size] (optional)
|
||||
// 2. host_past_key_value_lengths [batch_size] (int32) (optional)
|
||||
// 3. host_max_kv_cache_lengths [1] (int32)
|
||||
// 3. host_max_attention_window_sizes [1] (int32)
|
||||
// 4. context_lengths [batch_size]
|
||||
// 5. cache_indir [num_gen_requests, beam_width, memory_max_len] (required in beamsearch) (optional)
|
||||
// 6. host_request_types [batch_size] int32. 0: context; 1: generation: 2: none. When not in inflight-batching
|
||||
@ -141,7 +141,7 @@ private:
|
||||
QKV_TENSOR,
|
||||
SEQUENCE_LENGTH,
|
||||
HOST_PAST_KEY_VALUE_LENGTHS,
|
||||
HOST_MAX_KV_CACHE_LENGTH,
|
||||
HOST_MAX_ATTENTION_WINDOW,
|
||||
CONTEXT_LENGTHS,
|
||||
CACHE_INDIR,
|
||||
REQUEST_TYPES,
|
||||
|
||||
@ -15,9 +15,16 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "loraPlugin.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/kernels/groupGemm.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cublasMMWrapper.h"
|
||||
#include "tensorrt_llm/common/cublasVersionCheck.h"
|
||||
#include <algorithm>
|
||||
|
||||
using namespace nvinfer1;
|
||||
using namespace tensorrt_llm::common;
|
||||
using tensorrt_llm::plugins::LoraPluginCreator;
|
||||
@ -31,6 +38,10 @@ static const char* LORA_PLUGIN_NAME{"Lora"};
|
||||
PluginFieldCollection LoraPluginCreator::mFC{};
|
||||
std::vector<nvinfer1::PluginField> LoraPluginCreator::mPluginAttributes;
|
||||
|
||||
// TODO should be managed by better way
|
||||
static std::vector<cublasHandle_t> cublas_handles;
|
||||
static std::vector<cudaStream_t> streams;
|
||||
|
||||
// TODO should reuse the function in gemmPlugin
|
||||
void _getProblemParams(cublasOperation_t& transa, cublasOperation_t& transb, int& m, int& n, int& k, int& lda, int& ldb,
|
||||
int& ldc, bool transA, bool transB, int M, int N, int K)
|
||||
@ -283,6 +294,34 @@ void LoraPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, in
|
||||
mGemmId.k = K;
|
||||
}
|
||||
|
||||
size_t getLowRankWorkSpaceSize(int nbReq, int maxContextLength, int maxLowRank, int typeSize)
|
||||
{
|
||||
return (size_t) divUp(nbReq * maxContextLength * maxLowRank * typeSize, 16) * 16;
|
||||
}
|
||||
|
||||
size_t getCutlassWorkSpaceSize(int nbReq)
|
||||
{
|
||||
auto gemm_coord_size = divUp(nbReq * sizeof(cutlass::gemm::GemmCoord), 16) * 16;
|
||||
auto ptr_size = 4 * divUp(nbReq * sizeof(half*), 16) * 16;
|
||||
auto ldd_size = 4 * divUp(nbReq * sizeof(int64_t), 16) * 16;
|
||||
|
||||
return gemm_coord_size + ptr_size + ldd_size;
|
||||
}
|
||||
|
||||
LoraPlugin::~LoraPlugin()
|
||||
{
|
||||
for (int i = 0; i < streams.size(); i++)
|
||||
{
|
||||
if (i != 0)
|
||||
{
|
||||
cudaStreamDestroy(streams.at(i));
|
||||
}
|
||||
cublasDestroy(cublas_handles.at(i));
|
||||
}
|
||||
streams.clear();
|
||||
cublas_handles.clear();
|
||||
}
|
||||
|
||||
size_t LoraPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept
|
||||
{
|
||||
@ -291,15 +330,30 @@ size_t LoraPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, in
|
||||
auto const type = inputs[getInputTensorIdx()].type;
|
||||
auto const typeSize = tensorrt_llm::runtime::BufferDataType(type).getSize();
|
||||
|
||||
size_t const lowRankWorkSpaceSize = nbReq * mMaxContextLength * mMaxLowRank * typeSize;
|
||||
return CUBLAS_WORKSPACE_SIZE + getLowRankWorkSpaceSize(nbReq, mMaxContextLength, mMaxLowRank, typeSize)
|
||||
+ getCutlassWorkSpaceSize(nbReq);
|
||||
}
|
||||
|
||||
return CUBLAS_WORKSPACE_SIZE + lowRankWorkSpaceSize;
|
||||
void runCublasGemmEx(const int M, const int N, const int K, const bool transA, const bool transB, const void* act,
|
||||
const void* weight, void* output, cublasHandle_t cublas_handle)
|
||||
{
|
||||
float a = 1.0f;
|
||||
float b = 0.0f;
|
||||
void* alpha = &a;
|
||||
void* beta = &b;
|
||||
cublasOperation_t transa, transb;
|
||||
int m, n, k;
|
||||
int lda, ldb, ldc;
|
||||
_getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, transA, transB, M, N, K);
|
||||
|
||||
tensorrt_llm::common::check_cuda_error(cublasGemmEx(cublas_handle, transa, transb, m, n, k, alpha, weight,
|
||||
CUDA_R_16F, lda, act, CUDA_R_16F, ldb, beta, output, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));
|
||||
}
|
||||
|
||||
int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
|
||||
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
// inputs
|
||||
// input [-1, K] (view as 2D)
|
||||
// host_request_type [batch_size] on cpu
|
||||
@ -309,10 +363,9 @@ int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf
|
||||
// outputs
|
||||
// output [-1, N] (view as 2D)
|
||||
|
||||
auto const typeSize = tensorrt_llm::runtime::BufferDataType(mType).getSize();
|
||||
void* cublasWorkSpace = workspace;
|
||||
void* lowRankWorkSpace = static_cast<char*>(cublasWorkSpace) + CUBLAS_WORKSPACE_SIZE;
|
||||
TLLM_CHECK(mType == DataType::kHALF); // Only support on half now, will extend to more data type in near future.
|
||||
|
||||
auto const typeSize = tensorrt_llm::runtime::BufferDataType(mType).getSize();
|
||||
setGemmConfig();
|
||||
auto const batch_size = inputDesc[getLoraRanksIdx()].dims.d[0];
|
||||
auto const lora_ranks = static_cast<int32_t const*>(inputs[getLoraRanksIdx()]);
|
||||
@ -321,55 +374,230 @@ int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf
|
||||
= mRemoveInputPadding ? static_cast<int32_t const*>(inputs[getHostContextLengthsIdx()]) : nullptr;
|
||||
RequestType const* reqTypes = static_cast<RequestType const*>(inputs[getHostRequestTypesIdx()]);
|
||||
|
||||
void* cublasWorkSpace = workspace;
|
||||
void* lowRankWorkSpace = static_cast<char*>(cublasWorkSpace) + CUBLAS_WORKSPACE_SIZE;
|
||||
void* cutlassWorkSpace = static_cast<char*>(lowRankWorkSpace)
|
||||
+ getLowRankWorkSpaceSize(batch_size, mMaxContextLength, mMaxLowRank, typeSize);
|
||||
size_t cutlassWorkSpaceSize = getCutlassWorkSpaceSize(batch_size);
|
||||
size_t handled_token_num = 0;
|
||||
|
||||
// TODO only initialize output buffer when the lora rank is -1
|
||||
const int nbDimsA = inputDesc[0].dims.nbDims;
|
||||
|
||||
bool useUnifyGemm = false;
|
||||
for (int batchIdx = 0; batchIdx < batch_size; batchIdx++)
|
||||
{
|
||||
const RequestType reqType = reqTypes[batchIdx];
|
||||
const auto M = (reqType != RequestType::kCONTEXT)
|
||||
? 1
|
||||
: (mRemoveInputPadding ? host_context_lengths[batchIdx] : inputDesc[0].dims.d[1]);
|
||||
const auto lora_rank = lora_ranks[batchIdx];
|
||||
|
||||
if (lora_rank <= 0)
|
||||
if (lora_weights_ptr[2 * batchIdx] != lora_weights_ptr[0]
|
||||
|| lora_weights_ptr[2 * batchIdx + 1] != lora_weights_ptr[1] || lora_ranks[batchIdx] == 0)
|
||||
{
|
||||
const auto N = outputDesc[0].dims.d[outputDesc[0].dims.nbDims - 1];
|
||||
void* output = static_cast<void*>(static_cast<char*>(outputs[0]) + handled_token_num * N * typeSize);
|
||||
if (typeSize == 2)
|
||||
{
|
||||
deviceFill((half*) output, M * N, (half) 0.0f, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
deviceFill((float*) output, M * N, 0.0f, stream);
|
||||
}
|
||||
useUnifyGemm = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
// the input shape should be [1, token_num, K] under remove_input_padding,
|
||||
// [batch, seqlen, K] under non-remove_input_padding
|
||||
auto bestTactic = mPluginProfiler->getBestConfig(M, mGemmId);
|
||||
}
|
||||
|
||||
const int nbDimsA = inputDesc[0].dims.nbDims;
|
||||
bool UseGroupedGemm = true;
|
||||
if (useUnifyGemm)
|
||||
{
|
||||
const RequestType reqType = reqTypes[0];
|
||||
int M = 0;
|
||||
for (int batchIdx = 0; batchIdx < batch_size; batchIdx++)
|
||||
{
|
||||
M += (reqType != RequestType::kCONTEXT)
|
||||
? 1
|
||||
: (mRemoveInputPadding ? host_context_lengths[batchIdx] : inputDesc[0].dims.d[1]);
|
||||
}
|
||||
const auto lora_rank = lora_ranks[0];
|
||||
auto bestTactic = mPluginProfiler->getBestConfig(M, mGemmId);
|
||||
|
||||
const auto N = lora_rank;
|
||||
|
||||
TLLM_CHECK_WITH_INFO(N <= mMaxLowRank,
|
||||
fmtstr("Invalid low_rank (%d). low_rank must be smaller than mMaxLowRank (%d)", N, mMaxLowRank));
|
||||
const auto K = mTransA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1]; // input hidden size
|
||||
const auto N2 = outputDesc[0].dims.d[nbDimsA - 1];
|
||||
// [M, K] -> [M, N] -> [M, N2]
|
||||
|
||||
void* lora_in_weight = reinterpret_cast<void*>(lora_weights_ptr[0]);
|
||||
void* lora_out_weight = reinterpret_cast<void*>(lora_weights_ptr[1]);
|
||||
const void* input = inputs[0];
|
||||
void* output = outputs[0];
|
||||
|
||||
_runGemm(M, N, K, mTransA, mTransB, mType, mCublasWrapper, input, lora_in_weight, lowRankWorkSpace, bestTactic,
|
||||
cublasWorkSpace, stream);
|
||||
|
||||
_runGemm(M, N2, N, mTransA, mTransB, mType, mCublasWrapper, lowRankWorkSpace, lora_out_weight, output,
|
||||
bestTactic, cublasWorkSpace, stream);
|
||||
}
|
||||
else if (UseGroupedGemm)
|
||||
{
|
||||
int handled_token_num = 0;
|
||||
std::vector<cutlass::gemm::GemmCoord> problem_sizes;
|
||||
problem_sizes.reserve(batch_size);
|
||||
std::vector<void*> ptrA;
|
||||
ptrA.reserve(batch_size);
|
||||
std::vector<void*> ptrB;
|
||||
ptrB.reserve(batch_size);
|
||||
std::vector<void*> ptrC;
|
||||
ptrC.reserve(batch_size);
|
||||
std::vector<void*> ptrD;
|
||||
ptrD.reserve(batch_size);
|
||||
|
||||
std::vector<cutlass::gemm::GemmCoord> problem_sizes_2;
|
||||
problem_sizes_2.reserve(batch_size);
|
||||
std::vector<void*> ptrA_2;
|
||||
ptrA_2.reserve(batch_size);
|
||||
std::vector<void*> ptrB_2;
|
||||
ptrB_2.reserve(batch_size);
|
||||
std::vector<void*> ptrC_2;
|
||||
ptrC_2.reserve(batch_size);
|
||||
std::vector<void*> ptrD_2;
|
||||
ptrD_2.reserve(batch_size);
|
||||
for (int batchIdx = 0; batchIdx < batch_size; batchIdx++)
|
||||
{
|
||||
const RequestType reqType = reqTypes[batchIdx];
|
||||
const auto M = (reqType != RequestType::kCONTEXT)
|
||||
? 1
|
||||
: (mRemoveInputPadding ? host_context_lengths[batchIdx] : inputDesc[0].dims.d[1]);
|
||||
const auto lora_rank = lora_ranks[batchIdx];
|
||||
const auto N = lora_rank;
|
||||
if (N > 0)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(N <= mMaxLowRank,
|
||||
fmtstr("Invalid low_rank (%d). low_rank must be smaller than mMaxLowRank (%d)", N, mMaxLowRank));
|
||||
const auto K = mTransA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1]; // input hidden size
|
||||
|
||||
TLLM_CHECK_WITH_INFO(N <= mMaxLowRank,
|
||||
fmtstr("Invalid low_rank (%d). low_rank must be smaller than mMaxLowRank (%d)", N, mMaxLowRank));
|
||||
const auto K = mTransA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1]; // input hidden size
|
||||
const auto N2 = outputDesc[0].dims.d[nbDimsA - 1];
|
||||
// [M, K] -> [M, N] -> [M, N2]
|
||||
cutlass::gemm::GemmCoord problem(M, N, K);
|
||||
problem_sizes.push_back(problem);
|
||||
|
||||
void* lora_in_weight = reinterpret_cast<void*>(lora_weights_ptr[batchIdx * 2 + 0]);
|
||||
void* lora_out_weight = reinterpret_cast<void*>(lora_weights_ptr[batchIdx * 2 + 1]);
|
||||
const void* input
|
||||
= static_cast<const void*>(static_cast<const char*>(inputs[0]) + handled_token_num * K * typeSize);
|
||||
void* output = static_cast<void*>(static_cast<char*>(outputs[0]) + handled_token_num * N2 * typeSize);
|
||||
_runGemm(M, N, K, mTransA, mTransB, mType, mCublasWrapper, input, lora_in_weight, lowRankWorkSpace,
|
||||
bestTactic, cublasWorkSpace, stream);
|
||||
ptrA.push_back(static_cast<void*>(
|
||||
static_cast<char*>(const_cast<void*>(inputs[0])) + handled_token_num * K * typeSize));
|
||||
ptrB.push_back(reinterpret_cast<void*>(lora_weights_ptr[batchIdx * 2 + 0]));
|
||||
ptrC.push_back(static_cast<void*>(
|
||||
static_cast<char*>(lowRankWorkSpace) + handled_token_num * mMaxLowRank * typeSize));
|
||||
ptrD.push_back(static_cast<void*>(
|
||||
static_cast<char*>(lowRankWorkSpace) + handled_token_num * mMaxLowRank * typeSize));
|
||||
|
||||
_runGemm(M, N2, N, mTransA, mTransB, mType, mCublasWrapper, lowRankWorkSpace, lora_out_weight, output,
|
||||
bestTactic, cublasWorkSpace, stream);
|
||||
const auto N2 = outputDesc[0].dims.d[nbDimsA - 1];
|
||||
cutlass::gemm::GemmCoord problem_2(M, N2, N);
|
||||
problem_sizes_2.push_back(problem_2);
|
||||
ptrA_2.push_back(static_cast<void*>(
|
||||
static_cast<char*>(lowRankWorkSpace) + handled_token_num * mMaxLowRank * typeSize));
|
||||
ptrB_2.push_back(reinterpret_cast<void*>(lora_weights_ptr[batchIdx * 2 + 1]));
|
||||
ptrC_2.push_back(
|
||||
static_cast<void*>(static_cast<char*>(outputs[0]) + handled_token_num * N2 * typeSize));
|
||||
ptrD_2.push_back(
|
||||
static_cast<void*>(static_cast<char*>(outputs[0]) + handled_token_num * N2 * typeSize));
|
||||
}
|
||||
handled_token_num += M;
|
||||
}
|
||||
handled_token_num += M;
|
||||
tensorrt_llm::kernels::run_cutlass_1(problem_sizes, ptrA, ptrB, ptrC, ptrD, cutlassWorkSpace,
|
||||
cutlassWorkSpaceSize, cublasWorkSpace, CUBLAS_WORKSPACE_SIZE, stream);
|
||||
sync_check_cuda_error();
|
||||
tensorrt_llm::kernels::run_cutlass_2(problem_sizes_2, ptrA_2, ptrB_2, ptrC_2, ptrD_2, cutlassWorkSpace,
|
||||
cutlassWorkSpaceSize, cublasWorkSpace, CUBLAS_WORKSPACE_SIZE, stream);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
else
|
||||
{
|
||||
if (streams.size() != batch_size)
|
||||
{
|
||||
for (int i = 0; i < batch_size; i++)
|
||||
{
|
||||
// TLLM_LOG_INFO("allocate %d stream and handle", i);
|
||||
cudaStream_t stream_;
|
||||
cublasHandle_t handle;
|
||||
if (i == 0)
|
||||
{
|
||||
stream_ = stream;
|
||||
}
|
||||
else
|
||||
{
|
||||
cudaStreamCreate(&stream_);
|
||||
}
|
||||
cublasCreate(&handle);
|
||||
cublasSetStream(handle, stream_);
|
||||
cublas_handles.push_back(handle);
|
||||
streams.push_back(stream_);
|
||||
cudaStreamSynchronize(stream_);
|
||||
}
|
||||
}
|
||||
|
||||
if (!graphCreated)
|
||||
{
|
||||
cudaGraph_t graph;
|
||||
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
|
||||
std::vector<cudaEvent_t> events;
|
||||
events.reserve(batch_size);
|
||||
|
||||
for (int batchIdx = 0; batchIdx < batch_size; batchIdx++)
|
||||
{
|
||||
cublasSetStream(cublas_handles[batchIdx], streams[batchIdx]);
|
||||
cudaEventCreate(&events[batchIdx]);
|
||||
|
||||
const RequestType reqType = reqTypes[batchIdx];
|
||||
const auto M = (reqType != RequestType::kCONTEXT)
|
||||
? 1
|
||||
: (mRemoveInputPadding ? host_context_lengths[batchIdx] : inputDesc[0].dims.d[1]);
|
||||
const auto lora_rank = lora_ranks[batchIdx];
|
||||
|
||||
if (lora_rank <= 0)
|
||||
{
|
||||
const auto N = outputDesc[0].dims.d[outputDesc[0].dims.nbDims - 1];
|
||||
void* output
|
||||
= static_cast<void*>(static_cast<char*>(outputs[0]) + handled_token_num * N * typeSize);
|
||||
if (typeSize == 2)
|
||||
{
|
||||
deviceFill((half*) output, M * N, (half) 0.0f, streams[batchIdx]);
|
||||
}
|
||||
else
|
||||
{
|
||||
deviceFill((float*) output, M * N, 0.0f, streams[batchIdx]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// the input shape should be [1, token_num, K] under remove_input_padding,
|
||||
// [batch, seqlen, K] under non-remove_input_padding
|
||||
auto bestTactic = mPluginProfiler->getBestConfig(M, mGemmId);
|
||||
|
||||
const auto N = lora_rank;
|
||||
|
||||
TLLM_CHECK_WITH_INFO(N <= mMaxLowRank,
|
||||
fmtstr(
|
||||
"Invalid low_rank (%d). low_rank must be smaller than mMaxLowRank (%d)", N, mMaxLowRank));
|
||||
const auto K
|
||||
= mTransA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1]; // input hidden size
|
||||
const auto N2 = outputDesc[0].dims.d[nbDimsA - 1];
|
||||
// [M, K] -> [M, N] -> [M, N2]
|
||||
|
||||
void* lora_in_weight = reinterpret_cast<void*>(lora_weights_ptr[batchIdx * 2 + 0]);
|
||||
void* lora_out_weight = reinterpret_cast<void*>(lora_weights_ptr[batchIdx * 2 + 1]);
|
||||
const void* input = static_cast<const void*>(
|
||||
static_cast<const char*>(inputs[0]) + handled_token_num * K * typeSize);
|
||||
void* output
|
||||
= static_cast<void*>(static_cast<char*>(outputs[0]) + handled_token_num * N2 * typeSize);
|
||||
|
||||
if (batchIdx > 0)
|
||||
{
|
||||
cudaStreamWaitEvent(streams[batchIdx], events[0]);
|
||||
}
|
||||
|
||||
runCublasGemmEx(
|
||||
M, N, K, mTransA, mTransB, input, lora_in_weight, lowRankWorkSpace, cublas_handles[batchIdx]);
|
||||
runCublasGemmEx(M, N2, N, mTransA, mTransB, lowRankWorkSpace, lora_out_weight, output,
|
||||
cublas_handles[batchIdx]);
|
||||
|
||||
cudaEventRecord(events[batchIdx], streams[batchIdx]);
|
||||
}
|
||||
handled_token_num += M;
|
||||
|
||||
cudaStreamWaitEvent(stream, events[batchIdx], 0);
|
||||
}
|
||||
|
||||
cudaStreamEndCapture(stream, &graph);
|
||||
cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);
|
||||
graphCreated = true;
|
||||
}
|
||||
cudaGraphLaunch(instance, stream);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -44,7 +44,7 @@ public:
|
||||
|
||||
LoraPlugin(const void* data, size_t length, const PluginProfilerPtr& profiler);
|
||||
|
||||
~LoraPlugin() override = default;
|
||||
~LoraPlugin(); // override = default;
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
|
||||
@ -132,6 +132,9 @@ private:
|
||||
GemmIdCublas mGemmId{};
|
||||
|
||||
PluginProfilerPtr mPluginProfiler;
|
||||
|
||||
bool graphCreated = false;
|
||||
cudaGraphExec_t instance;
|
||||
};
|
||||
|
||||
class LoraPluginCreator : public BaseCreator
|
||||
|
||||
@ -19,12 +19,16 @@
|
||||
#include "namedTensor.h"
|
||||
#include "tensorrt_llm/batch_manager/GptManager.h"
|
||||
#include "tensorrt_llm/batch_manager/callbacks.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/pybind/utils/pathCaster.h"
|
||||
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <ATen/ops/tensor.h>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
@ -51,8 +55,8 @@ py::object GptManager::enter()
|
||||
void GptManager::exit(py::handle type, py::handle value, py::handle traceback)
|
||||
{
|
||||
// NOTE: we must release the GIL here. GptManager has spawned a thread for the execution loop. That thread must be
|
||||
// able to do forward progress for the shutdown process to succeed. For that, we must manually release the GIL while
|
||||
// waiting in `process.join()`.
|
||||
// able to do forward progress for the shutdown process to succeed. It takes the GIL during its callbacks, so
|
||||
// we release it now. Note that we shouldn't do anything related to python objects after that.
|
||||
py::gil_scoped_release release;
|
||||
shutdown();
|
||||
}
|
||||
@ -85,4 +89,21 @@ tb::SendResponseCallback callbackAdapter(SendResponseCallback callback)
|
||||
callback(id, pythonList, isOk, errMsg);
|
||||
};
|
||||
}
|
||||
|
||||
void GptManager::initBindings(py::module_& m)
|
||||
{
|
||||
py::class_<GptManager>(m, "GptManager")
|
||||
.def(py::init<std::filesystem::path const&, tb::TrtGptModelType, int32_t, tb::batch_scheduler::SchedulerPolicy,
|
||||
GetInferenceRequestsCallback, SendResponseCallback, tb::PollStopSignalCallback,
|
||||
tb::ReturnBatchManagerStatsCallback, const tb::TrtGptModelOptionalParams&, std::optional<uint64_t>>(),
|
||||
py::arg("trt_engine_path"), py::arg("model_type"), py::arg("max_beam_width"), py::arg("scheduler_policy"),
|
||||
py::arg("get_inference_requests_cb"), py::arg("send_response_cb"), py::arg("poll_stop_signal_cb") = nullptr,
|
||||
py::arg("return_batch_manager_stats_cb") = nullptr,
|
||||
py::arg_v("optional_params", tb::TrtGptModelOptionalParams(), "TrtGptModelOptionalParams"),
|
||||
py::arg("terminate_req_id") = std::nullopt)
|
||||
.def("shutdown", &GptManager::exit)
|
||||
.def("__enter__", &GptManager::enter)
|
||||
.def("__exit__", &GptManager::exit);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::pybind::batch_manager
|
||||
|
||||
@ -25,30 +25,31 @@
|
||||
#include <ATen/ops/tensor.h>
|
||||
#include <functional>
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
|
||||
using GetInferenceRequestsCallback = std::function<std::list<InferenceRequest>(int32_t)>;
|
||||
using SendResponseCallback = std::function<void(uint64_t, std::list<NamedTensor> const&, bool, const std::string&)>;
|
||||
|
||||
tb::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback callback);
|
||||
tb::SendResponseCallback callbackAdapter(SendResponseCallback callback);
|
||||
tensorrt_llm::batch_manager::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback callback);
|
||||
tensorrt_llm::batch_manager::SendResponseCallback callbackAdapter(SendResponseCallback callback);
|
||||
|
||||
class GptManager : tb::GptManager
|
||||
class GptManager : tensorrt_llm::batch_manager::GptManager
|
||||
{
|
||||
public:
|
||||
GptManager(std::filesystem::path const& trtEnginePath, tb::TrtGptModelType modelType, int32_t maxBeamWidth,
|
||||
tb::batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback getInferenceRequestsCb,
|
||||
SendResponseCallback sendResponseCb, tb::PollStopSignalCallback pollStopSignalCb = nullptr,
|
||||
tb::ReturnBatchManagerStatsCallback returnBatchManagerStatsCb = nullptr,
|
||||
const tb::TrtGptModelOptionalParams& optionalParams = tb::TrtGptModelOptionalParams(),
|
||||
GptManager(std::filesystem::path const& trtEnginePath, tensorrt_llm::batch_manager::TrtGptModelType modelType,
|
||||
int32_t maxBeamWidth, tensorrt_llm::batch_manager::batch_scheduler::SchedulerPolicy schedulerPolicy,
|
||||
GetInferenceRequestsCallback getInferenceRequestsCb, SendResponseCallback sendResponseCb,
|
||||
tensorrt_llm::batch_manager::PollStopSignalCallback pollStopSignalCb = nullptr,
|
||||
tensorrt_llm::batch_manager::ReturnBatchManagerStatsCallback returnBatchManagerStatsCb = nullptr,
|
||||
const tensorrt_llm::batch_manager::TrtGptModelOptionalParams& optionalParams
|
||||
= tensorrt_llm::batch_manager::TrtGptModelOptionalParams(),
|
||||
std::optional<uint64_t> terminateReqId = std::nullopt);
|
||||
|
||||
py::object enter();
|
||||
void exit(py::handle type, py::handle value, py::handle traceback);
|
||||
pybind11::object enter();
|
||||
void exit(pybind11::handle type, pybind11::handle value, pybind11::handle traceback);
|
||||
|
||||
static void initBindings(pybind11::module_& m);
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::pybind::batch_manager
|
||||
|
||||
@ -20,6 +20,11 @@
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
#include <memory>
|
||||
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
|
||||
@ -48,5 +53,46 @@ std::shared_ptr<tb::InferenceRequest> InferenceRequest::toTrtLlm() const
|
||||
tensorMap[name] = tr::TorchView::of(tensor.value());
|
||||
}
|
||||
}
|
||||
return std::make_shared<tb::InferenceRequest>(tensorMap, mRequestId);
|
||||
auto inferenceRequest = std::make_shared<tb::InferenceRequest>(tensorMap, mRequestId);
|
||||
inferenceRequest->setIsStreaming(isStreaming());
|
||||
return inferenceRequest;
|
||||
}
|
||||
|
||||
void InferenceRequest::initBindings(py::module_& m)
|
||||
{
|
||||
py::class_<InferenceRequest>(m, "InferenceRequest")
|
||||
.def(py::init<uint64_t>())
|
||||
.def(py::init<uint64_t, InferenceRequest::TensorMap const&>(), "deprecated: use direct tensor access instead")
|
||||
.def_property("input_ids", &InferenceRequest::getInputIdsUnchecked, &InferenceRequest::setInputIds)
|
||||
.def_property(
|
||||
"draft_input_ids", &InferenceRequest::getDraftInputIdsUnchecked, &InferenceRequest::setDraftInputIds)
|
||||
.def_property("draft_logits", &InferenceRequest::getDraftLogitsUnchecked, &InferenceRequest::setDraftLogits)
|
||||
.def_property("max_new_tokens", &InferenceRequest::getMaxNewTokensUnchecked, &InferenceRequest::setMaxNewTokens)
|
||||
.def_property("beam_width", &InferenceRequest::getBeamWidthUnchecked, &InferenceRequest::setBeamWidth)
|
||||
.def_property("end_id", &InferenceRequest::getEndIdUnchecked, &InferenceRequest::setEndId)
|
||||
.def_property("pad_id", &InferenceRequest::getPadIdUnchecked, &InferenceRequest::setPadId)
|
||||
.def_property("bad_words_list", &InferenceRequest::getBadWordsListUnchecked, &InferenceRequest::setBadWordsList)
|
||||
.def_property(
|
||||
"stop_words_list", &InferenceRequest::getStopWordsListUnchecked, &InferenceRequest::setStopWordsList)
|
||||
.def_property(
|
||||
"embedding_bias", &InferenceRequest::getEmbeddingBiasUnchecked, &InferenceRequest::setEmbeddingBias)
|
||||
.def_property("temperature", &InferenceRequest::getTemperatureUnchecked, &InferenceRequest::setTemperature)
|
||||
.def_property("runtime_top_k", &InferenceRequest::getRuntimeTopKUnchecked, &InferenceRequest::setRuntimeTopK)
|
||||
.def_property("runtime_top_p", &InferenceRequest::getRuntimeTopPUnchecked, &InferenceRequest::setRuntimeTopP)
|
||||
.def_property(
|
||||
"length_penalty", &InferenceRequest::getLengthPenaltyUnchecked, &InferenceRequest::setLengthPenalty)
|
||||
.def_property("repetition_penalty", &InferenceRequest::getRepetitionPenaltyUnchecked,
|
||||
&InferenceRequest::setRepetitionPenalty)
|
||||
.def_property("min_length", &InferenceRequest::getMinLengthUnchecked, &InferenceRequest::setMinLength)
|
||||
.def_property(
|
||||
"presence_penalty", &InferenceRequest::getPresencePenaltyUnchecked, &InferenceRequest::setPresencePenalty)
|
||||
.def_property("random_seed", &InferenceRequest::getRandomSeedUnchecked, &InferenceRequest::setRandomSeed)
|
||||
.def_property(
|
||||
"return_log_probs", &InferenceRequest::getReturnLogProbsUnchecked, &InferenceRequest::setReturnLogProbs)
|
||||
.def_property("prompt_embedding_table", &InferenceRequest::getPromptEmbeddingTableUnchecked,
|
||||
&InferenceRequest::setPromptEmbeddingTable)
|
||||
.def_property(
|
||||
"prompt_vocab_size", &InferenceRequest::getPromptVocabSizeUnchecked, &InferenceRequest::setPromptVocabSize)
|
||||
.def_property("is_streaming", &InferenceRequest::isStreaming, &InferenceRequest::setIsStreaming)
|
||||
.def_property_readonly("request_id", &InferenceRequest::getRequestId);
|
||||
}
|
||||
|
||||
@ -19,9 +19,9 @@
|
||||
|
||||
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/namedTensor.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
@ -53,6 +53,7 @@ public:
|
||||
}
|
||||
|
||||
[[nodiscard]] std::shared_ptr<tensorrt_llm::batch_manager::InferenceRequest> toTrtLlm() const;
|
||||
static void initBindings(pybind11::module_& m);
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::pybind::batch_manager
|
||||
|
||||
@ -21,6 +21,11 @@
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
#include <memory>
|
||||
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
|
||||
@ -53,3 +58,61 @@ std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
|
||||
embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable, mPromptVocabSize, mReturnLogProbs,
|
||||
mDraftTokens, draftLogits);
|
||||
}
|
||||
|
||||
void LlmRequest::initBindings(py::module_& m)
|
||||
{
|
||||
py::class_<LlmRequest>(m, "LlmRequest")
|
||||
.def(py::init<LlmRequest::RequestIdType, LlmRequest::SizeType, LlmRequest::VecTokens, tr::SamplingConfig, bool,
|
||||
std::optional<LlmRequest::SizeType>, std::optional<LlmRequest::SizeType>,
|
||||
std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::TensorPtr>,
|
||||
std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::TensorPtr>,
|
||||
std::optional<LlmRequest::SizeType>, bool, std::optional<LlmRequest::VecTokens>,
|
||||
std::optional<LlmRequest::TensorPtr>>(),
|
||||
py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"),
|
||||
py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt,
|
||||
py::arg("embedding_bias") = std::nullopt, py::arg("bad_words_list") = std::nullopt,
|
||||
py::arg("stop_words_list") = std::nullopt, py::arg("prompt_embedding_table") = std::nullopt,
|
||||
py::arg("prompt_vocab_size") = std::nullopt, py::arg("return_log_probs") = false,
|
||||
py::arg("draft_tokens") = std::nullopt, py::arg("draft_logits") = std::nullopt)
|
||||
.def("get_num_tokens", &LlmRequest::getNumTokens, py::arg("beam"))
|
||||
.def_property_readonly("max_beam_num_tokens", &LlmRequest::getMaxBeamNumTokens)
|
||||
.def("get_token", &LlmRequest::getToken, py::arg("beam"), py::arg("pos"))
|
||||
.def("get_tokens", py::overload_cast<LlmRequest::SizeType>(&LlmRequest::getTokens, py::const_), py::arg("beam"))
|
||||
.def("get_tokens", py::overload_cast<>(&LlmRequest::getTokens, py::const_))
|
||||
.def_property_readonly("max_num_generated_tokens", &LlmRequest::getMaxNumGeneratedTokens)
|
||||
.def("add_new_token", &LlmRequest::addNewToken, py::arg("token"), py::arg("beam"))
|
||||
.def("add_new_tokens", &LlmRequest::addNewTokens, py::arg("beam_tokens"))
|
||||
.def("set_generated_tokens", &LlmRequest::setGeneratedTokens, py::arg("generated_beam_tokens"))
|
||||
.def("pause", &LlmRequest::pause, py::arg("max_input_len"))
|
||||
.def_property("max_sent_token_pos", &LlmRequest::getMaxSentTokenPos, &LlmRequest::setMaxSentTokenPos)
|
||||
.def_property_readonly("prompt_embedding_table", &LlmRequest::getPromptEmbeddingTable)
|
||||
.def_property_readonly("prompt_vocab_size", &LlmRequest::getPromptVocabSize)
|
||||
.def_property_readonly("embedding_bias", &LlmRequest::getEmbeddingBias)
|
||||
.def_property_readonly("bad_words_list", &LlmRequest::getBadWordsList)
|
||||
.def_property_readonly("stop_words_list", &LlmRequest::getStopWordsList)
|
||||
.def_readwrite("request_id", &LlmRequest::mRequestId)
|
||||
.def_readwrite("prompt_len", &LlmRequest::mPromptLen)
|
||||
.def_readwrite("max_new_tokens", &LlmRequest::mMaxNewTokens)
|
||||
.def_readwrite("sampling_config", &LlmRequest::mSamplingConfig)
|
||||
.def_readwrite("state", &LlmRequest::mState)
|
||||
.def_readwrite("is_streaming", &LlmRequest::mIsStreaming)
|
||||
.def_readwrite("end_id", &LlmRequest::mEndId)
|
||||
.def_readwrite("pad_id", &LlmRequest::mPadId)
|
||||
.def_readwrite("seq_slot", &LlmRequest::mSeqSlot)
|
||||
.def_property_readonly("return_log_probs", &LlmRequest::returnLogProbs)
|
||||
.def_property_readonly("log_probs", py::overload_cast<>(&LlmRequest::getLogProbs, py::const_))
|
||||
.def("get_log_probs", py::overload_cast<SizeType>(&LlmRequest::getLogProbs, py::const_))
|
||||
.def("set_log_probs", &LlmRequest::setLogProbs, py::arg("log_probs"), py::arg("beam"))
|
||||
.def_property_readonly("cum_log_probs", &LlmRequest::getCumLogProbs)
|
||||
.def("set_cum_log_prob", &LlmRequest::setCumLogProb, py::arg("cum_log_prob"), py::arg("beam"))
|
||||
.def_property_readonly("orig_prompt_len", &LlmRequest::getOrigPromptLen)
|
||||
.def("has_draft_tokens", &LlmRequest::hasDraftTokens)
|
||||
.def_property(
|
||||
"draft_tokens", [](LlmRequest& self) { return *self.getDraftTokens(); },
|
||||
[](LlmRequest& self, LlmRequest::VecTokens& draftTokens)
|
||||
{ self.setDraftTokens(std::make_shared<LlmRequest::VecTokens>(std::move(draftTokens))); })
|
||||
.def_property(
|
||||
"draft_logits", [](LlmRequest& self) { return self.getDraftLogits(); },
|
||||
[](LlmRequest& self, LlmRequest::TensorPtr& logits)
|
||||
{ self.setDraftLogits(std::make_optional<LlmRequest::TensorPtr>(logits)); });
|
||||
}
|
||||
|
||||
@ -18,13 +18,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <ATen/ops/tensor.h>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
@ -58,6 +57,7 @@ public:
|
||||
}
|
||||
|
||||
[[nodiscard]] std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> toTrtLlm() const;
|
||||
static void initBindings(pybind11::module_& m);
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::pybind::batch_manager
|
||||
|
||||
@ -16,9 +16,14 @@
|
||||
*/
|
||||
#include "namedTensor.h"
|
||||
|
||||
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
|
||||
#include "tensorrt_llm/runtime/torch.h"
|
||||
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager
|
||||
@ -29,4 +34,12 @@ NamedTensor::NamedTensor(const tb::NamedTensor& cppNamedTensor)
|
||||
{
|
||||
}
|
||||
|
||||
void NamedTensor::initBindings(py::module_& m)
|
||||
{
|
||||
py::class_<NamedTensor>(m, "NamedTensor")
|
||||
.def(py::init<NamedTensor::TensorPtr, std::string>(), py::arg("tensor"), py::arg("name"))
|
||||
.def_readwrite("tensor", &NamedTensor::tensor)
|
||||
.def_readonly("name", &NamedTensor::name);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::pybind::batch_manager
|
||||
|
||||
@ -18,28 +18,19 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/namedTensor.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
#include <ATen/ops/from_blob.h>
|
||||
#include <ATen/ops/tensor.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
|
||||
class NamedTensor : public tb::GenericNamedTensor<std::optional<at::Tensor>>
|
||||
class NamedTensor : public tensorrt_llm::batch_manager::GenericNamedTensor<std::optional<at::Tensor>>
|
||||
{
|
||||
public:
|
||||
using Base = tb::GenericNamedTensor<std::optional<at::Tensor>>;
|
||||
using Base = tensorrt_llm::batch_manager::GenericNamedTensor<std::optional<at::Tensor>>;
|
||||
using TensorPtr = Base::TensorPtr;
|
||||
|
||||
NamedTensor(TensorPtr _tensor, std::string _name)
|
||||
@ -48,7 +39,8 @@ public:
|
||||
explicit NamedTensor(std::string _name)
|
||||
: Base(std::move(_name)){};
|
||||
|
||||
explicit NamedTensor(const tb::NamedTensor& cppNamedTensor);
|
||||
explicit NamedTensor(const tensorrt_llm::batch_manager::NamedTensor& cppNamedTensor);
|
||||
static void initBindings(pybind11::module_& m);
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::pybind::batch_manager
|
||||
|
||||
@ -58,45 +58,18 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
{
|
||||
m.doc() = "TensorRT-LLM Python bindings for C++ runtime";
|
||||
|
||||
py::class_<tpr::PromptTuningParams>(m, "PromptTuningParams")
|
||||
.def(py::init<tpr::PromptTuningParams::TensorPtr, tpr::PromptTuningParams::TensorPtr,
|
||||
tpr::PromptTuningParams::TensorPtr>(),
|
||||
py::arg("embedding_table") = py::none(), py::arg("tasks") = py::none(), py::arg("vocab_size") = py::none())
|
||||
.def_readwrite("embedding_table", &tpr::PromptTuningParams::embeddingTable)
|
||||
.def_readwrite("tasks", &tpr::PromptTuningParams::tasks)
|
||||
.def_readwrite("vocab_size", &tpr::PromptTuningParams::vocabSize)
|
||||
.def_readwrite("prompt_tuning_enabled", &tpr::PromptTuningParams::promptTuningEnabled);
|
||||
|
||||
py::class_<tpr::GenerationInput>(m, "GenerationInput")
|
||||
.def(py::init<SizeType, SizeType, tpr::GenerationInput::TensorPtr, tpr::GenerationInput::TensorPtr, bool>(),
|
||||
py::arg("end_id"), py::arg("pad_id"), py::arg("ids"), py::arg("lengths"), py::arg("packed") = false)
|
||||
.def_readwrite("end_id", &tpr::GenerationInput::endId)
|
||||
.def_readwrite("pad_id", &tpr::GenerationInput::padId)
|
||||
.def_readwrite("ids", &tpr::GenerationInput::ids)
|
||||
.def_readwrite("lengths", &tpr::GenerationInput::lengths)
|
||||
.def_readwrite("packed", &tpr::GenerationInput::packed)
|
||||
.def_readwrite("embedding_bias", &tpr::GenerationInput::embeddingBias)
|
||||
.def_readwrite("bad_words_list", &tpr::GenerationInput::badWordsList)
|
||||
.def_readwrite("stop_words_list", &tpr::GenerationInput::stopWordsList)
|
||||
.def_readwrite("max_new_tokens", &tpr::GenerationInput::maxNewTokens)
|
||||
.def_readwrite("prompt_tuning_params", &tpr::GenerationInput::promptTuningParams);
|
||||
|
||||
py::class_<tpr::GenerationOutput>(m, "GenerationOutput")
|
||||
.def(py::init<tpr::GenerationOutput::TensorPtr, tpr::GenerationOutput::TensorPtr>(), py::arg("ids"),
|
||||
py::arg("lengths"))
|
||||
.def_readwrite("ids", &tpr::GenerationOutput::ids)
|
||||
.def_readwrite("lengths", &tpr::GenerationOutput::lengths)
|
||||
.def_readwrite("log_probs", &tpr::GenerationOutput::logProbs)
|
||||
.def_readwrite("context_logits", &tpr::GenerationOutput::contextLogits)
|
||||
.def_readwrite("on_token_generated", &tpr::GenerationOutput::onTokenGenerated);
|
||||
tpr::PromptTuningParams::initBindings(m);
|
||||
tpr::GenerationInput::initBindings(m);
|
||||
tpr::GenerationOutput::initBindings(m);
|
||||
|
||||
py::class_<tbk::KvCacheConfig>(m, "KvCacheConfig")
|
||||
.def(py::init<std::optional<SizeType>, std::optional<SizeType>, std::optional<float>>(),
|
||||
py::arg("max_tokens") = py::none(), py::arg("max_kv_cache_length") = py::none(),
|
||||
py::arg("free_gpu_memory_fraction") = py::none())
|
||||
.def(py::init<std::optional<SizeType>, std::optional<SizeType>, std::optional<float>, bool>(),
|
||||
py::arg("max_tokens") = py::none(), py::arg("max_attention_window") = py::none(),
|
||||
py::arg("free_gpu_memory_fraction") = py::none(), py::arg("enable_block_reuse") = false)
|
||||
.def_readwrite("max_tokens", &tbk::KvCacheConfig::maxTokens)
|
||||
.def_readwrite("max_kv_cache_length", &tbk::KvCacheConfig::maxKvCacheLength)
|
||||
.def_readwrite("free_gpu_memory_fraction", &tbk::KvCacheConfig::freeGpuMemoryFraction);
|
||||
.def_readwrite("max_attention_window", &tbk::KvCacheConfig::maxAttentionWindow)
|
||||
.def_readwrite("free_gpu_memory_fraction", &tbk::KvCacheConfig::freeGpuMemoryFraction)
|
||||
.def_readwrite("enable_block_reuse", &tbk::KvCacheConfig::enableBlockReuse);
|
||||
|
||||
py::class_<tr::GptSession::Config>(m, "GptSessionConfig")
|
||||
.def(py::init<SizeType, SizeType, SizeType>(), py::arg("max_batch_size"), py::arg("max_beam_width"),
|
||||
@ -184,9 +157,13 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.def_property("quant_mode", &tr::GptModelConfig::getQuantMode, &tr::GptModelConfig::setQuantMode)
|
||||
.def_property_readonly("supports_inflight_batching", &tr::GptModelConfig::supportsInflightBatching)
|
||||
.def_property("max_batch_size", &tr::GptModelConfig::getMaxBatchSize, &tr::GptModelConfig::setMaxBatchSize)
|
||||
.def_property("max_beam_width", &tr::GptModelConfig::getMaxBeamWidth, &tr::GptModelConfig::setMaxBeamWidth)
|
||||
.def_property("max_input_len", &tr::GptModelConfig::getMaxInputLen, &tr::GptModelConfig::setMaxInputLen)
|
||||
.def_property("max_output_len", &tr::GptModelConfig::getMaxOutputLen, &tr::GptModelConfig::setMaxOutputLen)
|
||||
.def_property("max_num_tokens", &tr::GptModelConfig::getMaxNumTokens, &tr::GptModelConfig::setMaxNumTokens)
|
||||
.def_property("max_prompt_embedding_table_size", &tr::GptModelConfig::getMaxPromptEmbeddingTableSize,
|
||||
&tr::GptModelConfig::setMaxPromptEmbeddingTableSize)
|
||||
.def_property_readonly("use_prompt_tuning", &tr::GptModelConfig::usePromptTuning)
|
||||
.def_property("compute_context_logits",
|
||||
py::overload_cast<>(&tr::GptModelConfig::computeContextLogits, py::const_),
|
||||
py::overload_cast<bool>(&tr::GptModelConfig::computeContextLogits))
|
||||
@ -212,9 +189,10 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.def_property_readonly("pipeline_parallel_rank", &tr::WorldConfig::getPipelineParallelRank)
|
||||
.def_property_readonly("tensor_parallel_rank", &tr::WorldConfig::getTensorParallelRank)
|
||||
.def_static("mpi",
|
||||
py::overload_cast<SizeType, std::optional<SizeType>, std::optional<SizeType>>(&tr::WorldConfig::mpi),
|
||||
py::overload_cast<SizeType, std::optional<SizeType>, std::optional<SizeType>,
|
||||
std::optional<std::vector<SizeType>>>(&tr::WorldConfig::mpi),
|
||||
py::arg("gpus_per_node") = tr::WorldConfig::kDefaultGpusPerNode, py::arg("tensor_parallelism") = py::none(),
|
||||
py::arg("pipeline_parallelism") = py::none());
|
||||
py::arg("pipeline_parallelism") = py::none(), py::arg("user_specified_device_ids") = py::none());
|
||||
|
||||
py::class_<tr::SamplingConfig>(m, "SamplingConfig")
|
||||
.def(py::init<SizeType>(), py::arg("beam_width") = 1)
|
||||
@ -233,14 +211,15 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.def_readwrite("length_penalty", &tr::SamplingConfig::lengthPenalty);
|
||||
|
||||
py::class_<tr::GptJsonConfig>(m, "GptJsonConfig")
|
||||
.def(py::init<std::string, std::string, SizeType, SizeType, tr::GptModelConfig>(), py::arg("name"),
|
||||
py::arg("precision"), py::arg("tensor_parallelism"), py::arg("pipeline_parallelism"),
|
||||
.def(py::init<std::string, std::string, std::string, SizeType, SizeType, tr::GptModelConfig>(), py::arg("name"),
|
||||
py::arg("version"), py::arg("precision"), py::arg("tensor_parallelism"), py::arg("pipeline_parallelism"),
|
||||
py::arg("model_config"))
|
||||
.def_static("parse", py::overload_cast<std::string const&>(&tr::GptJsonConfig::parse), py::arg("json"))
|
||||
.def_static(
|
||||
"parse_file", py::overload_cast<std::filesystem::path const&>(&tr::GptJsonConfig::parse), py::arg("path"))
|
||||
.def_property_readonly("model_config", &tr::GptJsonConfig::getModelConfig)
|
||||
.def_property_readonly("name", &tr::GptJsonConfig::getName)
|
||||
.def_property_readonly("version", &tr::GptJsonConfig::getVersion)
|
||||
.def_property_readonly("precision", &tr::GptJsonConfig::getPrecision)
|
||||
.def_property_readonly("tensor_parallelism", &tr::GptJsonConfig::getTensorParallelism)
|
||||
.def_property_readonly("pipeline_parallelism", &tr::GptJsonConfig::getPipelineParallelism)
|
||||
@ -254,6 +233,14 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
py::arg("world_config"));
|
||||
|
||||
py::class_<tr::GptSession>(m, "GptSession")
|
||||
.def(py::init(
|
||||
[](tr::GptSession::Config const& config, tr::GptModelConfig const& modelConfig,
|
||||
tr::WorldConfig const& worldConfig, py::bytearray const& bytes)
|
||||
{
|
||||
auto buf = static_cast<std::string>(bytes);
|
||||
return tr::GptSession{config, modelConfig, worldConfig, buf.data(), buf.size()};
|
||||
}),
|
||||
py::arg("config"), py::arg("model_config"), py::arg("world_config"), py::arg("engine_buffer"))
|
||||
.def(py::init<tr::GptSession::Config, tr::GptModelConfig, tr::WorldConfig, std::string>(), py::arg("config"),
|
||||
py::arg("model_config"), py::arg("world_config"), py::arg("engine_file"))
|
||||
.def_property_readonly("model_config", &tr::GptSession::getModelConfig)
|
||||
@ -272,65 +259,8 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.value("REQUEST_STATE_GENERATION_IN_PROGRESS", tb::LlmRequestState_t::REQUEST_STATE_GENERATION_IN_PROGRESS)
|
||||
.value("REQUEST_STATE_GENERATION_COMPLETE", tb::LlmRequestState_t::REQUEST_STATE_GENERATION_COMPLETE);
|
||||
|
||||
using LlmRequest = tpb::LlmRequest;
|
||||
py::class_<LlmRequest>(m, "LlmRequest")
|
||||
.def(py::init<LlmRequest::RequestIdType, LlmRequest::SizeType, std::vector<LlmRequest::TokenIdType>,
|
||||
tr::SamplingConfig, bool, std::optional<LlmRequest::SizeType>, std::optional<LlmRequest::SizeType>,
|
||||
std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::TensorPtr>,
|
||||
std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::TensorPtr>,
|
||||
std::optional<LlmRequest::SizeType>, bool, std::optional<LlmRequest::VecTokens>,
|
||||
std::optional<LlmRequest::TensorPtr>>(),
|
||||
py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"),
|
||||
py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt,
|
||||
py::arg("embedding_bias") = std::nullopt, py::arg("bad_words_list") = std::nullopt,
|
||||
py::arg("stop_words_list") = std::nullopt, py::arg("prompt_embedding_table") = std::nullopt,
|
||||
py::arg("prompt_vocab_size") = std::nullopt, py::arg("return_log_probs") = false,
|
||||
py::arg("draft_tokens") = std::nullopt, py::arg("draft_logits") = std::nullopt)
|
||||
.def("get_num_tokens", &LlmRequest::getNumTokens, py::arg("beam"))
|
||||
.def_property_readonly("max_beam_num_tokens", &LlmRequest::getMaxBeamNumTokens)
|
||||
.def("get_token", &LlmRequest::getToken, py::arg("beam"), py::arg("pos"))
|
||||
.def("get_tokens", &LlmRequest::getTokens, py::arg("beam"))
|
||||
.def_property_readonly("max_num_generated_tokens", &LlmRequest::getMaxNumGeneratedTokens)
|
||||
.def("add_new_token", &LlmRequest::addNewToken, py::arg("token"), py::arg("beam"))
|
||||
.def("add_new_tokens", &LlmRequest::addNewTokens, py::arg("beam_tokens"))
|
||||
.def("set_generated_tokens", &LlmRequest::setGeneratedTokens, py::arg("generated_beam_tokens"))
|
||||
.def("pause", &LlmRequest::pause, py::arg("max_input_len"))
|
||||
.def_property("max_sent_token_pos", &LlmRequest::getMaxSentTokenPos, &LlmRequest::setMaxSentTokenPos)
|
||||
.def_property_readonly("prompt_embedding_table", &LlmRequest::getPromptEmbeddingTable)
|
||||
.def_property_readonly("prompt_vocab_size", &LlmRequest::getPromptVocabSize)
|
||||
.def_property_readonly("embedding_bias", &LlmRequest::getEmbeddingBias)
|
||||
.def_property_readonly("bad_words_list", &LlmRequest::getBadWordsList)
|
||||
.def_property_readonly("stop_words_list", &LlmRequest::getStopWordsList)
|
||||
.def_readwrite("request_id", &LlmRequest::mRequestId)
|
||||
.def_readwrite("prompt_len", &LlmRequest::mPromptLen)
|
||||
.def_readwrite("max_new_tokens", &LlmRequest::mMaxNewTokens)
|
||||
.def_readwrite("sampling_config", &LlmRequest::mSamplingConfig)
|
||||
.def_readwrite("state", &LlmRequest::mState)
|
||||
.def_readwrite("is_streaming", &LlmRequest::mIsStreaming)
|
||||
.def_readwrite("end_id", &LlmRequest::mEndId)
|
||||
.def_readwrite("pad_id", &LlmRequest::mPadId)
|
||||
.def_readwrite("batch_slot", &LlmRequest::mBatchSlot)
|
||||
.def_property_readonly("return_log_probs", &LlmRequest::returnLogProbs)
|
||||
.def_property_readonly("log_probs", py::overload_cast<>(&LlmRequest::getLogProbs, py::const_))
|
||||
.def("get_log_probs", py::overload_cast<SizeType>(&LlmRequest::getLogProbs, py::const_))
|
||||
.def("set_log_probs", &LlmRequest::setLogProbs, py::arg("log_probs"), py::arg("beam"))
|
||||
.def_property_readonly("cum_log_probs", &LlmRequest::getCumLogProbs)
|
||||
.def("set_cum_log_prob", &LlmRequest::setCumLogProb, py::arg("cum_log_prob"), py::arg("beam"))
|
||||
.def_property_readonly("orig_prompt_len", &LlmRequest::getOrigPromptLen)
|
||||
.def("has_draft_tokens", &LlmRequest::hasDraftTokens)
|
||||
.def_property(
|
||||
"draft_tokens", [](LlmRequest& self) { return *self.getDraftTokens(); },
|
||||
[](LlmRequest& self, LlmRequest::VecTokens& draftTokens)
|
||||
{ self.setDraftTokens(std::make_shared<LlmRequest::VecTokens>(std::move(draftTokens))); })
|
||||
.def_property(
|
||||
"draft_logits", [](LlmRequest& self) { return self.getDraftLogits(); },
|
||||
[](LlmRequest& self, LlmRequest::TensorPtr& logits)
|
||||
{ self.setDraftLogits(std::make_optional<LlmRequest::TensorPtr>(logits)); });
|
||||
|
||||
py::class_<tpb::NamedTensor>(m, "NamedTensor")
|
||||
.def(py::init<tpb::NamedTensor::TensorPtr, std::string>(), py::arg("tensor"), py::arg("name"))
|
||||
.def_readwrite("tensor", &tpb::NamedTensor::tensor)
|
||||
.def_readonly("name", &tpb::NamedTensor::name);
|
||||
tpb::NamedTensor::initBindings(m);
|
||||
tpb::LlmRequest::initBindings(m);
|
||||
|
||||
auto tensorNames = m.def_submodule("tensor_names");
|
||||
// Input tensor names
|
||||
@ -362,42 +292,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
tensorNames.attr("OUTPUT_LOG_PROBS") = py::str(tb::inference_request::kLogProbsTensorName);
|
||||
tensorNames.attr("CUM_LOG_PROBS") = py::str(tb::inference_request::kCumLogProbsTensorName);
|
||||
|
||||
using InferenceRequest = tpb::InferenceRequest;
|
||||
py::class_<InferenceRequest>(m, "InferenceRequest")
|
||||
.def(py::init<uint64_t>())
|
||||
.def(py::init<uint64_t, InferenceRequest::TensorMap const&>(), "deprecated: use direct tensor access instead")
|
||||
.def_property("input_ids", &InferenceRequest::getInputIdsUnchecked, &InferenceRequest::setInputIds)
|
||||
.def_property(
|
||||
"draft_input_ids", &InferenceRequest::getDraftInputIdsUnchecked, &InferenceRequest::setDraftInputIds)
|
||||
.def_property("draft_logits", &InferenceRequest::getDraftLogitsUnchecked, &InferenceRequest::setDraftLogits)
|
||||
.def_property("max_new_tokens", &InferenceRequest::getMaxNewTokensUnchecked, &InferenceRequest::setMaxNewTokens)
|
||||
.def_property("beam_width", &InferenceRequest::getBeamWidthUnchecked, &InferenceRequest::setBeamWidth)
|
||||
.def_property("end_id", &InferenceRequest::getEndIdUnchecked, &InferenceRequest::setEndId)
|
||||
.def_property("pad_id", &InferenceRequest::getPadIdUnchecked, &InferenceRequest::setPadId)
|
||||
.def_property("bad_words_list", &InferenceRequest::getBadWordsListUnchecked, &InferenceRequest::setBadWordsList)
|
||||
.def_property(
|
||||
"stop_words_list", &InferenceRequest::getStopWordsListUnchecked, &InferenceRequest::setStopWordsList)
|
||||
.def_property(
|
||||
"embedding_bias", &InferenceRequest::getEmbeddingBiasUnchecked, &InferenceRequest::setEmbeddingBias)
|
||||
.def_property("temperature", &InferenceRequest::getTemperatureUnchecked, &InferenceRequest::setTemperature)
|
||||
.def_property("runtime_top_k", &InferenceRequest::getRuntimeTopKUnchecked, &InferenceRequest::setRuntimeTopK)
|
||||
.def_property("runtime_top_p", &InferenceRequest::getRuntimeTopPUnchecked, &InferenceRequest::setRuntimeTopP)
|
||||
.def_property(
|
||||
"length_penalty", &InferenceRequest::getLengthPenaltyUnchecked, &InferenceRequest::setLengthPenalty)
|
||||
.def_property("repetition_penalty", &InferenceRequest::getRepetitionPenaltyUnchecked,
|
||||
&InferenceRequest::setRepetitionPenalty)
|
||||
.def_property("min_length", &InferenceRequest::getMinLengthUnchecked, &InferenceRequest::setMinLength)
|
||||
.def_property(
|
||||
"presence_penalty", &InferenceRequest::getPresencePenaltyUnchecked, &InferenceRequest::setPresencePenalty)
|
||||
.def_property("random_seed", &InferenceRequest::getRandomSeedUnchecked, &InferenceRequest::setRandomSeed)
|
||||
.def_property(
|
||||
"return_log_probs", &InferenceRequest::getReturnLogProbsUnchecked, &InferenceRequest::setReturnLogProbs)
|
||||
.def_property("prompt_embedding_table", &InferenceRequest::getPromptEmbeddingTableUnchecked,
|
||||
&InferenceRequest::setPromptEmbeddingTable)
|
||||
.def_property(
|
||||
"prompt_vocab_size", &InferenceRequest::getPromptVocabSizeUnchecked, &InferenceRequest::setPromptVocabSize)
|
||||
.def_property("is_streaming", &InferenceRequest::isStreaming, &InferenceRequest::setIsStreaming)
|
||||
.def_property_readonly("request_id", &InferenceRequest::getRequestId);
|
||||
tpb::InferenceRequest::initBindings(m);
|
||||
|
||||
py::enum_<tb::TrtGptModelType>(m, "TrtGptModelType")
|
||||
.value("V1", tb::TrtGptModelType::V1)
|
||||
@ -409,22 +304,14 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.value("GUARANTEED_NO_EVICT", tbb::SchedulerPolicy::GUARANTEED_NO_EVICT);
|
||||
|
||||
py::class_<tb::TrtGptModelOptionalParams>(m, "TrtGptModelOptionalParams")
|
||||
.def(py::init<tbk::KvCacheConfig, std::optional<SizeType>, bool>(),
|
||||
py::arg("kv_cache_config") = tbk::KvCacheConfig{}, py::arg("max_num_sequences") = py::none(),
|
||||
py::arg("enable_trt_overlap") = true)
|
||||
.def(py::init<tbk::KvCacheConfig, std::optional<SizeType>, bool, bool>(),
|
||||
py::arg_v("kv_cache_config", tbk::KvCacheConfig{}, "KvCacheConfig()"),
|
||||
py::arg("max_num_sequences") = py::none(), py::arg("enable_trt_overlap") = true,
|
||||
py::arg("use_context_fmha_for_generation") = false)
|
||||
.def_readwrite("kv_cache_config", &tb::TrtGptModelOptionalParams::kvCacheConfig)
|
||||
.def_readwrite("max_num_sequences", &tb::TrtGptModelOptionalParams::maxNumSequences)
|
||||
.def_readwrite("enable_trt_overlap", &tb::TrtGptModelOptionalParams::enableTrtOverlap);
|
||||
.def_readwrite("enable_trt_overlap", &tb::TrtGptModelOptionalParams::enableTrtOverlap)
|
||||
.def_readwrite("use_context_fmha_for_generation", &tb::TrtGptModelOptionalParams::useContextFMHAForGeneration);
|
||||
|
||||
py::class_<tpb::GptManager>(m, "GptManager")
|
||||
.def(py::init<std::filesystem::path const&, tb::TrtGptModelType, int32_t, tb::batch_scheduler::SchedulerPolicy,
|
||||
tpb::GetInferenceRequestsCallback, tpb::SendResponseCallback, tb::PollStopSignalCallback,
|
||||
tb::ReturnBatchManagerStatsCallback, const tb::TrtGptModelOptionalParams&, std::optional<uint64_t>>(),
|
||||
py::arg("trt_engine_path"), py::arg("model_type"), py::arg("max_beam_width"), py::arg("scheduler_policy"),
|
||||
py::arg("get_inference_requests_cb"), py::arg("send_response_cb"), py::arg("poll_stop_signal_cb") = nullptr,
|
||||
py::arg("return_batch_manager_stats_cb") = nullptr,
|
||||
py::arg("optional_params") = tb::TrtGptModelOptionalParams(), py::arg("terminate_req_id") = std::nullopt)
|
||||
.def("shutdown", &tpb::GptManager::exit)
|
||||
.def("__enter__", &tpb::GptManager::enter)
|
||||
.def("__exit__", &tpb::GptManager::exit);
|
||||
tpb::GptManager::initBindings(m);
|
||||
}
|
||||
|
||||
@ -19,6 +19,11 @@
|
||||
#include "tensorrt_llm/runtime/generationInput.h"
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
|
||||
using namespace tensorrt_llm::pybind::runtime;
|
||||
@ -36,6 +41,17 @@ std::shared_ptr<tr::PromptTuningParams> PromptTuningParams::toTrtLlm() const
|
||||
return ptt;
|
||||
}
|
||||
|
||||
void PromptTuningParams::initBindings(pybind11::module_& m)
|
||||
{
|
||||
py::class_<PromptTuningParams>(m, "PromptTuningParams")
|
||||
.def(py::init<PromptTuningParams::TensorPtr, PromptTuningParams::TensorPtr, PromptTuningParams::TensorPtr>(),
|
||||
py::arg("embedding_table") = py::none(), py::arg("tasks") = py::none(), py::arg("vocab_size") = py::none())
|
||||
.def_readwrite("embedding_table", &PromptTuningParams::embeddingTable)
|
||||
.def_readwrite("tasks", &PromptTuningParams::tasks)
|
||||
.def_readwrite("vocab_size", &PromptTuningParams::vocabSize)
|
||||
.def_readwrite("prompt_tuning_enabled", &PromptTuningParams::promptTuningEnabled);
|
||||
}
|
||||
|
||||
std::shared_ptr<tr::GenerationInput> GenerationInput::toTrtLlm() const
|
||||
{
|
||||
auto input = std::make_shared<tr::GenerationInput>(
|
||||
@ -52,3 +68,20 @@ std::shared_ptr<tr::GenerationInput> GenerationInput::toTrtLlm() const
|
||||
|
||||
return input;
|
||||
}
|
||||
|
||||
void GenerationInput::initBindings(pybind11::module_& m)
|
||||
{
|
||||
py::class_<GenerationInput>(m, "GenerationInput")
|
||||
.def(py::init<SizeType, SizeType, GenerationInput::TensorPtr, GenerationInput::TensorPtr, bool>(),
|
||||
py::arg("end_id"), py::arg("pad_id"), py::arg("ids"), py::arg("lengths"), py::arg("packed") = false)
|
||||
.def_readwrite("end_id", &GenerationInput::endId)
|
||||
.def_readwrite("pad_id", &GenerationInput::padId)
|
||||
.def_readwrite("ids", &GenerationInput::ids)
|
||||
.def_readwrite("lengths", &GenerationInput::lengths)
|
||||
.def_readwrite("packed", &GenerationInput::packed)
|
||||
.def_readwrite("embedding_bias", &GenerationInput::embeddingBias)
|
||||
.def_readwrite("bad_words_list", &GenerationInput::badWordsList)
|
||||
.def_readwrite("stop_words_list", &GenerationInput::stopWordsList)
|
||||
.def_readwrite("max_new_tokens", &GenerationInput::maxNewTokens)
|
||||
.def_readwrite("prompt_tuning_params", &GenerationInput::promptTuningParams);
|
||||
}
|
||||
|
||||
@ -17,15 +17,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/generationInput.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <ATen/ops/tensor.h>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::pybind::runtime
|
||||
{
|
||||
@ -46,6 +45,7 @@ public:
|
||||
}
|
||||
|
||||
[[nodiscard]] std::shared_ptr<tensorrt_llm::runtime::PromptTuningParams> toTrtLlm() const;
|
||||
static void initBindings(pybind11::module_& m);
|
||||
};
|
||||
|
||||
class GenerationInput
|
||||
@ -62,5 +62,6 @@ public:
|
||||
}
|
||||
|
||||
[[nodiscard]] std::shared_ptr<tensorrt_llm::runtime::GenerationInput> toTrtLlm() const;
|
||||
static void initBindings(pybind11::module_& m);
|
||||
};
|
||||
} // namespace tensorrt_llm::pybind::runtime
|
||||
|
||||
@ -19,6 +19,11 @@
|
||||
#include "tensorrt_llm/runtime/torch.h"
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
|
||||
using namespace tensorrt_llm::pybind::runtime;
|
||||
@ -35,6 +40,10 @@ std::shared_ptr<tr::GenerationOutput> GenerationOutput::toTrtLlm() const
|
||||
{
|
||||
output->contextLogits = tr::TorchView::of(contextLogits.value());
|
||||
}
|
||||
if (generationLogits)
|
||||
{
|
||||
output->generationLogits = tr::TorchView::of(generationLogits.value());
|
||||
}
|
||||
|
||||
if (onTokenGenerated)
|
||||
{
|
||||
@ -44,3 +53,15 @@ std::shared_ptr<tr::GenerationOutput> GenerationOutput::toTrtLlm() const
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
void GenerationOutput::initBindings(py::module_& m)
|
||||
{
|
||||
py::class_<GenerationOutput>(m, "GenerationOutput")
|
||||
.def(py::init<GenerationOutput::TensorPtr, GenerationOutput::TensorPtr>(), py::arg("ids"), py::arg("lengths"))
|
||||
.def_readwrite("ids", &GenerationOutput::ids)
|
||||
.def_readwrite("lengths", &GenerationOutput::lengths)
|
||||
.def_readwrite("log_probs", &GenerationOutput::logProbs)
|
||||
.def_readwrite("context_logits", &GenerationOutput::contextLogits)
|
||||
.def_readwrite("generation_logits", &GenerationOutput::generationLogits)
|
||||
.def_readwrite("on_token_generated", &GenerationOutput::onTokenGenerated);
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <optional>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::pybind::runtime
|
||||
{
|
||||
@ -36,6 +37,7 @@ public:
|
||||
}
|
||||
|
||||
[[nodiscard]] std::shared_ptr<tensorrt_llm::runtime::GenerationOutput> toTrtLlm() const;
|
||||
static void initBindings(pybind11::module_& m);
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::pybind::runtime
|
||||
|
||||
@ -100,7 +100,7 @@ typename tl::DynamicDecodeLayer<T>::ForwardParams prepareInputs(DecodingInput co
|
||||
|
||||
auto constexpr ite = 0; // no pipeline parallelism
|
||||
typename tl::DynamicDecodeLayer<T>::ForwardParams forwardParams{input.step, ite, input.maxLength,
|
||||
input.maxKvCacheLength, input.batchSize, tcc::toTllmTensor(*input.logits), tcc::toTllmTensor(*input.endIds)};
|
||||
input.maxAttentionWindow, input.batchSize, tcc::toTllmTensor(*input.logits), tcc::toTllmTensor(*input.endIds)};
|
||||
|
||||
if (input.cacheIndirection)
|
||||
{
|
||||
|
||||
@ -114,7 +114,7 @@ GptDecoderBatch::GptDecoderBatch(
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptDecoderBatch::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength,
|
||||
void GptDecoderBatch::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow,
|
||||
SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
@ -125,7 +125,7 @@ void GptDecoderBatch::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeTy
|
||||
mActualBatchSize = maxBatchSize;
|
||||
mGeneratedTokensPerStep.resize(maxBatchSize);
|
||||
mMaxSequenceLength = maxSequenceLength;
|
||||
mMaxKvCacheLength = maxKvCacheLength;
|
||||
mMaxAttentionWindow = maxAttentionWindow;
|
||||
mMaxTokensPerStep = maxTokensPerStep;
|
||||
|
||||
auto const maxBatchSizeShape = ITensor::makeShape({maxBatchSize});
|
||||
@ -248,7 +248,7 @@ void GptDecoderBatch::newRequest(
|
||||
TensorPtr endIdTensorPtr{ITensor::slice(constPointerCast(dJointInput.endIds), batchIdx, localBatchSize)};
|
||||
kernels::invokeFill(*endIdTensorPtr, endId, *stream);
|
||||
dInput = std::make_unique<DecodingInput>(
|
||||
inputLength, mMaxKvCacheLength, localBatchSize, dJointInput.logits, endIdTensorPtr);
|
||||
inputLength, mMaxAttentionWindow, localBatchSize, dJointInput.logits, endIdTensorPtr);
|
||||
|
||||
// Here, we need to add leading 1 dimension since decoderInput expects batchSize as leading dim
|
||||
// and decoder_batch::Request doesn't have batch dimension
|
||||
|
||||
@ -74,75 +74,190 @@ GptJsonConfig parseJson(InputType&& i)
|
||||
auto constexpr ingoreComments = true;
|
||||
auto json = nlohmann::json::parse(i, nullptr, allowExceptions, ingoreComments);
|
||||
|
||||
auto const& builderConfig = json.at("builder_config");
|
||||
auto const name = builderConfig.at("name").template get<std::string>();
|
||||
auto const precision = builderConfig.at("precision").template get<std::string>();
|
||||
auto const tensorParallelism = builderConfig.at("tensor_parallel").template get<SizeType>();
|
||||
auto const pipelineParallelism = parseJsonFieldOr(builderConfig, "pipeline_parallel", 1);
|
||||
auto const numHeads = builderConfig.at("num_heads").template get<SizeType>() / tensorParallelism;
|
||||
auto const hiddenSize = builderConfig.at("hidden_size").template get<SizeType>() / tensorParallelism;
|
||||
auto const vocabSize = builderConfig.at("vocab_size").template get<SizeType>();
|
||||
auto const numLayers = builderConfig.at("num_layers").template get<SizeType>();
|
||||
|
||||
auto dataType = nvinfer1::DataType::kFLOAT;
|
||||
if (!precision.compare("float32"))
|
||||
dataType = nvinfer1::DataType::kFLOAT;
|
||||
else if (!precision.compare("float16"))
|
||||
dataType = nvinfer1::DataType::kHALF;
|
||||
else if (!precision.compare("bfloat16"))
|
||||
dataType = nvinfer1::DataType::kBF16;
|
||||
else
|
||||
TLLM_CHECK_WITH_INFO(false, tc::fmtstr("Model data type '%s' not supported", precision.c_str()));
|
||||
|
||||
auto const quantMode = tc::QuantMode(parseJsonFieldOr(builderConfig, "quant_mode", tc::QuantMode::none().value()));
|
||||
// TODO:
|
||||
// Code crashes when numKvHeads <= 0. Clamping downwards to 1 prevents that, make sure this is best fix.
|
||||
auto const numKvHeads = std::max(
|
||||
parseJsonFieldOr(builderConfig, "num_kv_heads", numHeads * tensorParallelism) / tensorParallelism, 1);
|
||||
auto const maxBatchSize = parseJsonFieldOr(builderConfig, "max_batch_size", 0);
|
||||
auto const maxInputLen = parseJsonFieldOr(builderConfig, "max_input_len", 0);
|
||||
auto const maxOutputLen = parseJsonFieldOr(builderConfig, "max_output_len", 0);
|
||||
auto const maxDraftLen = parseJsonFieldOr(builderConfig, "max_draft_len", 0);
|
||||
auto const maxNumTokens = parseJsonFieldOptional<SizeType>(builderConfig, "max_num_tokens");
|
||||
auto const maxPromptEmbeddingTableSize
|
||||
= parseJsonFieldOr<SizeType>(builderConfig, "max_prompt_embedding_table_size", 0);
|
||||
|
||||
auto const computeContextLogits = parseJsonFieldOr(builderConfig, "gather_all_token_logits", false);
|
||||
auto const computeGenerationLogits = parseJsonFieldOr(builderConfig, "gather_all_token_logits", false);
|
||||
|
||||
auto const& pluginConfig = json.at("plugin_config");
|
||||
auto const pagedKvCache = pluginConfig.at("paged_kv_cache");
|
||||
auto const tokensPerBlock = pluginConfig.at("tokens_per_block");
|
||||
auto const& gptAttentionPlugin = pluginConfig.at("gpt_attention_plugin");
|
||||
auto const useGptAttentionPlugin = !gptAttentionPlugin.is_boolean() || gptAttentionPlugin.template get<bool>();
|
||||
auto const removeInputPadding = pluginConfig.at("remove_input_padding").template get<bool>();
|
||||
auto const useCustomAllReduce = pluginConfig.at("use_custom_all_reduce").template get<bool>();
|
||||
|
||||
auto modelConfig = GptModelConfig{vocabSize, numLayers, numHeads, hiddenSize, dataType};
|
||||
modelConfig.useGptAttentionPlugin(useGptAttentionPlugin);
|
||||
modelConfig.usePackedInput(removeInputPadding);
|
||||
modelConfig.usePagedKvCache(pagedKvCache);
|
||||
modelConfig.useCustomAllReduce(useCustomAllReduce);
|
||||
modelConfig.setTokensPerBlock(tokensPerBlock);
|
||||
modelConfig.setQuantMode(quantMode);
|
||||
modelConfig.setNbKvHeads(numKvHeads);
|
||||
modelConfig.computeContextLogits(computeContextLogits);
|
||||
modelConfig.computeGenerationLogits(computeGenerationLogits);
|
||||
|
||||
modelConfig.setMaxBatchSize(maxBatchSize);
|
||||
modelConfig.setMaxInputLen(maxInputLen);
|
||||
modelConfig.setMaxOutputLen(maxOutputLen);
|
||||
modelConfig.setMaxNumTokens(maxNumTokens);
|
||||
modelConfig.setMaxDraftLen(maxDraftLen);
|
||||
modelConfig.setMaxPromptEmbeddingTableSize(maxPromptEmbeddingTableSize);
|
||||
|
||||
if (name == std::string("chatglm_6b") || name == std::string("glm_10b"))
|
||||
auto engine_version = parseJsonFieldOr(json, "version", std::string("none"));
|
||||
if (engine_version == std::string("none"))
|
||||
{
|
||||
modelConfig.setModelVariant(GptModelConfig::ModelVariant::kGlm);
|
||||
// kGlm is only for ChatGLM-6B and GLM-10B
|
||||
}
|
||||
auto const& builderConfig = json.at("builder_config");
|
||||
auto const name = builderConfig.at("name").template get<std::string>();
|
||||
auto const precision = builderConfig.at("precision").template get<std::string>();
|
||||
auto const tensorParallelism = builderConfig.at("tensor_parallel").template get<SizeType>();
|
||||
auto const pipelineParallelism = parseJsonFieldOr(builderConfig, "pipeline_parallel", 1);
|
||||
auto const numHeads = builderConfig.at("num_heads").template get<SizeType>() / tensorParallelism;
|
||||
auto const hiddenSize = builderConfig.at("hidden_size").template get<SizeType>() / tensorParallelism;
|
||||
auto const vocabSize = builderConfig.at("vocab_size").template get<SizeType>();
|
||||
auto const numLayers = builderConfig.at("num_layers").template get<SizeType>();
|
||||
|
||||
return GptJsonConfig{name, precision, tensorParallelism, pipelineParallelism, modelConfig};
|
||||
auto dataType = nvinfer1::DataType::kFLOAT;
|
||||
if (!precision.compare("float32"))
|
||||
dataType = nvinfer1::DataType::kFLOAT;
|
||||
else if (!precision.compare("float16"))
|
||||
dataType = nvinfer1::DataType::kHALF;
|
||||
else if (!precision.compare("bfloat16"))
|
||||
dataType = nvinfer1::DataType::kBF16;
|
||||
else
|
||||
TLLM_CHECK_WITH_INFO(false, tc::fmtstr("Model data type '%s' not supported", precision.c_str()));
|
||||
|
||||
auto const quantMode
|
||||
= tc::QuantMode(parseJsonFieldOr(builderConfig, "quant_mode", tc::QuantMode::none().value()));
|
||||
// TODO:
|
||||
// Code crashes when numKvHeads <= 0. Clamping downwards to 1 prevents that, make sure this is best fix.
|
||||
auto const numKvHeads = std::max(
|
||||
parseJsonFieldOr(builderConfig, "num_kv_heads", numHeads * tensorParallelism) / tensorParallelism, 1);
|
||||
auto const maxBatchSize = parseJsonFieldOr(builderConfig, "max_batch_size", 0);
|
||||
auto const maxBeamWidth = parseJsonFieldOr(builderConfig, "max_beam_width", 0);
|
||||
auto const maxInputLen = parseJsonFieldOr(builderConfig, "max_input_len", 0);
|
||||
auto const maxOutputLen = parseJsonFieldOr(builderConfig, "max_output_len", 0);
|
||||
auto const maxDraftLen = parseJsonFieldOr(builderConfig, "max_draft_len", 0);
|
||||
auto const maxNumTokens = parseJsonFieldOptional<SizeType>(builderConfig, "max_num_tokens");
|
||||
auto const maxPromptEmbeddingTableSize
|
||||
= parseJsonFieldOr<SizeType>(builderConfig, "max_prompt_embedding_table_size", 0);
|
||||
|
||||
auto const computeContextLogits = parseJsonFieldOr(builderConfig, "gather_all_token_logits", false);
|
||||
auto const computeGenerationLogits = parseJsonFieldOr(builderConfig, "gather_all_token_logits", false);
|
||||
|
||||
auto const& pluginConfig = json.at("plugin_config");
|
||||
auto const pagedKvCache = pluginConfig.at("paged_kv_cache");
|
||||
auto const tokensPerBlock = pluginConfig.at("tokens_per_block");
|
||||
auto const& gptAttentionPlugin = pluginConfig.at("gpt_attention_plugin");
|
||||
auto const useGptAttentionPlugin = !gptAttentionPlugin.is_boolean() || gptAttentionPlugin.template get<bool>();
|
||||
auto const removeInputPadding = pluginConfig.at("remove_input_padding").template get<bool>();
|
||||
auto const useCustomAllReduce = pluginConfig.at("use_custom_all_reduce").template get<bool>();
|
||||
|
||||
auto modelConfig = GptModelConfig{vocabSize, numLayers, numHeads, hiddenSize, dataType};
|
||||
modelConfig.useGptAttentionPlugin(useGptAttentionPlugin);
|
||||
modelConfig.usePackedInput(removeInputPadding);
|
||||
modelConfig.usePagedKvCache(pagedKvCache);
|
||||
modelConfig.useCustomAllReduce(useCustomAllReduce);
|
||||
modelConfig.setTokensPerBlock(tokensPerBlock);
|
||||
modelConfig.setQuantMode(quantMode);
|
||||
modelConfig.setNbKvHeads(numKvHeads);
|
||||
modelConfig.computeContextLogits(computeContextLogits);
|
||||
modelConfig.computeGenerationLogits(computeGenerationLogits);
|
||||
|
||||
modelConfig.setMaxBatchSize(maxBatchSize);
|
||||
modelConfig.setMaxBeamWidth(maxBeamWidth);
|
||||
modelConfig.setMaxInputLen(maxInputLen);
|
||||
modelConfig.setMaxOutputLen(maxOutputLen);
|
||||
modelConfig.setMaxNumTokens(maxNumTokens);
|
||||
modelConfig.setMaxDraftLen(maxDraftLen);
|
||||
modelConfig.setMaxPromptEmbeddingTableSize(maxPromptEmbeddingTableSize);
|
||||
|
||||
if (name == std::string("chatglm_6b") || name == std::string("glm_10b"))
|
||||
{
|
||||
modelConfig.setModelVariant(GptModelConfig::ModelVariant::kGlm);
|
||||
// kGlm is only for ChatGLM-6B and GLM-10B
|
||||
}
|
||||
|
||||
return GptJsonConfig{name, engine_version, precision, tensorParallelism, pipelineParallelism, modelConfig};
|
||||
}
|
||||
else
|
||||
{
|
||||
auto const& pretrainedConfig = json.at("pretrained_config");
|
||||
auto const& buildConfig = json.at("build_config");
|
||||
auto const architecture = pretrainedConfig.at("architecture").template get<std::string>();
|
||||
auto const name = architecture;
|
||||
|
||||
auto const dtype = pretrainedConfig.at("dtype").template get<std::string>();
|
||||
auto const& mapping = pretrainedConfig.at("mapping");
|
||||
auto const tpSize = mapping.at("tp_size").template get<SizeType>();
|
||||
auto const ppSize = parseJsonFieldOr(mapping, "pp_size", 1);
|
||||
auto const numAttentionHeads = pretrainedConfig.at("num_attention_heads").template get<SizeType>() / tpSize;
|
||||
auto const hiddenSize = pretrainedConfig.at("hidden_size").template get<SizeType>() / tpSize;
|
||||
auto const vocabSize = pretrainedConfig.at("vocab_size").template get<SizeType>();
|
||||
auto const numHiddenLayers = pretrainedConfig.at("num_hidden_layers").template get<SizeType>();
|
||||
|
||||
auto dataType = nvinfer1::DataType::kFLOAT;
|
||||
if (!dtype.compare("float32"))
|
||||
dataType = nvinfer1::DataType::kFLOAT;
|
||||
else if (!dtype.compare("float16"))
|
||||
dataType = nvinfer1::DataType::kHALF;
|
||||
else if (!dtype.compare("bfloat16"))
|
||||
dataType = nvinfer1::DataType::kBF16;
|
||||
else
|
||||
TLLM_CHECK_WITH_INFO(false, tc::fmtstr("Model data type '%s' not supported", dtype.c_str()));
|
||||
|
||||
auto const& quantization = pretrainedConfig.at("quantization");
|
||||
auto useSmoothQuant = parseJsonFieldOr(quantization, "use_smooth_quant", false);
|
||||
auto perChannel = parseJsonFieldOr(quantization, "per_channel", false);
|
||||
auto perToken = parseJsonFieldOr(quantization, "per_token", false);
|
||||
// TODO: Unused parameters
|
||||
// auto perGroup = parseJsonFieldOr(quantization, "per_group", false);
|
||||
// auto groupSize = parseJsonFieldOr(quantization, "group_size", 128);
|
||||
auto int8KvCache = parseJsonFieldOr(quantization, "int8_kv_cache", false);
|
||||
auto enableFp8 = parseJsonFieldOr(quantization, "enable_fp8", false);
|
||||
auto fp8KvCache = parseJsonFieldOr(quantization, "fp8_kv_cache", false);
|
||||
auto useWeightOnly = parseJsonFieldOr(quantization, "use_weight_only", false);
|
||||
auto weightOnlyPrecision = parseJsonFieldOr(quantization, "weight_only_precision", std::string("int8"));
|
||||
bool quantizeWeights = false;
|
||||
bool quantizeActivations = false;
|
||||
if (useSmoothQuant)
|
||||
{
|
||||
quantizeWeights = true;
|
||||
quantizeActivations = true;
|
||||
}
|
||||
else if (useWeightOnly)
|
||||
{
|
||||
quantizeWeights = true;
|
||||
perToken = false;
|
||||
perChannel = false;
|
||||
}
|
||||
bool useInt4Weights = (weightOnlyPrecision == std::string("int4"));
|
||||
auto const quantMode = tc::QuantMode::fromDescription(quantizeWeights, quantizeActivations, perToken,
|
||||
perChannel, useInt4Weights, int8KvCache, fp8KvCache, enableFp8);
|
||||
|
||||
// TODO:
|
||||
// Code crashes when numKvHeads <= 0. Clamping downwards to 1 prevents that, make sure this is best fix.
|
||||
auto const numKVHeads = pretrainedConfig.at("num_key_value_heads").template get<SizeType>();
|
||||
auto const numKeyValueHeads = std::max(numKVHeads / tpSize, 1);
|
||||
|
||||
auto const maxBatchSize = parseJsonFieldOr(buildConfig, "max_batch_size", 0);
|
||||
auto const maxBeamWidth = parseJsonFieldOr(buildConfig, "max_beam_width", 0);
|
||||
auto const maxInputLen = parseJsonFieldOr(buildConfig, "max_input_len", 0);
|
||||
auto const maxOutputLen = parseJsonFieldOr(buildConfig, "max_output_len", 0);
|
||||
auto const maxDraftLen = parseJsonFieldOr(buildConfig, "max_draft_len", 0);
|
||||
auto const maxNumTokens = parseJsonFieldOptional<SizeType>(buildConfig, "max_num_tokens");
|
||||
auto const maxPromptEmbeddingTableSize
|
||||
= parseJsonFieldOr<SizeType>(buildConfig, "max_prompt_embedding_table_size", 0);
|
||||
|
||||
auto const computeContextLogits = parseJsonFieldOr(buildConfig, "gather_all_token_logits", false);
|
||||
auto const computeGenerationLogits = parseJsonFieldOr(buildConfig, "gather_all_token_logits", false);
|
||||
|
||||
auto const& pluginConfig = buildConfig.at("plugin_config");
|
||||
auto const pagedKvCache = pluginConfig.at("paged_kv_cache");
|
||||
auto const tokensPerBlock = pluginConfig.at("tokens_per_block");
|
||||
auto const& gptAttentionPlugin = pluginConfig.at("gpt_attention_plugin");
|
||||
auto const useGptAttentionPlugin = !gptAttentionPlugin.is_boolean() || gptAttentionPlugin.template get<bool>();
|
||||
auto const removeInputPadding = pluginConfig.at("remove_input_padding").template get<bool>();
|
||||
auto const useCustomAllReduce = pluginConfig.at("use_custom_all_reduce").template get<bool>();
|
||||
|
||||
auto modelConfig = GptModelConfig{vocabSize, numHiddenLayers, numAttentionHeads, hiddenSize, dataType};
|
||||
modelConfig.useGptAttentionPlugin(useGptAttentionPlugin);
|
||||
modelConfig.usePackedInput(removeInputPadding);
|
||||
modelConfig.usePagedKvCache(pagedKvCache);
|
||||
modelConfig.useCustomAllReduce(useCustomAllReduce);
|
||||
modelConfig.setTokensPerBlock(tokensPerBlock);
|
||||
modelConfig.setQuantMode(quantMode);
|
||||
modelConfig.setNbKvHeads(numKeyValueHeads);
|
||||
modelConfig.computeContextLogits(computeContextLogits);
|
||||
modelConfig.computeGenerationLogits(computeGenerationLogits);
|
||||
|
||||
modelConfig.setMaxBatchSize(maxBatchSize);
|
||||
modelConfig.setMaxBeamWidth(maxBeamWidth);
|
||||
modelConfig.setMaxInputLen(maxInputLen);
|
||||
modelConfig.setMaxOutputLen(maxOutputLen);
|
||||
modelConfig.setMaxNumTokens(maxNumTokens);
|
||||
modelConfig.setMaxDraftLen(maxDraftLen);
|
||||
modelConfig.setMaxPromptEmbeddingTableSize(maxPromptEmbeddingTableSize);
|
||||
|
||||
// TODO: Verify the architecture field in ChatGLM models
|
||||
if (name == std::string("ChatGLMModel") || name == std::string("GLMModel"))
|
||||
{
|
||||
modelConfig.setModelVariant(GptModelConfig::ModelVariant::kGlm);
|
||||
// kGlm is only for ChatGLM-6B and GLM-10B
|
||||
}
|
||||
|
||||
return GptJsonConfig{name, engine_version, dtype, tpSize, ppSize, modelConfig};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -153,8 +268,15 @@ std::string GptJsonConfig::engineFilename(WorldConfig const& worldConfig, std::s
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
getPipelineParallelism() == worldConfig.getPipelineParallelism(), "pipeline parallelism mismatch");
|
||||
auto pp = worldConfig.isPipelineParallel() ? "_pp" + std::to_string(worldConfig.getPipelineParallelism()) : "";
|
||||
return model + "_" + getPrecision() + "_tp" + std::to_string(worldConfig.getTensorParallelism()) + pp + "_rank"
|
||||
+ std::to_string(worldConfig.getRank()) + ".engine";
|
||||
if (getVersion() == std::string("none"))
|
||||
{
|
||||
return model + "_" + getPrecision() + "_tp" + std::to_string(worldConfig.getTensorParallelism()) + pp + "_rank"
|
||||
+ std::to_string(worldConfig.getRank()) + ".engine";
|
||||
}
|
||||
else
|
||||
{
|
||||
return "rank" + std::to_string(worldConfig.getRank()) + ".engine";
|
||||
}
|
||||
}
|
||||
|
||||
GptJsonConfig GptJsonConfig::parse(std::string const& json)
|
||||
|
||||
@ -117,7 +117,7 @@ void GptSession::createBuffers(SizeType numMicroBatches)
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptSession::createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxKvCacheLength,
|
||||
void GptSession::createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow,
|
||||
SizeType maxSequenceLength, nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
@ -135,13 +135,13 @@ void GptSession::createDecoders(SizeType batchSize, SizeType beamWidth, SizeType
|
||||
mDecoders.emplace_back(std::make_shared<StatefulGptDecoder>(vocabSize, vocabSizePadded, stream));
|
||||
constexpr SizeType maxTokensPerStep = 1;
|
||||
mDecoders.back()->setup(
|
||||
batchSize, beamWidth, maxKvCacheLength, maxSequenceLength, maxTokensPerStep, logitsType);
|
||||
batchSize, beamWidth, maxAttentionWindow, maxSequenceLength, maxTokensPerStep, logitsType);
|
||||
}
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptSession::createKvCacheManager(SizeType batchSize, SizeType beamWidth, SizeType maxKvCacheLength,
|
||||
void GptSession::createKvCacheManager(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow,
|
||||
SizeType maxSequenceLength, KvCacheConfig const& config)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
@ -169,11 +169,11 @@ void GptSession::createKvCacheManager(SizeType batchSize, SizeType beamWidth, Si
|
||||
= bmkv::KVCacheManager::getMaxNumTokens(config, kvDtype, mModelConfig, mWorldConfig, getBufferManager());
|
||||
TLLM_LOG_INFO("Using %d tokens in paged KV cache.", maxNumTokens);
|
||||
auto const maxNumBlocks = tc::ceilDiv(maxNumTokens, tokensPerBlock);
|
||||
auto const maxBlocksPerSeq = tc::ceilDiv(maxSequenceLength, tokensPerBlock);
|
||||
auto const maxBlocksPerSeq = tc::ceilDiv(std::min(maxSequenceLength, maxAttentionWindow), tokensPerBlock);
|
||||
|
||||
mKvCacheManager
|
||||
= std::make_shared<bmkv::KVCacheManager>(localNbLayers, nbHeads, nbKvHeads, hiddenSize, tokensPerBlock,
|
||||
maxNumBlocks, batchSize, beamWidth, maxBlocksPerSeq, maxKvCacheLength, kvDtype, mRuntime->getStreamPtr());
|
||||
maxNumBlocks, batchSize, beamWidth, maxBlocksPerSeq, maxAttentionWindow, kvDtype, mRuntime->getStreamPtr());
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
@ -234,8 +234,8 @@ void GptSession::setup(Config const& sessionConfig)
|
||||
auto const maxBatchSize = sessionConfig.maxBatchSize;
|
||||
auto const maxBeamWidth = sessionConfig.maxBeamWidth;
|
||||
auto const maxSequenceLength = sessionConfig.maxSequenceLength;
|
||||
auto const maxKvCacheLength = sessionConfig.kvCacheConfig.maxKvCacheLength.has_value()
|
||||
? std::min(sessionConfig.kvCacheConfig.maxKvCacheLength.value(), maxSequenceLength)
|
||||
auto const maxAttentionWindow = sessionConfig.kvCacheConfig.maxAttentionWindow.has_value()
|
||||
? std::min(sessionConfig.kvCacheConfig.maxAttentionWindow.value(), maxSequenceLength)
|
||||
: maxSequenceLength;
|
||||
|
||||
mMicroBatchConfig = MicroBatchConfig(maxBatchSize, mWorldConfig.getPipelineParallelism(),
|
||||
@ -249,18 +249,18 @@ void GptSession::setup(Config const& sessionConfig)
|
||||
// gptDecoderBatch does not resize buffers, but allows smaller batchSize and beamWidth.
|
||||
// TODO refactor batch manager to remove dependency on maxSequenceLength.
|
||||
mDecoderMaxSequenceLength = maxSequenceLength;
|
||||
mDecoderMaxKvCacheLength = maxKvCacheLength;
|
||||
mDecoderMaxAttentionWindow = maxAttentionWindow;
|
||||
|
||||
if (mModelConfig.usePagedKvCache())
|
||||
{
|
||||
createKvCacheManager(
|
||||
maxBatchSize, maxBeamWidth, maxKvCacheLength, maxSequenceLength, sessionConfig.kvCacheConfig);
|
||||
maxBatchSize, maxBeamWidth, maxAttentionWindow, maxSequenceLength, sessionConfig.kvCacheConfig);
|
||||
}
|
||||
|
||||
if (mWorldConfig.isLastPipelineParallelRank())
|
||||
{
|
||||
auto const logitsType = mRuntime->getEngine().getTensorDataType("logits");
|
||||
createDecoders(mMicroBatchConfig.genBatchSize, maxBeamWidth, maxKvCacheLength, maxSequenceLength, logitsType,
|
||||
createDecoders(mMicroBatchConfig.genBatchSize, maxBeamWidth, maxAttentionWindow, maxSequenceLength, logitsType,
|
||||
sessionConfig.decoderPerRequest, mMicroBatchConfig.numGenBatches);
|
||||
}
|
||||
|
||||
@ -280,7 +280,7 @@ void GptSession::setup(Config const& sessionConfig)
|
||||
{
|
||||
// we don't know maxInputLength yet and ignore it for pre-allocation
|
||||
buffers->generationConfig = RuntimeBuffers::GenerationConfig{
|
||||
mMicroBatchConfig.genBatchSize, maxBeamWidth, 0, maxKvCacheLength, maxSequenceLength};
|
||||
mMicroBatchConfig.genBatchSize, maxBeamWidth, 0, maxAttentionWindow, maxSequenceLength};
|
||||
buffers->reshape(mModelConfig, mWorldConfig);
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
@ -295,9 +295,9 @@ void GptSession::kvCacheAddSequences(SizeType beamWidth, SizeType microBatchId,
|
||||
TLLM_CHECK(contextLengthsHost);
|
||||
auto const contextLengthsPtr = bufferCast<SizeType const>(*contextLengthsHost);
|
||||
auto const contextLengthsSize = static_cast<SizeType>(contextLengthsHost->getSize());
|
||||
for (SizeType batchIdx = firstBatchIdx; batchIdx < firstBatchIdx + contextLengthsSize; ++batchIdx)
|
||||
for (SizeType batchIdx = 0; batchIdx < contextLengthsSize; ++batchIdx)
|
||||
{
|
||||
mKvCacheManager->addSequence(batchIdx, contextLengthsPtr[batchIdx], beamWidth);
|
||||
mKvCacheManager->addSequence(firstBatchIdx + batchIdx, contextLengthsPtr[batchIdx], beamWidth);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -621,7 +621,7 @@ void GptSession::generateBatched(std::vector<GenerationOutput>& microBatchesOutp
|
||||
auto const& microBatchInputs = microBatchesInputs.at(microBatchId);
|
||||
auto& buffers = *mBuffers.at(microBatchId);
|
||||
buffers.initFromInput(*microBatchInputs.ids, microBatchInputs.lengths, microBatchInputs.packed, beamWidth,
|
||||
mDecoderMaxKvCacheLength, mDecoderMaxSequenceLength, manager);
|
||||
mDecoderMaxAttentionWindow, mDecoderMaxSequenceLength, manager);
|
||||
buffers.reshape(mModelConfig, mWorldConfig);
|
||||
buffers.reset(manager);
|
||||
}
|
||||
|
||||
@ -29,8 +29,8 @@ using namespace tensorrt_llm::runtime;
|
||||
namespace tc = tensorrt_llm::common;
|
||||
|
||||
RuntimeBuffers::GenerationConfig RuntimeBuffers::GenerationConfig::fromInput(ITensor const& inputIds,
|
||||
ITensor const& inputLengthsHost, bool const inputPacked, SizeType const beamWidth, SizeType const maxKvCacheLength,
|
||||
SizeType const maxSequenceLength)
|
||||
ITensor const& inputLengthsHost, bool const inputPacked, SizeType const beamWidth,
|
||||
SizeType const maxAttentionWindow, SizeType const maxSequenceLength)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto const batchSize = static_cast<SizeType>(inputLengthsHost.getSize());
|
||||
@ -58,7 +58,8 @@ RuntimeBuffers::GenerationConfig RuntimeBuffers::GenerationConfig::fromInput(ITe
|
||||
"generated.");
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
return GenerationConfig{batchSize, beamWidth, maxInputLength, maxKvCacheLength, maxSequenceLength, inputLengthSum};
|
||||
return GenerationConfig{
|
||||
batchSize, beamWidth, maxInputLength, maxAttentionWindow, maxSequenceLength, inputLengthSum};
|
||||
}
|
||||
|
||||
void RuntimeBuffers::clear()
|
||||
@ -158,7 +159,7 @@ void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelCon
|
||||
if (modelConfig.useGptAttentionPlugin())
|
||||
{
|
||||
pastKeyValueLengths = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32);
|
||||
maxKvCacheLengths
|
||||
maxAttentionWindows
|
||||
= utils::createBufferVector(runtime, localNbLayers, MemoryType::kCPU, nvinfer1::DataType::kINT32);
|
||||
}
|
||||
else
|
||||
@ -185,7 +186,7 @@ void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelCon
|
||||
}
|
||||
|
||||
void RuntimeBuffers::initFromInput(ITensor const& inputIds, TensorPtr const& inputLengths, bool inputPacked,
|
||||
SizeType beamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength, BufferManager& manager)
|
||||
SizeType beamWidth, SizeType maxAttentionWindow, SizeType maxSequenceLength, BufferManager& manager)
|
||||
{
|
||||
contextLengthsDevice = inputLengths;
|
||||
contextLengthsHost->reshape(inputLengths->getShape());
|
||||
@ -193,7 +194,7 @@ void RuntimeBuffers::initFromInput(ITensor const& inputIds, TensorPtr const& inp
|
||||
manager.getStream().synchronize(); // wait for context lengths to be copied to host
|
||||
|
||||
generationConfig = RuntimeBuffers::GenerationConfig::fromInput(
|
||||
inputIds, *contextLengthsHost, inputPacked, beamWidth, maxKvCacheLength, maxSequenceLength);
|
||||
inputIds, *contextLengthsHost, inputPacked, beamWidth, maxAttentionWindow, maxSequenceLength);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
@ -203,7 +204,7 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
|
||||
auto const batchSize = generationConfig.batchSize;
|
||||
auto const beamWidth = generationConfig.beamWidth;
|
||||
auto const maxInputLength = generationConfig.maxInputLength;
|
||||
auto const maxKvCacheLength = generationConfig.maxKvCacheLength;
|
||||
auto const maxAttentionWindow = generationConfig.maxAttentionWindow;
|
||||
|
||||
if (worldConfig.isLastPipelineParallelRank() && !modelConfig.computeContextLogits())
|
||||
{
|
||||
@ -214,14 +215,14 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
|
||||
lastTokenIds->reshape(ITensor::makeShape({batchSize}));
|
||||
|
||||
auto kvCacheReserve = ITensor::makeShape(
|
||||
{batchSize, 2, modelConfig.getNbKvHeads(), maxKvCacheLength, modelConfig.getSizePerHead()});
|
||||
{batchSize, 2, modelConfig.getNbKvHeads(), maxAttentionWindow, modelConfig.getSizePerHead()});
|
||||
auto kvCacheShape
|
||||
= ITensor::makeShape({batchSize, 2, modelConfig.getNbKvHeads(), maxInputLength, modelConfig.getSizePerHead()});
|
||||
if (modelConfig.usePagedKvCache())
|
||||
{
|
||||
auto const localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism());
|
||||
auto const tokensPerBlock = modelConfig.getTokensPerBlock();
|
||||
auto const maxBlocksPerSeq = (maxKvCacheLength + tokensPerBlock - 1) / tokensPerBlock;
|
||||
auto const maxBlocksPerSeq = (maxAttentionWindow + tokensPerBlock - 1) / tokensPerBlock;
|
||||
|
||||
// reserve batchSize * beamWidth and resize to batchSize
|
||||
auto cacheBlockPointersShape = ITensor::makeShape({localNbLayers, batchSize * beamWidth, 2, maxBlocksPerSeq});
|
||||
@ -240,7 +241,7 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
|
||||
{
|
||||
pastKeyValueLengths->reshape(ITensor::makeShape({batchSize}));
|
||||
requestTypes->reshape(ITensor::makeShape({batchSize}));
|
||||
utils::reshapeBufferVector(maxKvCacheLengths, ITensor::makeShape({1}));
|
||||
utils::reshapeBufferVector(maxAttentionWindows, ITensor::makeShape({1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -250,7 +251,7 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
|
||||
utils::reshapeBufferVector(presentKeysVals, kvCacheShape);
|
||||
}
|
||||
|
||||
auto const cacheIndirShape = ITensor::makeShape({batchSize, beamWidth, maxKvCacheLength});
|
||||
auto const cacheIndirShape = ITensor::makeShape({batchSize, beamWidth, maxAttentionWindow});
|
||||
cacheIndirectionDecoderInput->reshape(cacheIndirShape);
|
||||
cacheIndirectionDecoderOutput->reshape(cacheIndirShape);
|
||||
|
||||
@ -334,7 +335,7 @@ std::vector<RuntimeBuffers> RuntimeBuffers::split(
|
||||
if (modelConfig.useGptAttentionPlugin())
|
||||
{
|
||||
buffers.pastKeyValueLengths = ITensor::slice(pastKeyValueLengths, offset, batchSize);
|
||||
buffers.maxKvCacheLengths = maxKvCacheLengths;
|
||||
buffers.maxAttentionWindows = maxAttentionWindows;
|
||||
buffers.requestTypes = ITensor::slice(requestTypes, offset, batchSize);
|
||||
}
|
||||
else
|
||||
@ -548,10 +549,10 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c
|
||||
TLLM_CHECK(requestTypes->getSize() == static_cast<std::size_t>(batchSize));
|
||||
std::fill_n(RequestTypesPtr, batchSize, 0);
|
||||
|
||||
// Set maxKvCacheLengths buffer to the same value currently.
|
||||
// Set maxAttentionWindows buffer to the same value currently.
|
||||
for (auto layer = 0; layer < localNbLayers; ++layer)
|
||||
{
|
||||
bufferCast<SizeType>(*maxKvCacheLengths[layer])[0] = generationConfig.maxKvCacheLength;
|
||||
bufferCast<SizeType>(*maxAttentionWindows[layer])[0] = generationConfig.maxAttentionWindow;
|
||||
}
|
||||
|
||||
auto const& inputShape = inputIds->getShape();
|
||||
@ -823,7 +824,7 @@ void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outpu
|
||||
inputBuffers.insert_or_assign("host_past_key_value_lengths", pastKeyValueLengths);
|
||||
inputBuffers.insert_or_assign("host_request_types", requestTypes);
|
||||
inputBuffers.insert_or_assign("sequence_length", sequenceLengths);
|
||||
utils::insertTensorVector(inputBuffers, "host_max_kv_cache_length_", maxKvCacheLengths, firstLayerId);
|
||||
utils::insertTensorVector(inputBuffers, "host_max_attention_window_size_", maxAttentionWindows, firstLayerId);
|
||||
|
||||
if (modelConfig.usePackedInput())
|
||||
{
|
||||
|
||||
@ -49,11 +49,11 @@ public:
|
||||
GenerationConfig() = default;
|
||||
|
||||
explicit GenerationConfig(SizeType batchSize, SizeType beamWidth, SizeType maxInputLength,
|
||||
SizeType maxKvCacheLength, SizeType maxSeqLength, SizeType inputLengthSum = SizeType(0))
|
||||
SizeType maxAttentionWindow, SizeType maxSeqLength, SizeType inputLengthSum = SizeType(0))
|
||||
: batchSize{batchSize}
|
||||
, beamWidth{beamWidth}
|
||||
, maxInputLength{maxInputLength}
|
||||
, maxKvCacheLength{maxKvCacheLength}
|
||||
, maxAttentionWindow{maxAttentionWindow}
|
||||
, maxSeqLength{maxSeqLength}
|
||||
, inputLengthSum{inputLengthSum}
|
||||
{
|
||||
@ -62,12 +62,12 @@ public:
|
||||
SizeType batchSize{};
|
||||
SizeType beamWidth{};
|
||||
SizeType maxInputLength{};
|
||||
SizeType maxKvCacheLength{};
|
||||
SizeType maxAttentionWindow{};
|
||||
SizeType maxSeqLength{};
|
||||
SizeType inputLengthSum{}; // Initialized only if inputPacked is set to true in fromInput.
|
||||
|
||||
static GenerationConfig fromInput(ITensor const& inputIds, ITensor const& inputLengths, bool inputPacked,
|
||||
SizeType beamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength);
|
||||
SizeType beamWidth, SizeType maxAttentionWindow, SizeType maxSequenceLength);
|
||||
};
|
||||
|
||||
public:
|
||||
@ -89,10 +89,10 @@ public:
|
||||
TensorPtr requestTypes; // with attention plugin. Host tensor
|
||||
|
||||
std::vector<TensorPtr> presentKeysVals;
|
||||
std::vector<TensorPtr> presentKeysValsAlt; // without attention plugin
|
||||
std::vector<TensorPtr> maxKvCacheLengths; // with attention plugin, host tensor
|
||||
TensorPtr kvCacheBlockPointersHost; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
|
||||
TensorPtr kvCacheBlockPointersDevice; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
|
||||
std::vector<TensorPtr> presentKeysValsAlt; // without attention plugin
|
||||
std::vector<TensorPtr> maxAttentionWindows; // with attention plugin, host tensor
|
||||
TensorPtr kvCacheBlockPointersHost; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
|
||||
TensorPtr kvCacheBlockPointersDevice; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
|
||||
|
||||
// References to tmp buffers
|
||||
TensorPtr newTokens;
|
||||
@ -126,7 +126,7 @@ public:
|
||||
void create(TllmRuntime& runtime, GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
|
||||
|
||||
void initFromInput(ITensor const& inputIds, TensorPtr const& inputLengths, bool inputPacked, SizeType beamWidth,
|
||||
SizeType maxKvCacheLength, SizeType maxSequenceLength, BufferManager& manager);
|
||||
SizeType maxAttentionWindow, SizeType maxSequenceLength, BufferManager& manager);
|
||||
|
||||
//! \brief Reshape buffers based on current GenerationConfig
|
||||
void reshape(GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
|
||||
|
||||
@ -908,6 +908,7 @@ void tileTensor(ITensor& output, ITensor const& input, SizeType beamWidth, CudaS
|
||||
case nvinfer1::DataType::kINT32: invokeTileTensor<SizeType>(output, input, beamWidth, stream); break;
|
||||
case nvinfer1::DataType::kFLOAT: invokeTileTensor<float>(output, input, beamWidth, stream); break;
|
||||
case nvinfer1::DataType::kHALF: invokeTileTensor<half>(output, input, beamWidth, stream); break;
|
||||
case nvinfer1::DataType::kBF16: invokeTileTensor<__nv_bfloat16>(output, input, beamWidth, stream); break;
|
||||
case nvinfer1::DataType::kINT8: invokeTileTensor<int8_t>(output, input, beamWidth, stream); break;
|
||||
case nvinfer1::DataType::kFP8: invokeTileTensor<__nv_fp8_e4m3>(output, input, beamWidth, stream); break;
|
||||
default: TLLM_CHECK_WITH_INFO(false, "data type not supported");
|
||||
|
||||
@ -64,19 +64,19 @@ StatefulGptDecoder::StatefulGptDecoder(std::size_t vocabSize, std::size_t vocabS
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void StatefulGptDecoder::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength,
|
||||
void StatefulGptDecoder::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow,
|
||||
SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK(maxTokensPerStep == 1);
|
||||
mDecoder = IGptDecoder::create(dtype, mVocabSize, mVocabSizePadded, mStream);
|
||||
|
||||
reshapeBuffers(maxBatchSize, maxBeamWidth, maxKvCacheLength, maxSequenceLength);
|
||||
reshapeBuffers(maxBatchSize, maxBeamWidth, maxAttentionWindow, maxSequenceLength);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void StatefulGptDecoder::reshapeBuffers(
|
||||
SizeType batchSize, SizeType beamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength)
|
||||
SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow, SizeType maxSequenceLength)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK(batchSize > 0);
|
||||
@ -84,7 +84,7 @@ void StatefulGptDecoder::reshapeBuffers(
|
||||
TLLM_CHECK(maxSequenceLength > 0);
|
||||
|
||||
mMaxSequenceLength = maxSequenceLength;
|
||||
mMaxKvCacheLength = maxKvCacheLength;
|
||||
mMaxAttentionWindow = maxAttentionWindow;
|
||||
|
||||
auto const batchSizeShape = ITensor::makeShape({batchSize});
|
||||
auto const batchSizeXbeamWidth = ITensor::makeShape({batchSize, beamWidth});
|
||||
@ -137,7 +137,7 @@ void StatefulGptDecoder::newBatch(
|
||||
auto const batchSize = inputLengthsShape.d[0];
|
||||
auto const beamWidth = samplingConfig.beamWidth;
|
||||
|
||||
reshapeBuffers(batchSize, beamWidth, mMaxKvCacheLength, mMaxSequenceLength);
|
||||
reshapeBuffers(batchSize, beamWidth, mMaxAttentionWindow, mMaxSequenceLength);
|
||||
mDecoder->setup(samplingConfig, batchSize, mMaxSequenceLength);
|
||||
|
||||
// sanity checks, should always be true after reshape
|
||||
@ -167,7 +167,7 @@ void StatefulGptDecoder::newBatch(
|
||||
// inputs
|
||||
auto& dInput = *mDecodingInput;
|
||||
dInput.maxLength = maxInputLength;
|
||||
dInput.maxKvCacheLength = mMaxKvCacheLength;
|
||||
dInput.maxAttentionWindow = mMaxAttentionWindow;
|
||||
dInput.batchSize = batchSize;
|
||||
kernels::invokeFill(const_cast<ITensor&>(*dInput.endIds), endId, *stream);
|
||||
dInput.embeddingBias = inputs.embeddingBias;
|
||||
|
||||
@ -39,7 +39,7 @@ public:
|
||||
StatefulGptDecoder(std::size_t vocabSize, std::size_t vocabSizePadded, CudaStreamPtr stream);
|
||||
|
||||
//! Setup the decoder before calling `forward()`
|
||||
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength,
|
||||
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow, SizeType maxSequenceLength,
|
||||
SizeType maxTokensPerStep, nvinfer1::DataType dtype) override;
|
||||
|
||||
//! @brief Initialize the decoder with new batch of inputs.
|
||||
@ -98,7 +98,8 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
void reshapeBuffers(SizeType batchSize, SizeType beamWidth, SizeType mMaxKvCacheLength, SizeType maxSequenceLength);
|
||||
void reshapeBuffers(
|
||||
SizeType batchSize, SizeType beamWidth, SizeType mMaxAttentionWindow, SizeType maxSequenceLength);
|
||||
|
||||
private:
|
||||
std::size_t const mVocabSize;
|
||||
@ -116,6 +117,6 @@ private:
|
||||
|
||||
SizeType mNbSteps;
|
||||
SizeType mMaxSequenceLength{};
|
||||
SizeType mMaxKvCacheLength{};
|
||||
SizeType mMaxAttentionWindow{};
|
||||
};
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -70,11 +70,6 @@ public:
|
||||
return mTensor.is_cuda() ? MemoryType::kGPU : mTensor.is_pinned() ? MemoryType::kPINNED : MemoryType::kCPU;
|
||||
}
|
||||
|
||||
void resize(std::size_t newSize) override
|
||||
{
|
||||
ITensor::resize(newSize);
|
||||
}
|
||||
|
||||
void release() override
|
||||
{
|
||||
resize(0);
|
||||
@ -87,9 +82,19 @@ public:
|
||||
|
||||
void reshape(Shape const& dims) override
|
||||
{
|
||||
TLLM_CHECK(volumeNonNegative(dims) <= getCapacity());
|
||||
mTensor.resize_(TorchUtils::shape(dims));
|
||||
try
|
||||
{
|
||||
mTensor.resize_(TorchUtils::shape(dims));
|
||||
}
|
||||
catch (c10::Error const& e)
|
||||
{
|
||||
TLLM_THROW("%s", e.what_without_backtrace());
|
||||
}
|
||||
mDims = dims;
|
||||
if (auto const newSize = volumeNonNegative(dims); mCapacity < newSize)
|
||||
{
|
||||
mCapacity = newSize;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@ -21,10 +21,12 @@
|
||||
#include "tensorrt_llm/runtime/tllmLogger.h"
|
||||
#include "tensorrt_llm/runtime/utils/multiDeviceUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <csignal>
|
||||
#include <cstdlib>
|
||||
#include <mpi.h>
|
||||
#include <mutex>
|
||||
#include <numeric>
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
namespace tc = tensorrt_llm::common;
|
||||
@ -73,7 +75,7 @@ bool WorldConfig::validConfig(nvinfer1::ILogger& logger, SizeType tensorParallel
|
||||
}
|
||||
|
||||
WorldConfig WorldConfig::mpi(nvinfer1::ILogger& logger, SizeType gpusPerNode, std::optional<SizeType> tensorParallelism,
|
||||
std::optional<SizeType> pipelineParallelism)
|
||||
std::optional<SizeType> pipelineParallelism, std::optional<std::vector<SizeType>> userSpecifiedDeviceIds)
|
||||
{
|
||||
initMpi(logger);
|
||||
|
||||
@ -85,14 +87,60 @@ WorldConfig WorldConfig::mpi(nvinfer1::ILogger& logger, SizeType gpusPerNode, st
|
||||
auto pp = pipelineParallelism.value_or(1);
|
||||
auto tp = tensorParallelism.value_or(mpiSize / pp);
|
||||
TLLM_CHECK(mpiSize == tp * pp);
|
||||
return WorldConfig{tp, pp, mpiRank, gpusPerNode};
|
||||
// Pass the user-specified device lists to the WorldConfig. Otherwise create a default list of device ids.
|
||||
std::vector<SizeType> deviceIds(mpiSize);
|
||||
std::iota(deviceIds.begin(), deviceIds.end(), 0);
|
||||
|
||||
if (userSpecifiedDeviceIds)
|
||||
{
|
||||
TLLM_CHECK(static_cast<SizeType>(userSpecifiedDeviceIds.value().size()) == tp * pp);
|
||||
deviceIds = userSpecifiedDeviceIds.value();
|
||||
|
||||
// For user provided device list, verify:
|
||||
// 1) total number is smaller than the total cuda-visible device counts
|
||||
// 2) all deviceIds is within the range
|
||||
// 3) All ids are unique
|
||||
// 4) if the deviceIds are contiguous, and throw a warning if not
|
||||
TLLM_CHECK((gpusPerNode >= static_cast<SizeType>(deviceIds.size()))
|
||||
&& (gpusPerNode > *std::max_element(deviceIds.begin(), deviceIds.end()))
|
||||
&& *std::min_element(deviceIds.begin(), deviceIds.end()) >= 0);
|
||||
gpusPerNode = deviceIds.size();
|
||||
auto it = std::unique(deviceIds.begin(), deviceIds.end());
|
||||
TLLM_CHECK(std::distance(deviceIds.begin(), it) == gpusPerNode);
|
||||
std::sort(deviceIds.begin(), deviceIds.end());
|
||||
|
||||
// If the deviceIds are not contiguous, throw a warning
|
||||
bool isContiguous = true;
|
||||
for (SizeType i = 1; i < static_cast<SizeType>(deviceIds.size()); ++i)
|
||||
{
|
||||
if (deviceIds[i] != deviceIds[i - 1] + 1)
|
||||
{
|
||||
isContiguous = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!isContiguous)
|
||||
{
|
||||
logger.log(nvinfer1::ILogger::Severity::kWARNING, "The user specified device IDs are not contiguous!");
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "Using user-specificed devices: [";
|
||||
for (auto& id : deviceIds)
|
||||
{
|
||||
ss << id << ",";
|
||||
}
|
||||
ss << "]";
|
||||
logger.log(nvinfer1::ILogger::Severity::kINFO, ss.str().c_str());
|
||||
}
|
||||
return WorldConfig{tp, pp, mpiRank, gpusPerNode, deviceIds};
|
||||
}
|
||||
|
||||
WorldConfig WorldConfig::mpi(
|
||||
SizeType gpusPerNode, std::optional<SizeType> tensorParallelism, std::optional<SizeType> pipelineParallelism)
|
||||
WorldConfig WorldConfig::mpi(SizeType gpusPerNode, std::optional<SizeType> tensorParallelism,
|
||||
std::optional<SizeType> pipelineParallelism, std::optional<std::vector<SizeType>> userSpecifiedDeviceIds)
|
||||
{
|
||||
TllmLogger logger{};
|
||||
return mpi(logger, gpusPerNode, tensorParallelism, pipelineParallelism);
|
||||
return mpi(logger, gpusPerNode, tensorParallelism, pipelineParallelism, userSpecifiedDeviceIds);
|
||||
}
|
||||
|
||||
std::vector<SizeType> WorldConfig::getPipelineParallelGroup() const
|
||||
|
||||
@ -138,7 +138,7 @@ void FtDynamicDecode<T>::setup(size_t batch_size, size_t beam_width, th::optiona
|
||||
|
||||
template <typename T>
|
||||
void FtDynamicDecode<T>::forward(th::Tensor& logits, // (batch_size, beam_width, hidden_size)
|
||||
int step, int max_input_length, int max_kv_cache_length, uint64_t ite, int local_batch_size, th::Tensor end_id,
|
||||
int step, int max_input_length, int max_attention_window, uint64_t ite, int local_batch_size, th::Tensor end_id,
|
||||
th::optional<th::Tensor> embedding_bias_opt, th::optional<th::Tensor> input_lengths_opt,
|
||||
th::optional<th::Tensor> sequence_limit_length_opt, th::optional<th::Tensor> stop_words_list_opt,
|
||||
th::optional<th::Tensor> bad_words_list_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
|
||||
@ -158,7 +158,7 @@ void FtDynamicDecode<T>::forward(th::Tensor& logits, // (batch_size, beam_width,
|
||||
auto const& logits_converted = convert_tensor<float>(logits);
|
||||
auto const& end_ids_converted = convert_tensor<int>(end_id);
|
||||
typename tensorrt_llm::layers::DynamicDecodeLayer<T>::ForwardParams forwardParams{step, static_cast<int>(ite),
|
||||
max_input_length, max_kv_cache_length, local_batch_size, logits_converted, end_ids_converted};
|
||||
max_input_length, max_attention_window, local_batch_size, logits_converted, end_ids_converted};
|
||||
|
||||
safeUpdate<int>(src_cache_indirection_opt, forwardParams.src_cache_indirection);
|
||||
safeUpdate<int>(sequence_limit_length_opt, forwardParams.sequence_limit_length);
|
||||
@ -275,7 +275,7 @@ void DynamicDecodeOp::setup(int64_t batch_size, int64_t beam_width, th::optional
|
||||
}
|
||||
|
||||
th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max_input_length,
|
||||
int64_t max_kv_cache_length, int64_t ite, int64_t local_batch_size, th::Tensor end_id,
|
||||
int64_t max_attention_window, int64_t ite, int64_t local_batch_size, th::Tensor end_id,
|
||||
th::optional<th::Tensor> embedding_bias_opt,
|
||||
th::optional<th::Tensor> input_lengths_opt, // length of input contexts.
|
||||
th::optional<th::Tensor> sequence_limit_length_opt, th::optional<th::Tensor> stop_words_list_opt,
|
||||
@ -341,7 +341,7 @@ th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max
|
||||
|
||||
dynamic_decode_->forward(
|
||||
// Inputs
|
||||
logits, static_cast<int>(step), static_cast<int>(max_input_length), static_cast<int>(max_kv_cache_length),
|
||||
logits, static_cast<int>(step), static_cast<int>(max_input_length), static_cast<int>(max_attention_window),
|
||||
static_cast<uint32_t>(ite), static_cast<int>(local_batch_size), end_id, embedding_bias_opt, input_lengths_opt,
|
||||
sequence_limit_length_opt, stop_words_list_opt, bad_words_list_opt, no_repeat_ngram_size_opt,
|
||||
src_cache_indirection_opt,
|
||||
|
||||
@ -39,7 +39,7 @@ public:
|
||||
= 0;
|
||||
|
||||
virtual void forward(th::Tensor& logits, // (batch_size, beam_width, hidden_size)
|
||||
int step, int max_input_length, int max_kv_cache_length, uint64_t ite, int local_batch_size, th::Tensor end_id,
|
||||
int step, int max_input_length, int max_attention_window, uint64_t ite, int local_batch_size, th::Tensor end_id,
|
||||
th::optional<th::Tensor> embedding_bias_opt, th::optional<th::Tensor> input_lengths_opt,
|
||||
th::optional<th::Tensor> sequence_limit_length_opt, th::optional<th::Tensor> stop_words_list_opt,
|
||||
th::optional<th::Tensor> bad_words_list_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
|
||||
@ -77,7 +77,7 @@ public:
|
||||
th::optional<th::Tensor> top_p_reset_ids_opt) override;
|
||||
|
||||
void forward(th::Tensor& logits, // (batch_size, beam_width, hidden_size)
|
||||
int step, int max_input_length, int max_kv_cache_length, uint64_t ite, int local_batch_size, th::Tensor end_id,
|
||||
int step, int max_input_length, int max_attention_window, uint64_t ite, int local_batch_size, th::Tensor end_id,
|
||||
th::optional<th::Tensor> embedding_bias_opt, th::optional<th::Tensor> input_lengths_opt,
|
||||
th::optional<th::Tensor> sequence_limit_length_opt, th::optional<th::Tensor> stop_words_list_opt,
|
||||
th::optional<th::Tensor> bad_words_list_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
|
||||
@ -121,7 +121,7 @@ public:
|
||||
th::optional<th::Tensor> top_p_reset_ids_opt);
|
||||
|
||||
th::Tensor forward(th::Tensor logits, // (batch_size, beam_width, vocab_size)
|
||||
int64_t step, int64_t max_input_length, int64_t max_kv_cache_length, int64_t ite, int64_t local_batch_size,
|
||||
int64_t step, int64_t max_input_length, int64_t max_attention_window, int64_t ite, int64_t local_batch_size,
|
||||
th::Tensor end_id, th::optional<th::Tensor> embedding_bias_opt,
|
||||
th::optional<th::Tensor> input_lengths_opt, // length of input contexts.
|
||||
th::optional<th::Tensor> sequence_limit_length_opt, th::optional<th::Tensor> stop_words_list_opt,
|
||||
|
||||
@ -4,13 +4,33 @@ This document explains how to build and run the C++ tests, and the included [res
|
||||
|
||||
Windows users: Be sure to set DLL paths as specified in [Extra Steps for C++ Runtime Usage](../../windows/README.md#extra-steps-for-c-runtime-usage).
|
||||
|
||||
## Compile
|
||||
## All-in-one script
|
||||
|
||||
The script [test_cpp.py](resources/scripts/test_cpp.py) can be executed to build TRT-LLM, build engines, generate expected outputs and run C++ tests all in one go.
|
||||
To get an overview of the parameters call:
|
||||
|
||||
```bash
|
||||
python3 cpp/tests/resources/scripts/test_cpp.py -h
|
||||
```
|
||||
|
||||
It is possible to choose a single model for end-to-end tests or skip models that should not be tested.
|
||||
An example call may look like this:
|
||||
|
||||
```bash
|
||||
CPP_BUILD_DIR=cpp/build
|
||||
MODEL_CACHE=/path/to/model_cache
|
||||
python3 cpp/tests/resources/scripts/test_cpp.py -a "80-real;86-real" --build_dir ${CPP_BUILD_DIR} --trt_root /usr/local/tensorrt --model_cache ${MODEL_CACHE} --only_gptj
|
||||
```
|
||||
|
||||
## Manual steps
|
||||
|
||||
### Compile
|
||||
|
||||
From the top-level directory call:
|
||||
|
||||
```bash
|
||||
CPP_BUILD_DIR=cpp/build
|
||||
python3 scripts/build_wheel.py -a "80-real;86-real" --build_dir ${CPP_BUILD_DIR}
|
||||
python3 scripts/build_wheel.py -a "80-real;86-real" --build_dir ${CPP_BUILD_DIR} --trt_root /usr/local/tensorrt
|
||||
pip install -r requirements-dev.txt --extra-index-url https://pypi.ngc.nvidia.com
|
||||
pip install build/tensorrt_llm*.whl
|
||||
cd $CPP_BUILD_DIR && make -j$(nproc) google-tests
|
||||
@ -22,11 +42,11 @@ Single tests can be executed from `CPP_BUILD_DIR/tests`, e.g.
|
||||
./$CPP_BUILD_DIR/tests/allocatorTest
|
||||
```
|
||||
|
||||
## End-to-end tests
|
||||
### End-to-end tests
|
||||
|
||||
`gptSessionTest`, `gptManagerTest` and `trtGptModelRealDecoderTest` require pre-built TensorRT engines, which are loaded in the tests. They also require data files which are stored in [cpp/tests/resources/data](resources/data).
|
||||
|
||||
### Build engines
|
||||
#### Build engines
|
||||
|
||||
[Scripts](resources/scripts) are provided that download the GPT2 and GPT-J models from Huggingface and convert them to TensorRT engines.
|
||||
The weights and built engines are stored under [cpp/tests/resources/models](resources/models).
|
||||
@ -45,7 +65,7 @@ It is possible to build engines with tensor and pipeline parallelism for LLaMA u
|
||||
PYTHONPATH=examples/llama python3 cpp/tests/resources/scripts/build_llama_engines.py --only_multi_gpu
|
||||
```
|
||||
|
||||
### Generate expected output
|
||||
#### Generate expected output
|
||||
|
||||
End-to-end tests read inputs and expected outputs from Numpy files located at [cpp/tests/resources/data](resources/data). The expected outputs can be generated using [scripts](resources/scripts) which employ the Python runtime to run the built engines:
|
||||
|
||||
@ -56,7 +76,7 @@ PYTHONPATH=examples:$PYTHONPATH python3 cpp/tests/resources/scripts/generate_exp
|
||||
PYTHONPATH=examples:$PYTHONPATH python3 cpp/tests/resources/scripts/generate_expected_chatglm_output.py
|
||||
```
|
||||
|
||||
### Generate data with tensor and pipeline parallelism
|
||||
#### Generate data with tensor and pipeline parallelism
|
||||
|
||||
It is possible to generate tensor and pipeline parallelism data for LLaMA using 4 GPUs. To generate results from the top-level directory:
|
||||
|
||||
@ -64,7 +84,7 @@ It is possible to generate tensor and pipeline parallelism data for LLaMA using
|
||||
PYTHONPATH=examples mpirun -n 4 python3 cpp/tests/resources/scripts/generate_expected_llama_output.py --only_multi_gpu
|
||||
```
|
||||
|
||||
### Run test
|
||||
#### Run test
|
||||
|
||||
After building the engines and generating the expected output execute the tests
|
||||
|
||||
@ -72,7 +92,7 @@ After building the engines and generating the expected output execute the tests
|
||||
./$CPP_BUILD_DIR/tests/gptSessionTest
|
||||
```
|
||||
|
||||
## Run all tests with ctest
|
||||
### Run all tests with ctest
|
||||
|
||||
To run all tests and produce an xml report, call
|
||||
|
||||
|
||||
@ -292,8 +292,9 @@ float compare(void* _pa, void* _pb, int size, float scale)
|
||||
#if defined(ENABLE_BF16)
|
||||
if constexpr (std::is_same_v<T, __nv_bfloat16>)
|
||||
{
|
||||
// bfloat16 has fewer mantissa digits than float16, so the cumulative error will be larger.
|
||||
diff_thres *= 2.f;
|
||||
// bfloat16 has fewer mantissa digits than float16(10 bits for fp16 but only 7 bits for bf16), so the cumulative
|
||||
// error will be larger.
|
||||
diff_thres *= 3.f;
|
||||
}
|
||||
else
|
||||
#endif
|
||||
@ -308,8 +309,7 @@ float compare(void* _pa, void* _pb, int size, float scale)
|
||||
template <typename T1, typename T2>
|
||||
void random_fill(std::vector<T1>& vec, T2 minv, T2 maxv)
|
||||
{
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::mt19937 gen(20231205);
|
||||
std::uniform_real_distribution<float> dis(static_cast<float>(minv), static_cast<float>(maxv));
|
||||
for (auto& v : vec)
|
||||
{
|
||||
|
||||
@ -98,7 +98,10 @@ def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1):
|
||||
run_command(["git", "lfs", "pull", "--include", model_file_name],
|
||||
cwd=hf_dir)
|
||||
|
||||
(hf_dir / "model.safetensors").unlink(missing_ok=True)
|
||||
safetensor_file = hf_dir / "model.safetensors"
|
||||
has_safetensor = safetensor_file.exists()
|
||||
if has_safetensor:
|
||||
safetensor_file.rename(str(safetensor_file) + ".bak")
|
||||
|
||||
assert (hf_dir / model_file_name).is_file()
|
||||
|
||||
@ -158,6 +161,9 @@ def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1):
|
||||
engine_dir / 'fp16-plugin-packed-paged-gather' / tp_pp_dir,
|
||||
tp_size, '--gather_all_token_logits', *ifb_args)
|
||||
|
||||
if has_safetensor:
|
||||
_pl.Path(str(safetensor_file) + ".bak").rename(safetensor_file)
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
|
||||
@ -32,7 +32,6 @@ def build_engine(weight_dir: _pl.Path, engine_dir: _pl.Path, *args):
|
||||
'--dtype=float16',
|
||||
'--logits_dtype=float16',
|
||||
'--use_gemm_plugin=float16',
|
||||
'--use_layernorm_plugin=float16',
|
||||
'--max_batch_size=32',
|
||||
'--max_input_len=40',
|
||||
'--max_output_len=20',
|
||||
|
||||
@ -55,7 +55,7 @@ def generate_output(engine: str,
|
||||
str(output_dir / (output_name + '.csv')), '--max_output_len',
|
||||
str(max_output_len), '--num_beams',
|
||||
str(num_beams), '--output_logits_npy',
|
||||
str(output_logits_npy)
|
||||
str(output_logits_npy), '--use_py_session'
|
||||
])
|
||||
run.main(args)
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ def generate_output(engine: str,
|
||||
str(output_dir / (output_name + '.npy')), '--output_csv',
|
||||
str(output_dir / (output_name + '.csv')), '--max_output_len',
|
||||
str(max_output_len), '--num_beams',
|
||||
str(num_beams)
|
||||
str(num_beams), '--use_py_session'
|
||||
])
|
||||
run.main(args)
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ def generate_output(engine: str,
|
||||
str(output_dir / (output_name + '.npy')), '--output_csv',
|
||||
str(output_dir / (output_name + '.csv')), '--max_output_len',
|
||||
str(max_output_len), '--num_beams',
|
||||
str(num_beams)
|
||||
str(num_beams), '--use_py_session'
|
||||
])
|
||||
run.main(args)
|
||||
|
||||
|
||||
@ -68,7 +68,7 @@ def build_trt_llm(python_exe: str,
|
||||
python_exe, "scripts/build_wheel.py", "--cuda_architectures",
|
||||
cuda_architectures, "--build_dir",
|
||||
str(build_dir), "--dist_dir",
|
||||
str(dist_dir)
|
||||
str(dist_dir), "--python_bindings"
|
||||
]
|
||||
if trt_root is not None:
|
||||
build_wheel += ["--trt_root", str(trt_root)]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user