mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#302)
* Update TensorRT-LLM --------- Co-authored-by: wangruohui <12756472+wangruohui@users.noreply.github.com>
This commit is contained in:
parent
4de32a86ae
commit
f044eb8d94
19
README.md
19
README.md
@ -8,7 +8,6 @@ TensorRT-LLM
|
||||
[](https://www.python.org/downloads/release/python-31012/)
|
||||
[](https://developer.nvidia.com/cuda-downloads)
|
||||
[](https://developer.nvidia.com/tensorrt)
|
||||
[](./setup.py)
|
||||
[](./LICENSE)
|
||||
|
||||
[Architecture](./docs/source/architecture.md) | [Results](./docs/source/performance.md) | [Examples](./examples/) | [Documentation](./docs/source/)
|
||||
@ -173,13 +172,13 @@ Lovelace architectures. Certain limitations may, however, apply.
|
||||
Various numerical precisions are supported in TensorRT-LLM. The support for
|
||||
some of those numerical features require specific architectures:
|
||||
|
||||
| | FP32 | FP16 | BF16 | FP8 | INT8 | INT4 |
|
||||
| :--------------------------- | :---- | :---- | :---- | :--- | :--- | :--- |
|
||||
| Volta (SM70) | Y | Y | N | N | Y | Y |
|
||||
| Turing (SM75) | Y | Y | N | N | Y | Y |
|
||||
| Ampere (SM80, SM86) | Y | Y | Y | N | Y | Y |
|
||||
| Ada-Lovelace (SM89) | Y | Y | Y | Y | Y | Y |
|
||||
| Hopper (SM90) | Y | Y | Y | Y | Y | Y |
|
||||
| | FP32 | FP16 | BF16 | FP8 | INT8 | INT4 |
|
||||
| :------------------ | :--- | :--- | :--- | :--- | :--- | :--- |
|
||||
| Volta (SM70) | Y | Y | N | N | Y | Y |
|
||||
| Turing (SM75) | Y | Y | N | N | Y | Y |
|
||||
| Ampere (SM80, SM86) | Y | Y | Y | N | Y | Y |
|
||||
| Ada-Lovelace (SM89) | Y | Y | Y | Y | Y | Y |
|
||||
| Hopper (SM90) | Y | Y | Y | Y | Y | Y |
|
||||
|
||||
In this release of TensorRT-LLM, the support for FP8 and quantized data types
|
||||
(INT8 or INT4) is not implemented for all the models. See the
|
||||
@ -217,8 +216,7 @@ The list of supported models is:
|
||||
* [Bert](examples/bert)
|
||||
* [Blip2](examples/blip2)
|
||||
* [BLOOM](examples/bloom)
|
||||
* [ChatGLM-6B](examples/chatglm6b)
|
||||
* [ChatGLM2-6B](examples/chatglm2-6b/)
|
||||
* [ChatGLM](examples/chatglm), including ChatGLM-6B, ChatGLM2-6B, ChatGLM2-6B-32k, ChatGLM3-6B, ChatGLM3-6B-32k
|
||||
* [Falcon](examples/falcon)
|
||||
* [GPT](examples/gpt)
|
||||
* [GPT-J](examples/gptj)
|
||||
@ -230,6 +228,7 @@ The list of supported models is:
|
||||
* [OPT](examples/opt)
|
||||
* [SantaCoder](examples/gpt)
|
||||
* [StarCoder](examples/gpt)
|
||||
* [InternLM](examples/internlm)
|
||||
|
||||
## Performance
|
||||
|
||||
|
||||
@ -18,12 +18,12 @@
|
||||
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
|
||||
#include "tensorrt_llm/runtime/gptJsonConfig.h"
|
||||
#include "tensorrt_llm/runtime/gptSession.h"
|
||||
#include "tensorrt_llm/runtime/memoryCounters.h"
|
||||
#include "tensorrt_llm/runtime/tllmLogger.h"
|
||||
|
||||
#include <NvInfer.h>
|
||||
#include <chrono>
|
||||
#include <cxxopts.hpp>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
@ -39,14 +39,22 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
std::shared_ptr<nvinfer1::ILogger> const& logger, int warmUp, int numRuns, int duration,
|
||||
GptSession::Config& sessionConfig, bool cudaGraphMode)
|
||||
{
|
||||
auto const json = GptJsonConfig::parse(dataPath / "config.json");
|
||||
|
||||
std::string modelNameHyphen = modelName;
|
||||
std::filesystem::path jsonFileName = dataPath / "config.json";
|
||||
if (tc::strStartsWith(modelName, "chatglm"))
|
||||
{
|
||||
std::replace(modelNameHyphen.begin(), modelNameHyphen.end(), '_', '-');
|
||||
jsonFileName = dataPath / (modelNameHyphen + std::string("-config.json"));
|
||||
}
|
||||
auto const json = GptJsonConfig::parse(jsonFileName);
|
||||
auto const modelConfig = json.getModelConfig();
|
||||
auto const inputPacked = modelConfig.usePackedInput();
|
||||
SizeType deviceCount{0};
|
||||
TLLM_CUDA_CHECK(cudaGetDeviceCount(&deviceCount));
|
||||
auto const worldConfig
|
||||
= WorldConfig::mpi(*logger, deviceCount, json.getTensorParallelism(), json.getPipelineParallelism());
|
||||
auto const enginePath = dataPath / json.engineFilename(worldConfig, modelName);
|
||||
auto const enginePath = dataPath / json.engineFilename(worldConfig, modelNameHyphen);
|
||||
auto const dtype = modelConfig.getDataType();
|
||||
auto const useHalf = (dtype == nvinfer1::DataType::kHALF);
|
||||
|
||||
@ -78,10 +86,15 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
auto constexpr endId = 50256;
|
||||
auto constexpr padId = 50256;
|
||||
|
||||
auto& memoryCounter = MemoryCounters::getInstance();
|
||||
TLLM_LOG_INFO(memoryCounter.toString());
|
||||
|
||||
for (auto const batchSize : batchSizes)
|
||||
{
|
||||
try
|
||||
{
|
||||
TLLM_LOG_INFO(memoryCounter.toString());
|
||||
|
||||
std::vector<SizeType> inputLenghtsHost(batchSize, maxInputLength);
|
||||
auto inputLenghts
|
||||
= bufferManager.copyFrom(inputLenghtsHost, ITensor::makeShape({batchSize}), MemoryType::kGPU);
|
||||
@ -99,6 +112,9 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
inputIds = bufferManager.copyFrom(
|
||||
inputsHost, ITensor::makeShape({batchSize, maxInputLength}), MemoryType::kGPU);
|
||||
}
|
||||
|
||||
TLLM_LOG_INFO(memoryCounter.toString());
|
||||
|
||||
GenerationInput generationInput{
|
||||
endId, padId, std::move(inputIds), std::move(inputLenghts), inputPacked};
|
||||
|
||||
@ -107,6 +123,8 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32),
|
||||
bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32)};
|
||||
|
||||
TLLM_LOG_INFO(memoryCounter.toString());
|
||||
|
||||
for (auto r = 0; r < warmUp; ++r)
|
||||
{
|
||||
SizeType numSteps = 0;
|
||||
@ -118,6 +136,8 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
}
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
TLLM_LOG_INFO(memoryCounter.toString());
|
||||
|
||||
int iterIdx = 0;
|
||||
float curDuration = 0;
|
||||
while (iterIdx < numRuns || curDuration / 1000 < duration)
|
||||
@ -134,6 +154,9 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
iterIdx += 1;
|
||||
curDuration += std::chrono::duration<float, std::milli>(end - start).count();
|
||||
}
|
||||
|
||||
TLLM_LOG_INFO(memoryCounter.toString());
|
||||
|
||||
printf("Benchmarking done. Iteration: %d, duration: %.2f sec.\n", iterIdx, curDuration / 1000);
|
||||
|
||||
if (worldConfig.getRank() == 0)
|
||||
@ -159,7 +182,7 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
// We can ignore the OOM exception and continue the rest of the benchmark
|
||||
if (worldConfig.getRank() == 0)
|
||||
{
|
||||
printf("%s", e.what());
|
||||
TLLM_LOG_EXCEPTION(e);
|
||||
printf(
|
||||
"[BENCHMARK] batch_size %d input_length %d output_length %d latency(ms) N/A tokensPerSec N/A\n",
|
||||
batchSize, maxInputLength, maxNewTokens);
|
||||
@ -167,6 +190,7 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
continue;
|
||||
}
|
||||
}
|
||||
TLLM_LOG_INFO(memoryCounter.toString());
|
||||
}
|
||||
}
|
||||
|
||||
@ -200,8 +224,8 @@ int main(int argc, char* argv[])
|
||||
options.add_options()("duration", "Minimal duration of iterations to measure in seconds.",
|
||||
cxxopts::value<int>()->default_value("60"));
|
||||
|
||||
options.add_options()(
|
||||
"num_micro_batches", "Number of micro batches if enabling pipeline parallelism.", cxxopts::value<int>());
|
||||
options.add_options()("ctx_micro_batch_size", "Batch size for context phase.", cxxopts::value<int>());
|
||||
options.add_options()("gen_micro_batch_size", "Batch size for generation phase.", cxxopts::value<int>());
|
||||
options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>());
|
||||
options.add_options()(
|
||||
"kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.", cxxopts::value<float>());
|
||||
@ -281,10 +305,15 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
|
||||
GptSession::Config sessionConfig{0, 0, 0};
|
||||
// Argument: Number of micro batches
|
||||
if (result.count("num_micro_batches"))
|
||||
// Argument: Batch size for context phase
|
||||
if (result.count("ctx_micro_batch_size"))
|
||||
{
|
||||
sessionConfig.numMicroBatches = result["num_micro_batches"].as<int>();
|
||||
sessionConfig.ctxMicroBatchSize = result["ctx_micro_batch_size"].as<int>();
|
||||
}
|
||||
// Argument: Batch size for generation phase
|
||||
if (result.count("gen_micro_batch_size"))
|
||||
{
|
||||
sessionConfig.genMicroBatchSize = result["gen_micro_batch_size"].as<int>();
|
||||
}
|
||||
// Argument: Max tokens in paged K-V Cache
|
||||
if (result.count("max_tokens_in_paged_kvcache"))
|
||||
|
||||
@ -48,6 +48,7 @@ class BuildConfig(BaseModel, extra=Extra.allow):
|
||||
# default value to be None, not 0 or 1 to prevent misuse
|
||||
rotary_pct: Optional[float] = None
|
||||
bias: bool = True
|
||||
quantization: Optional[str] = None
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
@ -121,7 +122,7 @@ _allowed_configs = {
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
use_smooth_quant=True,
|
||||
quantization="int8_sq_per_tensor",
|
||||
)),
|
||||
"gpt_350m_sq_per_token_channel":
|
||||
ModelConfig(name="gpt_350m_sq_per_token_channel",
|
||||
@ -138,9 +139,7 @@ _allowed_configs = {
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
use_smooth_quant=True,
|
||||
per_token=True,
|
||||
per_channel=True,
|
||||
quantization="int8_sq_per_token_channel",
|
||||
)),
|
||||
"gpt-next_2b":
|
||||
ModelConfig(name="gpt-next_2b",
|
||||
@ -318,7 +317,7 @@ _allowed_configs = {
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
use_smooth_quant=True)),
|
||||
quantization="int8_sq_per_tensor")),
|
||||
"gptj_6b":
|
||||
ModelConfig(name="gptj_6b",
|
||||
family="gptj",
|
||||
@ -354,7 +353,7 @@ _allowed_configs = {
|
||||
builder_opt=None,
|
||||
)),
|
||||
"chatglm_6b":
|
||||
ModelConfig(name="chatglm_6b",
|
||||
ModelConfig(name="chatglm-6b",
|
||||
family="chatglm",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
@ -371,7 +370,7 @@ _allowed_configs = {
|
||||
remove_input_padding=False,
|
||||
)),
|
||||
"chatglm2_6b":
|
||||
ModelConfig(name="chatglm2_6b",
|
||||
ModelConfig(name="chatglm2-6b",
|
||||
family="chatglm2",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
@ -387,6 +386,23 @@ _allowed_configs = {
|
||||
builder_opt=None,
|
||||
remove_input_padding=False,
|
||||
)),
|
||||
"chatglm3_6b":
|
||||
ModelConfig(name="chatglm3-6b",
|
||||
family="chatglm3",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
num_layers=28,
|
||||
num_heads=32,
|
||||
hidden_size=4096,
|
||||
vocab_size=65024,
|
||||
hidden_act='swiglu',
|
||||
n_positions=2048,
|
||||
max_batch_size=256,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
remove_input_padding=False,
|
||||
)),
|
||||
"bloom_560m":
|
||||
ModelConfig(name="bloom_560m",
|
||||
family="bloom",
|
||||
|
||||
@ -18,15 +18,11 @@ from multiprocessing import Process, Queue
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
from allowed_configs import get_allowed_models
|
||||
from bert_benchmark import BERTBenchmark
|
||||
from gpt_benchmark import GPTBenchmark
|
||||
from mem_monitor import mem_monitor
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
from allowed_configs import get_allowed_models
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Benchmark TensorRT-LLM models.')
|
||||
parser.add_argument('-m',
|
||||
@ -172,18 +168,7 @@ def parse_arguments():
|
||||
help=
|
||||
'Quick sanity check with num_layer=1; will be silently ignored if --engine_dir is specified.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--enable_fp8',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Use FP8 Linear layer for LMHead, Attention QKV/Dense, and MLP.')
|
||||
parser.add_argument(
|
||||
'--fp8_kv_cache',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use dtype for KV cache. fp8_kv_cache chooses fp8 quantization for KV'
|
||||
)
|
||||
|
||||
parser.add_argument('--csv',
|
||||
default=False,
|
||||
action="store_true",
|
||||
@ -199,11 +184,38 @@ def parse_arguments():
|
||||
help=
|
||||
'Use latency-optimized all-reduce for tensor parallelism. Gives better performance with NVLink.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--strongly_typed',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help=
|
||||
'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--quantization',
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[
|
||||
'fp8', 'fp8_gemm', 'fp8_kv_cache', 'int8_sq_per_tensor',
|
||||
'int8_sq_per_token_channel', 'int8_weight_only', 'int4_weight_only',
|
||||
'int4_weight_only_awq', 'int4_weight_only_gptq'
|
||||
],
|
||||
help="Optimize the model with specified quantization recipe")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
# We import tensorrt_llm here because MPI is initialized when
|
||||
# tensorrt_llm is imported, but mpi4py does not work well with
|
||||
# the start method `spawn` of Python multiprocessing,
|
||||
# so we set the start method first, then initialize MPI.
|
||||
from allowed_configs import get_allowed_models
|
||||
from bert_benchmark import BERTBenchmark
|
||||
from gpt_benchmark import GPTBenchmark
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
logger.set_level(args.log_level)
|
||||
|
||||
# Batch size
|
||||
@ -235,10 +247,10 @@ def main(args):
|
||||
args.max_output_len,
|
||||
args.max_batch_size,
|
||||
force_num_layer_1=args.force_num_layer_1,
|
||||
enable_fp8=args.enable_fp8,
|
||||
fp8_kv_cache=args.fp8_kv_cache,
|
||||
enable_cuda_graph=args.enable_cuda_graph,
|
||||
enable_custom_all_reduce=args.enable_custom_all_reduce)
|
||||
enable_custom_all_reduce=args.enable_custom_all_reduce,
|
||||
strongly_typed=args.strongly_typed,
|
||||
quantization=args.quantization)
|
||||
elif args.model in get_allowed_models(benchmark_type="bert"):
|
||||
benchmarker = BERTBenchmark(args.engine_dir,
|
||||
args.model,
|
||||
@ -273,8 +285,8 @@ def main(args):
|
||||
# Launch a subprocess to monitor memory usage
|
||||
q1 = Queue() # q1 is used for sending signal to subprocess
|
||||
q2 = Queue() # q2 is used for receiving results from subprocess
|
||||
p = Process(target=mem_monitor, args=(q1, q2))
|
||||
p.start()
|
||||
mem_monitor_process = Process(target=mem_monitor, args=(q1, q2))
|
||||
mem_monitor_process.start()
|
||||
|
||||
iter_idx = 0
|
||||
try:
|
||||
@ -301,14 +313,14 @@ def main(args):
|
||||
|
||||
except Exception as e:
|
||||
print("Found exception during benchmarking", e.with_traceback())
|
||||
p.kill()
|
||||
mem_monitor_process.kill()
|
||||
raise e
|
||||
logger.debug("Sending signal to mem monitor process, start")
|
||||
q1.put(1)
|
||||
logger.debug("Sending signal to mem monitor process, done")
|
||||
peak_gpu_used = q2.get()
|
||||
logger.debug("Get peak gpu memory usage from mem monitor process, done")
|
||||
p.join()
|
||||
mem_monitor_process.join()
|
||||
logger.debug("Memory monitor process joined")
|
||||
|
||||
latency = round(sum(latencies) / iter_idx, 3)
|
||||
|
||||
@ -24,8 +24,7 @@ import tensorrt_llm
|
||||
from tensorrt_llm._utils import str_dtype_to_trt
|
||||
from tensorrt_llm.builder import Builder
|
||||
from tensorrt_llm.layers import PositionEmbeddingType
|
||||
from tensorrt_llm.models import (fp8_quantize, smooth_quantize,
|
||||
weight_only_quantize)
|
||||
from tensorrt_llm.models import quantize_model
|
||||
from tensorrt_llm.network import net_guard
|
||||
from tensorrt_llm.plugin.plugin import ContextFMHAType
|
||||
from tensorrt_llm.quantization import QuantMode
|
||||
@ -61,6 +60,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
self.fuse_bias = True
|
||||
|
||||
self.cuda_graph_mode = kwargs.get('enable_cuda_graph', False)
|
||||
self.strongly_typed = kwargs.get('strongly_typed', False)
|
||||
self.enable_custom_all_reduce = enable_custom_all_reduce
|
||||
|
||||
if engine_dir is not None:
|
||||
@ -73,12 +73,9 @@ class GPTBenchmark(BaseBenchmark):
|
||||
# Build engine
|
||||
self.world_size = tensorrt_llm.mpi_world_size()
|
||||
self.apply_query_key_layer_scaling = False
|
||||
self.use_smooth_quant = False
|
||||
# this attribute is not stored in allowed_config
|
||||
self.enable_fp8 = kwargs.get('enable_fp8', False)
|
||||
self.fp8_kv_cache = kwargs.get('fp8_kv_cache', False)
|
||||
|
||||
self.use_weight_only = False
|
||||
self.per_group = False
|
||||
self.weight_only_precision = 'int8'
|
||||
self.per_token = False
|
||||
self.per_channel = False
|
||||
@ -95,12 +92,17 @@ class GPTBenchmark(BaseBenchmark):
|
||||
self.use_rmsnorm_plugin = False
|
||||
self.use_lookup_plugin = non_mha_plg_dtype
|
||||
self.enable_context_fmha = use_mha_plugin
|
||||
self.quant_mode = QuantMode(0)
|
||||
|
||||
self.remove_input_padding = use_non_mha_plugin
|
||||
|
||||
for key, value in get_build_config(model_name).items():
|
||||
setattr(self, key, value)
|
||||
|
||||
if self.quantization is None:
|
||||
self.quantization = kwargs.get('quantization', None)
|
||||
|
||||
self.set_quantization()
|
||||
|
||||
# Override the n_position/max_input_len/max_output_len/max_batch_size to value from cmd line if that's specified.
|
||||
if n_positions is not None:
|
||||
assert isinstance(
|
||||
@ -126,20 +128,6 @@ class GPTBenchmark(BaseBenchmark):
|
||||
self.num_kv_heads = self.num_heads
|
||||
if kwargs.get('force_num_layer_1', False):
|
||||
self.num_layers = 1
|
||||
|
||||
if self.use_smooth_quant:
|
||||
self.quant_mode = QuantMode.use_smooth_quant(
|
||||
self.per_token, self.per_channel)
|
||||
elif self.use_weight_only:
|
||||
self.quant_mode = QuantMode.use_weight_only(
|
||||
self.weight_only_precision == 'int4')
|
||||
|
||||
if self.enable_fp8:
|
||||
self.quant_mode = self.quant_mode.set_fp8_qdq()
|
||||
|
||||
if self.fp8_kv_cache:
|
||||
self.quant_mode = self.quant_mode.set_fp8_kv_cache()
|
||||
|
||||
engine_buffer = self.build()
|
||||
|
||||
assert engine_buffer is not None
|
||||
@ -155,16 +143,25 @@ class GPTBenchmark(BaseBenchmark):
|
||||
quant_mode=self.quant_mode,
|
||||
use_custom_all_reduce=self.enable_custom_all_reduce,
|
||||
)
|
||||
if model_name == 'chatglm_6b':
|
||||
if model_name == 'chatglm-6b':
|
||||
self.sampling_config = tensorrt_llm.runtime.SamplingConfig(
|
||||
end_id=130005,
|
||||
pad_id=3,
|
||||
num_beams=num_beams,
|
||||
top_k=top_k,
|
||||
top_p=top_p)
|
||||
self.decoder = tensorrt_llm.runtime.ChatGLM6BHeadModelGenerationSession(
|
||||
self.decoder = tensorrt_llm.runtime.ChatGLMGenerationSession(
|
||||
model_config, engine_buffer, self.runtime_mapping)
|
||||
elif model_name == 'chatglm2_6b':
|
||||
elif model_name == 'chatglm2-6b':
|
||||
self.sampling_config = tensorrt_llm.runtime.SamplingConfig(
|
||||
end_id=2,
|
||||
pad_id=0,
|
||||
num_beams=num_beams,
|
||||
top_k=top_k,
|
||||
top_p=top_p)
|
||||
self.decoder = tensorrt_llm.runtime.GenerationSession(
|
||||
model_config, engine_buffer, self.runtime_mapping)
|
||||
elif model_name == 'chatglm3-6b':
|
||||
self.sampling_config = tensorrt_llm.runtime.SamplingConfig(
|
||||
end_id=2,
|
||||
pad_id=0,
|
||||
@ -212,6 +209,75 @@ class GPTBenchmark(BaseBenchmark):
|
||||
self.decoder.setup(batch_size, inlen, outlen, beam_width=self.num_beams)
|
||||
return (input_ids, input_lengths)
|
||||
|
||||
def set_quantization(self):
|
||||
self.quant_mode = QuantMode(0)
|
||||
|
||||
if self.quantization == "fp8":
|
||||
self.strongly_typed = True
|
||||
self.quant_mode = self.quant_mode.set_fp8_qdq()
|
||||
self.quant_mode = self.quant_mode.set_fp8_kv_cache()
|
||||
|
||||
elif self.quantization == "fp8_gemm":
|
||||
self.strongly_typed = True
|
||||
self.quant_mode = self.quant_mode.set_fp8_qdq()
|
||||
|
||||
elif self.quantization == "fp8_kv_cache":
|
||||
self.strongly_typed = True
|
||||
self.quant_mode = self.quant_mode.set_fp8_kv_cache()
|
||||
|
||||
elif self.quantization == "int8_sq_per_tensor":
|
||||
self.use_smooth_quant = True
|
||||
self.quant_mode = QuantMode.use_smooth_quant(
|
||||
self.per_token, self.per_channel)
|
||||
|
||||
elif self.quantization == "int8_sq_per_token_channel":
|
||||
self.use_smooth_quant = True
|
||||
self.per_token = True
|
||||
self.per_channel = True
|
||||
self.quant_mode = QuantMode.use_smooth_quant(
|
||||
self.per_token, self.per_channel)
|
||||
|
||||
elif self.quantization == "int8_weight_only":
|
||||
self.use_smooth_quant = False
|
||||
self.use_weight_only = True
|
||||
self.weight_only_precision = 'int8'
|
||||
self.quant_mode = QuantMode.use_weight_only(False)
|
||||
|
||||
elif self.quantization == "int4_weight_only":
|
||||
self.use_weight_only = True
|
||||
self.weight_only_precision = 'int4'
|
||||
self.quant_mode = QuantMode.use_weight_only(True)
|
||||
|
||||
elif self.quantization == "int4_weight_only_awq":
|
||||
self.use_weight_only = True
|
||||
self.per_group = True
|
||||
self.weight_only_precision = 'int4_awq'
|
||||
self.quant_mode = QuantMode.from_description(
|
||||
quantize_weights=True,
|
||||
quantize_activations=False,
|
||||
per_token=False,
|
||||
per_channel=False,
|
||||
per_group=True,
|
||||
use_int4_weights=True)
|
||||
|
||||
elif self.quantization == "int4_weight_only_gptq":
|
||||
self.use_weight_only = True
|
||||
self.per_group = True
|
||||
self.weight_only_precision = 'int4_gptq'
|
||||
self.quant_mode = QuantMode.from_description(
|
||||
quantize_weights=True,
|
||||
quantize_activations=False,
|
||||
per_token=False,
|
||||
per_channel=False,
|
||||
per_group=True,
|
||||
use_int4_weights=True)
|
||||
|
||||
elif self.quantization == None:
|
||||
pass
|
||||
|
||||
else:
|
||||
raise Exception(f'{0} is invalid config: {self.quantization}')
|
||||
|
||||
def build(self):
|
||||
builder = Builder()
|
||||
builder_config = builder.create_builder_config(
|
||||
@ -232,10 +298,10 @@ class GPTBenchmark(BaseBenchmark):
|
||||
max_input_len=self.max_input_len,
|
||||
max_output_len=self.max_output_len,
|
||||
int8=self.quant_mode.has_act_and_weight_quant(),
|
||||
fp8=self.quant_mode.has_fp8_qdq(),
|
||||
quant_mode=self.quant_mode,
|
||||
use_refit=self.refit,
|
||||
opt_level=self.builder_opt)
|
||||
opt_level=self.builder_opt,
|
||||
strongly_typed=self.strongly_typed)
|
||||
engine_name = get_engine_name(self.model_name, self.dtype,
|
||||
self.world_size, self.runtime_rank)
|
||||
|
||||
@ -322,7 +388,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
apply_query_key_layer_scaling=builder_config.
|
||||
apply_query_key_layer_scaling)
|
||||
elif family == "chatglm":
|
||||
tensorrt_llm_model = tensorrt_llm.models.ChatGLM6BHeadModel(
|
||||
tensorrt_llm_model = tensorrt_llm.models.ChatGLMHeadModel(
|
||||
num_layers=self.num_layers,
|
||||
num_heads=self.num_heads,
|
||||
hidden_size=self.hidden_size,
|
||||
@ -335,9 +401,10 @@ class GPTBenchmark(BaseBenchmark):
|
||||
tp_size=self.world_size), # TP only
|
||||
apply_query_key_layer_scaling=builder_config.
|
||||
apply_query_key_layer_scaling,
|
||||
quant_mode=self.quant_mode)
|
||||
quant_mode=self.quant_mode,
|
||||
model_version="1")
|
||||
elif family == "chatglm2":
|
||||
tensorrt_llm_model = tensorrt_llm.models.ChatGLM2_6BHeadModel(
|
||||
tensorrt_llm_model = tensorrt_llm.models.ChatGLMHeadModel(
|
||||
num_layers=self.num_layers,
|
||||
num_heads=self.num_heads,
|
||||
hidden_size=self.hidden_size,
|
||||
@ -350,7 +417,24 @@ class GPTBenchmark(BaseBenchmark):
|
||||
tp_size=self.world_size), # TP only
|
||||
apply_query_key_layer_scaling=builder_config.
|
||||
apply_query_key_layer_scaling,
|
||||
quant_mode=self.quant_mode)
|
||||
quant_mode=self.quant_mode,
|
||||
model_version="2")
|
||||
elif family == "chatglm3":
|
||||
tensorrt_llm_model = tensorrt_llm.models.ChatGLMHeadModel(
|
||||
num_layers=self.num_layers,
|
||||
num_heads=self.num_heads,
|
||||
hidden_size=self.hidden_size,
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_act=self.hidden_act,
|
||||
max_position_embeddings=self.n_positions,
|
||||
dtype=kv_dtype,
|
||||
mapping=tensorrt_llm.Mapping(
|
||||
world_size=self.world_size,
|
||||
tp_size=self.world_size), # TP only
|
||||
apply_query_key_layer_scaling=builder_config.
|
||||
apply_query_key_layer_scaling,
|
||||
quant_mode=self.quant_mode,
|
||||
model_version="3")
|
||||
elif family == "bloom":
|
||||
tensorrt_llm_model = tensorrt_llm.models.BloomForCausalLM(
|
||||
num_layers=self.num_layers,
|
||||
@ -362,6 +446,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
mapping=tensorrt_llm.Mapping(
|
||||
world_size=self.world_size,
|
||||
tp_size=self.world_size), # TP only
|
||||
quant_mode=self.quant_mode,
|
||||
use_parallel_embedding=(self.model_name == 'bloom_176b'))
|
||||
elif family == "falcon":
|
||||
tensorrt_llm_model = tensorrt_llm.models.FalconForCausalLM(
|
||||
@ -381,27 +466,34 @@ class GPTBenchmark(BaseBenchmark):
|
||||
else:
|
||||
raise Exception(f'Unexpected model: {self.model_name}')
|
||||
|
||||
if self.use_smooth_quant:
|
||||
tensorrt_llm_model = smooth_quantize(tensorrt_llm_model,
|
||||
self.quant_mode)
|
||||
elif self.use_weight_only and self.weight_only_precision == 'int8':
|
||||
tensorrt_llm_model = weight_only_quantize(
|
||||
tensorrt_llm_model, QuantMode.use_weight_only())
|
||||
elif self.use_weight_only and self.weight_only_precision == 'int4':
|
||||
tensorrt_llm_model = weight_only_quantize(
|
||||
tensorrt_llm_model,
|
||||
QuantMode.use_weight_only(use_int4_weights=True))
|
||||
elif self.enable_fp8 or self.fp8_kv_cache:
|
||||
tensorrt_llm_model = fp8_quantize(tensorrt_llm_model,
|
||||
self.quant_mode)
|
||||
quant_kwargs = {}
|
||||
if family == "llama" and self.use_weight_only:
|
||||
if self.weight_only_precision == 'int4_awq':
|
||||
quant_kwargs = {
|
||||
"group_size": 128,
|
||||
"zero": False,
|
||||
"pre_quant_scale": True,
|
||||
"exclude_modules": [],
|
||||
}
|
||||
elif self.weight_only_precision == 'int4_gptq':
|
||||
quant_kwargs = {
|
||||
"group_size": 128,
|
||||
"zero": True,
|
||||
"pre_quant_scale": False,
|
||||
}
|
||||
tensorrt_llm_model = quantize_model(tensorrt_llm_model, self.quant_mode,
|
||||
**quant_kwargs)
|
||||
|
||||
# Module -> Network
|
||||
network = builder.create_network()
|
||||
network.trt_network.name = engine_name
|
||||
|
||||
not_fp8_quantization = self.quantization is None or "fp8" not in self.quantization
|
||||
|
||||
if self.use_gpt_attention_plugin:
|
||||
network.plugin_config.set_gpt_attention_plugin(
|
||||
dtype=self.use_gpt_attention_plugin)
|
||||
if self.use_gemm_plugin:
|
||||
if self.use_gemm_plugin and not_fp8_quantization:
|
||||
network.plugin_config.set_gemm_plugin(dtype=self.use_gemm_plugin)
|
||||
if self.use_layernorm_plugin:
|
||||
network.plugin_config.set_layernorm_plugin(
|
||||
|
||||
@ -27,10 +27,14 @@ project(tensorrt_llm LANGUAGES CXX)
|
||||
|
||||
# Build options
|
||||
option(BUILD_PYT "Build in PyTorch TorchScript class mode" ON)
|
||||
option(BUILD_PYBIND "Build Python bindings for C++ runtime and batch manager"
|
||||
OFF)
|
||||
option(BUILD_TESTS "Build Google tests" ON)
|
||||
option(BUILD_BENCHMARKS "Build benchmarks" ON)
|
||||
option(NVTX_DISABLE "Disable all NVTX features" ON)
|
||||
option(WARNING_IS_ERROR "Treat all warnings as errors" OFF)
|
||||
option(FAST_BUILD "Skip compiling some kernels to accelerate compiling" OFF)
|
||||
option(FAST_MATH "Compiling in fast math mode" OFF)
|
||||
|
||||
if(NVTX_DISABLE)
|
||||
add_compile_definitions("NVTX_DISABLE")
|
||||
@ -73,6 +77,11 @@ else()
|
||||
message(STATUS "Not building benchmarks")
|
||||
endif()
|
||||
|
||||
if(FAST_BUILD)
|
||||
add_compile_definitions("FAST_BUILD")
|
||||
message(WARNING "Skip some kernels to accelerate compilation")
|
||||
endif()
|
||||
|
||||
# Determine CUDA version before enabling the language extension
|
||||
check_language(CUDA)
|
||||
if(CMAKE_CUDA_COMPILER)
|
||||
@ -229,6 +238,10 @@ endif()
|
||||
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
|
||||
if(FAST_MATH)
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --use_fast_math")
|
||||
message("CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
|
||||
endif()
|
||||
|
||||
set(COMMON_HEADER_DIRS ${PROJECT_SOURCE_DIR} ${CUDAToolkit_INCLUDE_DIR})
|
||||
message(STATUS "COMMON_HEADER_DIRS: ${COMMON_HEADER_DIRS}")
|
||||
@ -333,3 +346,11 @@ if(BUILD_BENCHMARKS)
|
||||
add_subdirectory(${TRT_LLM_ROOT_DIR}/benchmarks/cpp
|
||||
${CMAKE_BINARY_DIR}/benchmarks)
|
||||
endif()
|
||||
|
||||
# Measure the compile time
|
||||
option(MEASURE_BUILD_TIME "Measure the build time of each module" OFF)
|
||||
if(MEASURE_BUILD_TIME)
|
||||
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CMAKE_COMMAND} -E time")
|
||||
set_property(GLOBAL PROPERTY RULE_LAUNCH_CUSTOM "${CMAKE_COMMAND} -E time")
|
||||
set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "${CMAKE_COMMAND} -E time")
|
||||
endif()
|
||||
|
||||
@ -46,6 +46,7 @@ class GptManager
|
||||
public:
|
||||
using SizeType = tensorrt_llm::runtime::SizeType;
|
||||
using RequestList = std::list<std::shared_ptr<LlmRequest>>;
|
||||
using TensorPtr = runtime::ITensor::SharedPtr;
|
||||
|
||||
GptManager(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, int32_t maxBeamWidth,
|
||||
batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback getInferenceRequestsCb,
|
||||
@ -108,6 +109,9 @@ private:
|
||||
inline static const std::string kBeamWidthTensorName_ = "beam_width";
|
||||
inline static const std::string kEndIdTensorName_ = "end_id";
|
||||
inline static const std::string kPadIdTensorName_ = "pad_id";
|
||||
inline static const std::string kBadWordsListTensorName_ = "bad_words_list";
|
||||
inline static const std::string kStopWordsListTensorName_ = "stop_words_list";
|
||||
inline static const std::string kEmbeddingBiasTensorName_ = "embedding_bias";
|
||||
inline static const std::string kTemperatureTensorName_ = "temperature";
|
||||
inline static const std::string kRuntimeTopKTensorName_ = "runtime_top_k";
|
||||
inline static const std::string kRuntimeTopPTensorName_ = "runtime_top_p";
|
||||
@ -116,6 +120,8 @@ private:
|
||||
inline static const std::string kMinLengthTensorName_ = "min_length";
|
||||
inline static const std::string kPresencePenaltyTensorName_ = "presence_penalty";
|
||||
inline static const std::string kRandomSeedTensorName_ = "random_seed";
|
||||
inline static const std::string kPromptEmbeddingTableName_ = "prompt_embedding_table";
|
||||
inline static const std::string kPromptVocabSizeName_ = "prompt_vocab_size";
|
||||
inline static const std::string kOutputIdsTensorName_ = "output_ids";
|
||||
inline static const std::string kSequenceLengthTensorName_ = "sequence_length";
|
||||
|
||||
|
||||
@ -33,6 +33,16 @@
|
||||
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
{
|
||||
|
||||
using SizeType = tensorrt_llm::runtime::SizeType;
|
||||
|
||||
struct KvCacheStats
|
||||
{
|
||||
SizeType maxNumBlocks;
|
||||
SizeType freeNumBlocks;
|
||||
SizeType usedNumBlocks;
|
||||
SizeType toksPerBlock;
|
||||
};
|
||||
|
||||
// Basic building block of a paged KV cache - a single
|
||||
// cache block. This class just holds metadata, no pointers
|
||||
// since it is reused across all layers.
|
||||
@ -231,6 +241,17 @@ public:
|
||||
return mBlockManager.getNumFreeBlocks();
|
||||
}
|
||||
|
||||
[[nodiscard]] KvCacheStats getKvCacheStats() const
|
||||
{
|
||||
KvCacheStats kvCacheStats;
|
||||
kvCacheStats.maxNumBlocks = getMaxNumBlocks();
|
||||
kvCacheStats.freeNumBlocks = getNumFreeBlocks();
|
||||
kvCacheStats.usedNumBlocks = getUsedNumBlocks();
|
||||
kvCacheStats.toksPerBlock = getTokensPerBlock();
|
||||
|
||||
return kvCacheStats;
|
||||
}
|
||||
|
||||
// Volume of [2, numKvHeads, tokensPerBlock, sizePerHead]
|
||||
[[nodiscard]] SizeType getBlockSize() const
|
||||
{
|
||||
|
||||
@ -16,6 +16,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/samplingConfig.h"
|
||||
|
||||
#include <assert.h>
|
||||
@ -41,10 +43,14 @@ public:
|
||||
using TokenIdType = runtime::TokenIdType;
|
||||
using RequestIdType = std::uint64_t;
|
||||
using BeamTokens = std::vector<std::vector<TokenIdType>>;
|
||||
using TensorPtr = runtime::ITensor::SharedPtr;
|
||||
|
||||
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<std::vector<TokenIdType>> input_tokens,
|
||||
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
|
||||
std::optional<SizeType> padId = std::nullopt)
|
||||
std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
|
||||
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
|
||||
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
|
||||
std::optional<SizeType> promptVocabSize = std::nullopt)
|
||||
: mRequestId(requestId)
|
||||
, mPromptLen(input_tokens->size())
|
||||
, mMaxNewTokens(maxNewTokens)
|
||||
@ -54,10 +60,25 @@ public:
|
||||
, mEndId(endId)
|
||||
, mPadId(padId)
|
||||
, mBatchSlot(-1)
|
||||
, mEmbeddingBias(embeddingBias)
|
||||
, mBadWordsList(badWordsList)
|
||||
, mStopWordsList(stopWordsList)
|
||||
, mPromptEmbeddingTable(promptEmbeddingTable)
|
||||
, mPromptVocabSize(promptVocabSize)
|
||||
{
|
||||
mMaxSentTokenPos = mPromptLen - 1;
|
||||
// Scatter the input tokens to other beam
|
||||
mTokens = std::make_shared<BeamTokens>(mSamplingConfig.beamWidth, *input_tokens);
|
||||
|
||||
if ((mPromptEmbeddingTable.has_value() && !mPromptVocabSize.has_value())
|
||||
|| (!mPromptEmbeddingTable.has_value() && mPromptVocabSize.has_value()))
|
||||
{
|
||||
std::string errStr
|
||||
= "Prompt embedding table and prompt vocab size tensors must both be provided for requests with prompt "
|
||||
"tuning enabled.";
|
||||
TLLM_LOG_ERROR(errStr);
|
||||
throw std::runtime_error(errStr);
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Get total number of tokens for this req (prompt + generated)
|
||||
@ -104,6 +125,14 @@ public:
|
||||
return getMaxBeamNumTokens() - mPromptLen;
|
||||
}
|
||||
|
||||
/// @brief Add new generated tokens to the vector of tokens
|
||||
/// @param token The token to add
|
||||
/// @param beam The beam to which to add the new token
|
||||
void addNewToken(TokenIdType token, SizeType beam)
|
||||
{
|
||||
mTokens->at(beam).push_back(token);
|
||||
}
|
||||
|
||||
/// @brief Add new generated tokens to the vector of tokens
|
||||
/// @param beamTokens A vector containing the tokens to add for each beam index
|
||||
/// beamTokens is expected to be of size beamWidth
|
||||
@ -174,6 +203,46 @@ public:
|
||||
mMaxSentTokenPos = pos;
|
||||
}
|
||||
|
||||
std::optional<TensorPtr> getPromptEmbeddingTable() const
|
||||
{
|
||||
return mPromptEmbeddingTable;
|
||||
}
|
||||
|
||||
void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager)
|
||||
{
|
||||
if (!mPromptEmbeddingTable.has_value()
|
||||
|| mPromptEmbeddingTable.value()->getMemoryType() == runtime::MemoryType::kGPU)
|
||||
{
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
TensorPtr gpuPromptEmbeddingTable
|
||||
= manager.copyFrom(*mPromptEmbeddingTable.value(), runtime::MemoryType::kGPU);
|
||||
mPromptEmbeddingTable = gpuPromptEmbeddingTable;
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<SizeType> getPromptVocabSize() const
|
||||
{
|
||||
return mPromptVocabSize;
|
||||
}
|
||||
|
||||
std::optional<TensorPtr> getEmbeddingBias() const
|
||||
{
|
||||
return mEmbeddingBias;
|
||||
}
|
||||
|
||||
std::optional<TensorPtr> getBadWordsList() const
|
||||
{
|
||||
return mBadWordsList;
|
||||
}
|
||||
|
||||
std::optional<TensorPtr> getStopWordsList() const
|
||||
{
|
||||
return mStopWordsList;
|
||||
}
|
||||
|
||||
RequestIdType mRequestId;
|
||||
SizeType mPromptLen;
|
||||
SizeType mMaxNewTokens;
|
||||
@ -188,6 +257,13 @@ public:
|
||||
private:
|
||||
std::shared_ptr<BeamTokens> mTokens;
|
||||
SizeType mMaxSentTokenPos;
|
||||
|
||||
std::optional<TensorPtr> mEmbeddingBias;
|
||||
std::optional<TensorPtr> mBadWordsList;
|
||||
std::optional<TensorPtr> mStopWordsList;
|
||||
|
||||
std::optional<TensorPtr> mPromptEmbeddingTable;
|
||||
std::optional<SizeType> mPromptVocabSize;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/promptTuningParams.h"
|
||||
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
@ -26,18 +27,20 @@
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
|
||||
class GenerationInput
|
||||
template <typename TTensor, typename PromptTuningParams>
|
||||
class GenericGenerationInput
|
||||
{
|
||||
public:
|
||||
using TensorPtr = ITensor::SharedPtr;
|
||||
using TensorPtr = TTensor;
|
||||
|
||||
explicit GenerationInput(
|
||||
explicit GenericGenerationInput(
|
||||
SizeType const endId, SizeType const padId, TensorPtr ids, TensorPtr lengths, bool packed = false)
|
||||
: endId{endId}
|
||||
, padId{padId}
|
||||
, ids{std::move(ids)}
|
||||
, lengths{std::move(lengths)}
|
||||
, packed{packed}
|
||||
, maxNewTokens(std::nullopt)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(static_cast<bool>(this->ids), "Invalid ids tensor");
|
||||
TLLM_CHECK_WITH_INFO(static_cast<bool>(this->lengths), "Invalid lengths tensor");
|
||||
@ -55,6 +58,22 @@ public:
|
||||
TensorPtr badWordsList; // [2, badWordsLength] or [batchSize, 2, badWordsLength], on gpu
|
||||
TensorPtr stopWordsList; // [batchSize, 2, stopWordsLength], on gpu
|
||||
std::optional<SizeType> maxNewTokens; // max number of tokens to generate
|
||||
|
||||
// Ptuning parameters
|
||||
PromptTuningParams promptTuningParams; // See promptTuningParams.h for expected shapes
|
||||
};
|
||||
|
||||
class GenerationInput : public GenericGenerationInput<ITensor::SharedPtr, PromptTuningParams>
|
||||
{
|
||||
public:
|
||||
using Base = GenericGenerationInput<ITensor::SharedPtr, PromptTuningParams>;
|
||||
using TensorPtr = Base::TensorPtr;
|
||||
|
||||
explicit GenerationInput(
|
||||
SizeType const endId, SizeType const padId, TensorPtr ids, TensorPtr lengths, bool packed = false)
|
||||
: GenericGenerationInput(endId, padId, std::move(ids), std::move(lengths), packed)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -26,14 +26,14 @@
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
|
||||
class GenerationOutput
|
||||
template <typename TTensor>
|
||||
class GenericGenerationOutput
|
||||
{
|
||||
public:
|
||||
using TensorPtr = ITensor::SharedPtr;
|
||||
|
||||
using TensorPtr = TTensor;
|
||||
using Callback = std::function<void(TensorPtr const& ids, SizeType step, bool finished)>;
|
||||
|
||||
explicit GenerationOutput(TensorPtr ids, TensorPtr lengths)
|
||||
explicit GenericGenerationOutput(TensorPtr ids, TensorPtr lengths)
|
||||
: ids{std::move(ids)}
|
||||
, lengths{std::move(lengths)}
|
||||
{
|
||||
@ -53,4 +53,16 @@ public:
|
||||
Callback onTokenGenerated;
|
||||
};
|
||||
|
||||
class GenerationOutput : public GenericGenerationOutput<ITensor::SharedPtr>
|
||||
{
|
||||
public:
|
||||
using Base = GenericGenerationOutput<ITensor::SharedPtr>;
|
||||
using TensorPtr = Base::TensorPtr;
|
||||
|
||||
explicit GenerationOutput(TensorPtr ids, TensorPtr lengths)
|
||||
: GenericGenerationOutput(std::move(ids), std::move(lengths))
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -29,7 +29,7 @@ public:
|
||||
enum class ModelVariant : std::int32_t
|
||||
{
|
||||
kGpt = 0,
|
||||
kGlm = 1, // https://github.com/THUDM/GLM
|
||||
kGlm = 1, // https://github.com/THUDM/GLM and https://github.com/THUDM/ChatGLM-6B
|
||||
};
|
||||
|
||||
constexpr explicit GptModelConfig(
|
||||
@ -52,6 +52,7 @@ public:
|
||||
, mComputeContextLogits(false)
|
||||
, mModelVariant(ModelVariant::kGpt)
|
||||
, mUseCustomAllReduce(false)
|
||||
, mMaxPromptEmbeddingTableSize(0)
|
||||
{
|
||||
}
|
||||
|
||||
@ -196,6 +197,21 @@ public:
|
||||
mMaxNumTokens = maxNumTokens;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr usePromptTuning() const noexcept
|
||||
{
|
||||
return mMaxPromptEmbeddingTableSize > 0;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType constexpr getMaxPromptEmbeddingTableSize() const noexcept
|
||||
{
|
||||
return mMaxPromptEmbeddingTableSize;
|
||||
}
|
||||
|
||||
void constexpr setMaxPromptEmbeddingTableSize(SizeType maxPromptEmbeddingTableSize) noexcept
|
||||
{
|
||||
mMaxPromptEmbeddingTableSize = maxPromptEmbeddingTableSize;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr computeContextLogits() const noexcept
|
||||
{
|
||||
return mComputeContextLogits;
|
||||
@ -246,6 +262,8 @@ private:
|
||||
bool mComputeContextLogits;
|
||||
ModelVariant mModelVariant;
|
||||
bool mUseCustomAllReduce;
|
||||
|
||||
SizeType mMaxPromptEmbeddingTableSize;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -53,10 +53,11 @@ namespace utils
|
||||
std::vector<uint8_t> loadEngine(std::string const& enginePath);
|
||||
}
|
||||
|
||||
class TllmRuntime;
|
||||
class IpcMemory;
|
||||
class IStatefulGptDecoder;
|
||||
class NcclCommunicator;
|
||||
class RuntimeBuffers;
|
||||
class TllmRuntime;
|
||||
|
||||
class GptSession
|
||||
{
|
||||
@ -85,7 +86,8 @@ public:
|
||||
bool decoderPerRequest{false};
|
||||
bool cudaGraphMode{false};
|
||||
KvCacheConfig kvCacheConfig{};
|
||||
std::optional<SizeType> numMicroBatches = std::nullopt;
|
||||
std::optional<SizeType> ctxMicroBatchSize = std::nullopt;
|
||||
std::optional<SizeType> genMicroBatchSize = std::nullopt;
|
||||
};
|
||||
|
||||
GptSession(Config const& sessionConfig, GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
|
||||
@ -136,7 +138,7 @@ private:
|
||||
|
||||
void setup(Config const& sessionConfig);
|
||||
|
||||
void createContexts(SizeType numMicroBatches, bool useCudaGraphs);
|
||||
void createContexts(SizeType numBatchesCtx, SizeType numBatchesGen, bool useCudaGraphs);
|
||||
void createBuffers(SizeType numMicroBatches);
|
||||
void createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength,
|
||||
nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches);
|
||||
@ -144,6 +146,12 @@ private:
|
||||
SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength, KvCacheConfig const& config);
|
||||
void createCustomAllReduceWorkspace(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength);
|
||||
|
||||
void executeContextStep(std::vector<GenerationInput> const& microBatches,
|
||||
std::vector<SizeType> const& microBatchOffsets, KvCacheManager const* kvCacheManager);
|
||||
SizeType executeGenerationStep(SizeType step, std::vector<GenerationInput> const& microBatches,
|
||||
std::vector<SizeType> const& microBatchOffsets, KvCacheManager* kvCacheManager,
|
||||
std::vector<bool>& microBatchesFinished);
|
||||
|
||||
//! @brief Execute decoder on last PP rank, receive decoder output on other PP ranks.
|
||||
void decoderStepAsync(SizeType decoderStep, SizeType microBatchId);
|
||||
|
||||
@ -156,11 +164,11 @@ private:
|
||||
|
||||
void kvCacheAddSequences(SizeType beamWidth, SizeType microBatchId, SizeType firstBatchIdx);
|
||||
|
||||
ITensor::SharedPtr initNewTokens(
|
||||
GenerationInput const& inputs, SamplingConfig const& samplingConfig, SizeType microBatchId);
|
||||
//! @brief Populate outputIds and return reference to newTokens tensor
|
||||
ITensor::SharedPtr initDecoder(ITensor& outputIds, GenerationInput const& inputs,
|
||||
SamplingConfig const& samplingConfig, SizeType microBatchId) const;
|
||||
|
||||
std::function<void(SizeType microBatchId, SizeType step, bool finished)> createOnTokenGeneratedCallback(
|
||||
GenerationOutput& outputs, SizeType numMicroBatches);
|
||||
std::function<void(SizeType step, bool finished)> createOnTokenGeneratedCallback(GenerationOutput& outputs);
|
||||
|
||||
class CudaGraphExecutor
|
||||
{
|
||||
@ -196,6 +204,45 @@ private:
|
||||
cudaGraphExec_t mInstance;
|
||||
};
|
||||
|
||||
class MicroBatchConfig
|
||||
{
|
||||
public:
|
||||
MicroBatchConfig()
|
||||
: numCtxBatches{1}
|
||||
, numGenBatches{1}
|
||||
, ctxBatchSize{0}
|
||||
, genBatchSize{0}
|
||||
{
|
||||
}
|
||||
|
||||
explicit MicroBatchConfig(SizeType maxBatchSize, SizeType pipelineParallelism,
|
||||
std::optional<SizeType> genMicroBatchSize, std::optional<SizeType> ctxMicroBatchSize);
|
||||
|
||||
constexpr SizeType numCtxPerGen() const
|
||||
{
|
||||
return numCtxBatches / numGenBatches;
|
||||
}
|
||||
|
||||
//! @details First 2 * numGenBatches contexts are for generation phase, next numCtxBatches are for context
|
||||
//! phase. Use numCtxPerGen() contexts for the context batches of each generation batch.
|
||||
constexpr SizeType getCtxContextId(SizeType generationBatchId, SizeType contextBatchId) const
|
||||
{
|
||||
return 2 * numGenBatches + generationBatchId * numCtxPerGen() + contextBatchId;
|
||||
}
|
||||
|
||||
//! @details First 2 * numGenBatches contexts are for generation phase, flip-flop between 2 of them for each
|
||||
//! generation batch.
|
||||
constexpr SizeType getGenContextId(SizeType flipFlopId, SizeType generationBatchId) const
|
||||
{
|
||||
return flipFlopId * numGenBatches + generationBatchId;
|
||||
}
|
||||
|
||||
SizeType numCtxBatches;
|
||||
SizeType numGenBatches;
|
||||
SizeType ctxBatchSize;
|
||||
SizeType genBatchSize;
|
||||
};
|
||||
|
||||
friend class batch_manager::TrtGptModelV1;
|
||||
|
||||
private:
|
||||
@ -206,13 +253,17 @@ private:
|
||||
std::shared_ptr<CudaStream> mCommStream;
|
||||
CudaEvent mCommEvent{};
|
||||
|
||||
// tensor parallelism with custom allreduce plugin
|
||||
ITensor::SharedPtr mCommPtrs;
|
||||
std::vector<std::shared_ptr<IpcMemory>> mIpcMemoryHandles;
|
||||
|
||||
SizeType mDecoderMaxSequenceLength{};
|
||||
|
||||
LoggerPtr mLogger;
|
||||
std::shared_ptr<TllmRuntime> mRuntime;
|
||||
std::shared_ptr<KvCacheManager> mKvCacheManager;
|
||||
|
||||
SizeType mNumMicroBatches;
|
||||
MicroBatchConfig mMicroBatchConfig;
|
||||
// for each micro batch
|
||||
std::vector<std::shared_ptr<IStatefulGptDecoder>> mDecoders;
|
||||
std::vector<std::shared_ptr<RuntimeBuffers>> mBuffers;
|
||||
|
||||
@ -35,9 +35,10 @@ namespace decoder_batch
|
||||
class Request
|
||||
{
|
||||
public:
|
||||
using TensorPtr = std::shared_ptr<ITensor const>;
|
||||
using ConstTensorPtr = std::shared_ptr<ITensor const>;
|
||||
using TensorPtr = std::shared_ptr<ITensor>;
|
||||
|
||||
explicit Request(TensorPtr ids, std::optional<SizeType> maxNewTokens = std::nullopt,
|
||||
explicit Request(ConstTensorPtr ids, std::optional<SizeType> maxNewTokens = std::nullopt,
|
||||
std::optional<SizeType> endId = std::nullopt, std::optional<SizeType> padId = std::nullopt)
|
||||
: ids{std::move(ids)}
|
||||
, maxNewTokens{maxNewTokens}
|
||||
@ -46,7 +47,7 @@ public:
|
||||
}
|
||||
|
||||
// mandatory parameters
|
||||
TensorPtr ids; // [inputSeqLen], the input sequence of token ids, on gpu
|
||||
ConstTensorPtr ids; // [inputSeqLen], the input sequence of token ids, on gpu
|
||||
|
||||
// optional parameters
|
||||
std::optional<SizeType> maxNewTokens; // maximum number of tokens to generate for this request
|
||||
|
||||
@ -114,6 +114,25 @@ public:
|
||||
return newDims;
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief Add a *unit* dimension to `shape` at the specified position.
|
||||
//!
|
||||
//! \param shape The shape to unsqueeze.
|
||||
//! \param dim The dimension where unit dimension should be added.
|
||||
//! \return A new shape with the added unit dimension.
|
||||
//!
|
||||
static Shape unsqueeze(Shape const& shape, SizeType dim)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(dim <= shape.nbDims && dim >= 0,
|
||||
common::fmtstr("Invalid dim %d, tensor has %d dimensions", dim, shape.nbDims));
|
||||
|
||||
Shape newDims{shape.nbDims + 1};
|
||||
std::copy(shape.d, shape.d + dim, newDims.d);
|
||||
newDims.d[dim] = 1;
|
||||
std::copy(shape.d + dim, shape.d + shape.nbDims, newDims.d + dim + 1);
|
||||
return newDims;
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief Removes the given *unit* dimensions from this tensor.
|
||||
//!
|
||||
@ -122,6 +141,14 @@ public:
|
||||
reshape(squeeze(getShape(), dim));
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief Adds a *unit* dimension at the specified position
|
||||
//!
|
||||
void unsqueeze(SizeType dim)
|
||||
{
|
||||
reshape(unsqueeze(getShape(), dim));
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief Creates a sliced view on the underlying `tensor`. The view will have the same data type as `tensor`.
|
||||
//!
|
||||
|
||||
@ -127,6 +127,8 @@ public:
|
||||
|
||||
static std::string bytesToString(DiffType bytes, int precision = 2);
|
||||
|
||||
std::string toString() const;
|
||||
|
||||
private:
|
||||
SizeType mGpu{}, mCpu{}, mPinned{};
|
||||
DiffType mGpuDiff{}, mCpuDiff{}, mPinnedDiff{};
|
||||
|
||||
77
cpp/include/tensorrt_llm/runtime/promptTuningParams.h
Normal file
77
cpp/include/tensorrt_llm/runtime/promptTuningParams.h
Normal file
@ -0,0 +1,77 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/tllmBuffers.h"
|
||||
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
|
||||
template <typename TTensor>
|
||||
class GenericPromptTuningParams
|
||||
{
|
||||
public:
|
||||
using TensorPtr = TTensor;
|
||||
using SizeType = tensorrt_llm::runtime::SizeType;
|
||||
|
||||
explicit GenericPromptTuningParams(
|
||||
TensorPtr embeddingTable = TensorPtr(), TensorPtr tasks = TensorPtr(), TensorPtr vocabSize = TensorPtr())
|
||||
: embeddingTable{std::move(embeddingTable)}
|
||||
, tasks{std::move(tasks)}
|
||||
, vocabSize{std::move(vocabSize)} {};
|
||||
|
||||
// The prompt embedding table
|
||||
TensorPtr embeddingTable; // [numTasks * taskVocabSize, hidden_dim], on gpu
|
||||
// In GenerationInput, tasks expected shape is [batchSize]
|
||||
// For context requests with non-packed inputs, expected shape is [batchSize, 1]
|
||||
// For generation requests with non-packed inputs, expected shape is [batchSize*beamWidth] for generation requests.
|
||||
// For packed inputs, expected shape is [1, packedLength] (note that ifb currently doesn't support non-packed
|
||||
// inputs)
|
||||
TensorPtr tasks;
|
||||
TensorPtr vocabSize; // [1], on gpu
|
||||
|
||||
std::vector<bool>
|
||||
promptTuningEnabled; // [batchSize] vector of bool that indicates which requests in a batch have ptuning enabled
|
||||
};
|
||||
|
||||
class PromptTuningParams : public GenericPromptTuningParams<ITensor::SharedPtr>
|
||||
{
|
||||
public:
|
||||
using TensorPtr = ITensor::SharedPtr;
|
||||
using SizeType = GenericPromptTuningParams::SizeType;
|
||||
|
||||
explicit PromptTuningParams(
|
||||
TensorPtr embeddingTable = nullptr, TensorPtr tasks = nullptr, TensorPtr vocabSize = nullptr)
|
||||
: GenericPromptTuningParams(std::move(embeddingTable), std::move(tasks), std::move(vocabSize))
|
||||
{
|
||||
}
|
||||
|
||||
// Fill the tasks tensor for the batch using the provided tasksHost
|
||||
// Function assumes that the first numContextRequests requests in the batch are context requests
|
||||
void fillTasksTensor(TensorPtr tasksHost, const SizeType batchSize, const SizeType numContextRequests,
|
||||
const std::vector<SizeType>& reqBeamWidths, const std::vector<SizeType>& reqPromptLengths,
|
||||
BufferManager& manager, bool packedInput);
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
@ -70,29 +70,40 @@ if(NOT WIN32) # Linux
|
||||
endif()
|
||||
else() # Windows
|
||||
# AMD64, IA64, ARM64, EM64T, X86
|
||||
set(BATCH_MANAGER_TARGET_ARCH "${CMAKE_SYSTEM_PROCESSOR}-WINDOWS")
|
||||
string(TOLOWER ${BATCH_MANAGER_TARGET_ARCH} ${BATCH_MANAGER_TARGET_ARCH})
|
||||
if(CMAKE_SYSTEM_PROCESSOR MATCHES "AMD64")
|
||||
set(BATCH_MANAGER_TARGET_ARCH "x86_64-windows-msvc")
|
||||
else()
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"The system processor type is unsupported: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(BUILD_BATCH_MANAGER)
|
||||
add_subdirectory(batch_manager)
|
||||
else()
|
||||
add_library(${BATCH_MANAGER_TARGET} STATIC IMPORTED)
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} "-c"
|
||||
"import torch; print(torch.compiled_with_cxx11_abi(),end='');"
|
||||
RESULT_VARIABLE _PYTHON_SUCCESS
|
||||
OUTPUT_VARIABLE USE_CXX11_ABI)
|
||||
if(NOT WIN32) # Linux
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} "-c"
|
||||
"import torch; print(torch.compiled_with_cxx11_abi(),end='');"
|
||||
RESULT_VARIABLE _PYTHON_SUCCESS
|
||||
OUTPUT_VARIABLE USE_CXX11_ABI)
|
||||
|
||||
message(STATUS "USE_CXX11_ABI: ${USE_CXX11_ABI}")
|
||||
message(STATUS "USE_CXX11_ABI: ${USE_CXX11_ABI}")
|
||||
|
||||
if(USE_CXX11_ABI)
|
||||
if(USE_CXX11_ABI)
|
||||
set(BATCH_MANAGER_LIB_LOC
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/batch_manager/${BATCH_MANAGER_TARGET_ARCH}/libtensorrt_llm_batch_manager_static.a"
|
||||
)
|
||||
else()
|
||||
set(BATCH_MANAGER_LIB_LOC
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/batch_manager/${BATCH_MANAGER_TARGET_ARCH}/libtensorrt_llm_batch_manager_static.pre_cxx11.a"
|
||||
)
|
||||
endif()
|
||||
else() # Windows
|
||||
set(BATCH_MANAGER_LIB_LOC
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/batch_manager/${BATCH_MANAGER_TARGET_ARCH}/libtensorrt_llm_batch_manager_static.a"
|
||||
)
|
||||
else()
|
||||
set(BATCH_MANAGER_LIB_LOC
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/batch_manager/${BATCH_MANAGER_TARGET_ARCH}/libtensorrt_llm_batch_manager_static.pre_cxx11.a"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/batch_manager/${BATCH_MANAGER_TARGET_ARCH}/tensorrt_llm_batch_manager_static.lib"
|
||||
)
|
||||
endif()
|
||||
set_property(TARGET ${BATCH_MANAGER_TARGET} PROPERTY IMPORTED_LOCATION
|
||||
@ -132,7 +143,7 @@ set_target_properties(
|
||||
CXX_EXTENSIONS "NO")
|
||||
|
||||
if(NOT MSVC) # Unix-like compilers
|
||||
set(ALLOW_UNDEFINED_FLAG "-Wl, --no-undefined")
|
||||
set(UNDEFINED_FLAG "-Wl,--no-undefined")
|
||||
else() # MSVC
|
||||
set(UNDEFINED_FLAG "")
|
||||
endif()
|
||||
@ -158,4 +169,8 @@ if(BUILD_PYT)
|
||||
add_subdirectory(thop)
|
||||
endif()
|
||||
|
||||
if(BUILD_PYBIND)
|
||||
add_subdirectory(pybind)
|
||||
endif()
|
||||
|
||||
add_subdirectory(plugins)
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:422df71fccde81a55049fb61996d0b88bbaf1f18866b63c8e73c36b772c2df46
|
||||
size 1508332
|
||||
oid sha256:f591dd181613b14f7ded3ba3e167d14073564254bc46db8c4bd9636d6d896b16
|
||||
size 1611436
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0013625bc6b18255f44d6ab38e8ea0bceda6452bddf9df3cf832ad106fc2058d
|
||||
size 1516676
|
||||
oid sha256:21d17a9fa736d033ad77270a0fbcdd09c27dfab3f871d92a5ffa0cb744fa48fd
|
||||
size 1623126
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
bda56cf4ad2242be25115ddecd23e7df libtensorrt_llm_batch_manager_static.a
|
||||
12d7c8e5b4a018dfd9043fa7db979b5a libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
7e492cc1057b1091f62d69df81547cb071729e5d commit
|
||||
e1dc326c0c45864b9e7963b4d92d322f libtensorrt_llm_batch_manager_static.a
|
||||
d2e9d76efe6b4173270aa6b494dfe59c libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
07363ea7a6fdd6eeedc1670dedeeaedff7f9a848 commit
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c5a207480594cb228b7264f28af85b0a820046f64379f11fd7389c701ca5497d
|
||||
size 1421186
|
||||
oid sha256:3fe444bf079ce35262b932302806b372ccb677182969e3bba45698343e5e350f
|
||||
size 1523444
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:80e06e15b9e29ba80c036ba6604a2ce286acb294eddb50015bad53cfdeba4534
|
||||
size 1423958
|
||||
oid sha256:99641389fdf26f6324b7465df0b61b74946787a6a147d145de23b444261e6e5f
|
||||
size 1524188
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
b10b0e00d0132b04969d779af45d73d0 libtensorrt_llm_batch_manager_static.a
|
||||
3ad06255afdaa8450c133d1d1bc486c4 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
36
cpp/tensorrt_llm/common/assert.cpp
Executable file
36
cpp/tensorrt_llm/common/assert.cpp
Executable file
@ -0,0 +1,36 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "assert.h"
|
||||
|
||||
bool CHECK_DEBUG_ENABLED = false;
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
#if !defined(_MSC_VER)
|
||||
__attribute__((constructor))
|
||||
#endif
|
||||
void initOnLoad()
|
||||
{
|
||||
auto constexpr kDebugEnabled = "TRT_LLM_DEBUG_MODE";
|
||||
auto const debugEnabled = std::getenv(kDebugEnabled);
|
||||
if (debugEnabled && debugEnabled[0] == '1')
|
||||
{
|
||||
CHECK_DEBUG_ENABLED = true;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
@ -30,6 +30,8 @@ namespace tensorrt_llm::common
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
extern bool CHECK_DEBUG_ENABLED;
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define TLLM_LIKELY(x) (__assume((x) == 1), (x))
|
||||
#else
|
||||
@ -50,6 +52,26 @@ namespace tensorrt_llm::common
|
||||
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, info); \
|
||||
} while (0)
|
||||
|
||||
#define TLLM_CHECK_DEBUG(val) \
|
||||
do \
|
||||
{ \
|
||||
if (CHECK_DEBUG_ENABLED) \
|
||||
{ \
|
||||
TLLM_LIKELY(static_cast<bool>(val)) ? ((void) 0) \
|
||||
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define TLLM_CHECK_DEBUG_WITH_INFO(val, info) \
|
||||
do \
|
||||
{ \
|
||||
if (CHECK_DEBUG_ENABLED) \
|
||||
{ \
|
||||
TLLM_LIKELY(static_cast<bool>(val)) ? ((void) 0) \
|
||||
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, info); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define TLLM_THROW(...) \
|
||||
do \
|
||||
{ \
|
||||
|
||||
@ -390,6 +390,17 @@ void print2dToScreen(const T* result, const int r, const int c, const int stride
|
||||
print2dToStream(result, r, c, stride, stdout);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void print2dToFile(std::string fname, const T* result, const int r, const int c, const int stride)
|
||||
{
|
||||
FILE* fp = fopen(fname.c_str(), "wt");
|
||||
if (fp != nullptr)
|
||||
{
|
||||
print2dToStream(result, r, c, stride, fp);
|
||||
fclose(fp);
|
||||
}
|
||||
}
|
||||
|
||||
inline void print_float_(float x)
|
||||
{
|
||||
printf("%7.3f ", x);
|
||||
|
||||
@ -201,7 +201,7 @@ public:
|
||||
return quantMode;
|
||||
}
|
||||
|
||||
constexpr QuantMode operator+(const QuantMode& other) noexcept
|
||||
constexpr QuantMode operator+(const QuantMode& other) const noexcept
|
||||
{
|
||||
return QuantMode(mValue | other.mValue);
|
||||
}
|
||||
@ -211,7 +211,7 @@ public:
|
||||
return *this = *this + other;
|
||||
}
|
||||
|
||||
constexpr QuantMode operator-(const QuantMode& other) noexcept
|
||||
constexpr QuantMode operator-(const QuantMode& other) const noexcept
|
||||
{
|
||||
return QuantMode(mValue & ~other.mValue);
|
||||
}
|
||||
|
||||
@ -296,6 +296,11 @@ struct TopK
|
||||
|
||||
__device__ __forceinline__ void insert(T elem, int elem_id)
|
||||
{
|
||||
if (elem_id < 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if (elem > u[MAX_K - 1] || (p[MAX_K - 1] == -1) || ((elem == u[MAX_K - 1]) && (elem_id < p[MAX_K - 1])))
|
||||
// if (elem > u[MAX_K-1] || ((elem == u[MAX_K-1]) && (elem_id < p[MAX_K-1])))
|
||||
{
|
||||
|
||||
@ -171,10 +171,17 @@ template <typename T>
|
||||
void invokeAddBiasApplyPenalties(T* logits, const int** output_ids_ptr, const int** parent_ids_ptr,
|
||||
const int* input_lengths, const int* sequence_lengths, const T* bias, const int ite, const int local_batch_size,
|
||||
const int batch_size, const int beam_width, const int vocab_size, const int vocab_size_padded, const int* end_ids,
|
||||
const float* temperatures, const float* repetition_penalties, const RepetitionPenaltyType repetition_penalty_type,
|
||||
const float* temperatures, const std::vector<float>& h_temperatures, const float* repetition_penalties,
|
||||
const std::vector<float>& h_repetition_penalties, const RepetitionPenaltyType repetition_penalty_type,
|
||||
const int* min_lengths, const int max_seq_len, cudaStream_t stream)
|
||||
{
|
||||
if (bias != nullptr || temperatures != nullptr || vocab_size != vocab_size_padded)
|
||||
|
||||
#define ALL_OF(p_, sz_, dt_, v_) (std::all_of(p_, p_ + sz_, [&](dt_ b) { return b == v_; }))
|
||||
|
||||
if (bias != nullptr
|
||||
|| (temperatures != nullptr
|
||||
&& !ALL_OF(std::begin(h_temperatures) + ite * local_batch_size, local_batch_size, float, 1.0f))
|
||||
|| vocab_size != vocab_size_padded)
|
||||
{
|
||||
dim3 block(512);
|
||||
if (std::is_same<T, half>::value && vocab_size % 2 == 0 && vocab_size_padded % 2 == 0)
|
||||
@ -199,14 +206,19 @@ void invokeAddBiasApplyPenalties(T* logits, const int** output_ids_ptr, const in
|
||||
size_t smem_size = (sizeof(T) * max_seq_len + 31) / 32 * 32 + sizeof(int) * max_seq_len;
|
||||
dim3 block(256);
|
||||
dim3 grid(beam_width * local_batch_size);
|
||||
if (repetition_penalty_type == RepetitionPenaltyType::Multiplicative)
|
||||
float default_value = getDefaultPenaltyValue(repetition_penalty_type);
|
||||
if (repetition_penalty_type == RepetitionPenaltyType::Multiplicative
|
||||
&& !ALL_OF(std::begin(h_repetition_penalties) + ite * local_batch_size, local_batch_size, float,
|
||||
default_value))
|
||||
{
|
||||
apply_repetition_penalty<T, false><<<grid, block, smem_size, stream>>>(logits, batch_size, beam_width,
|
||||
vocab_size, vocab_size_padded, output_ids_ptr, parent_ids_ptr, input_lengths, sequence_lengths,
|
||||
repetition_penalties, max_seq_len);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
else if (repetition_penalty_type == RepetitionPenaltyType::Additive)
|
||||
else if (repetition_penalty_type == RepetitionPenaltyType::Additive
|
||||
&& !ALL_OF(std::begin(h_repetition_penalties) + ite * local_batch_size, local_batch_size, float,
|
||||
default_value))
|
||||
{
|
||||
apply_repetition_penalty<T, true><<<grid, block, smem_size, stream>>>(logits, batch_size, beam_width,
|
||||
vocab_size, vocab_size_padded, output_ids_ptr, parent_ids_ptr, input_lengths, sequence_lengths,
|
||||
@ -224,18 +236,22 @@ void invokeAddBiasApplyPenalties(T* logits, const int** output_ids_ptr, const in
|
||||
apply_min_length_penalty<<<grid_size, block_size, 0, stream>>>(
|
||||
logits, min_lengths, end_ids, sequence_lengths, input_lengths, beam_width, vocab_size_padded);
|
||||
sync_check_cuda_error();
|
||||
|
||||
#undef ALL_OF
|
||||
}
|
||||
|
||||
template void invokeAddBiasApplyPenalties(float* logits, const int** output_ids_ptr, const int** parent_ids_ptr,
|
||||
const int* input_lengths, const int* sequence_lengths, const float* bias, const int ite, const int local_batch_size,
|
||||
const int batch_size, const int beam_width, const int vocab_size, const int vocab_size_padded, const int* end_ids,
|
||||
const float* temperatures, const float* repetition_penalties, const RepetitionPenaltyType repetition_penalty_type,
|
||||
const float* temperatures, const std::vector<float>& h_temperatures, const float* repetition_penalties,
|
||||
const std::vector<float>& h_repetition_penalties, const RepetitionPenaltyType repetition_penalty_type,
|
||||
const int* min_lengths, int max_seq_len, cudaStream_t stream);
|
||||
|
||||
template void invokeAddBiasApplyPenalties(half* logits, const int** output_ids_ptr, const int** parent_ids_ptr,
|
||||
const int* input_lengths, const int* sequence_lengths, const half* bias, const int ite, const int local_batch_size,
|
||||
const int batch_size, const int beam_width, const int vocab_size, const int vocab_size_padded, const int* end_ids,
|
||||
const float* temperatures, const float* repetition_penalties, const RepetitionPenaltyType repetition_penalty_type,
|
||||
const float* temperatures, const std::vector<float>& h_temperatures, const float* repetition_penalties,
|
||||
const std::vector<float>& h_repetition_penalties, const RepetitionPenaltyType repetition_penalty_type,
|
||||
const int* min_lengths, int max_seq_len, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
@ -28,7 +28,8 @@ template <typename T>
|
||||
void invokeAddBiasApplyPenalties(T* logits, const int** output_ids_ptr, const int** parent_ids_ptr,
|
||||
const int* input_lengths, const int* sequence_lengths, const T* bias, const int ite, const int local_batch_size,
|
||||
const int batch_size, const int beam_width, const int vocab_size, const int vocab_size_padded, const int* end_ids,
|
||||
const float* temperatures, const float* repetition_penalties, const RepetitionPenaltyType repetition_penalty_type,
|
||||
const float* temperatures, const std::vector<float>& h_temperatures, const float* repetition_penalties,
|
||||
const std::vector<float>& h_repetition_penalties, const RepetitionPenaltyType repetition_penalty_type,
|
||||
const int* min_lengths, int max_seq_len, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
@ -47,16 +47,20 @@ void multihead_attention_(
|
||||
switch (params.hidden_size_per_head)
|
||||
{
|
||||
case 32: mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 32>(params, kv_cache_buffer, stream); break;
|
||||
case 48: mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 48>(params, kv_cache_buffer, stream); break;
|
||||
case 64: mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 64>(params, kv_cache_buffer, stream); break;
|
||||
case 128:
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 128>(params, kv_cache_buffer, stream);
|
||||
break;
|
||||
case 256:
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 256>(params, kv_cache_buffer, stream);
|
||||
break;
|
||||
#ifndef FAST_BUILD // skip mmha 48, 80, 96, 112, 144, 160, 192 and 224 for fast build
|
||||
case 48: mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 48>(params, kv_cache_buffer, stream); break;
|
||||
case 80: mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 80>(params, kv_cache_buffer, stream); break;
|
||||
case 96: mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 96>(params, kv_cache_buffer, stream); break;
|
||||
case 112:
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 112>(params, kv_cache_buffer, stream);
|
||||
break;
|
||||
case 128:
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 128>(params, kv_cache_buffer, stream);
|
||||
break;
|
||||
case 144:
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 144>(params, kv_cache_buffer, stream);
|
||||
break;
|
||||
@ -69,9 +73,7 @@ void multihead_attention_(
|
||||
case 224:
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 224>(params, kv_cache_buffer, stream);
|
||||
break;
|
||||
case 256:
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 256>(params, kv_cache_buffer, stream);
|
||||
break;
|
||||
#endif // FAST_BUILD
|
||||
default: TLLM_THROW("unsupported head_size");
|
||||
}
|
||||
}
|
||||
|
||||
@ -29,9 +29,11 @@ auto constexpr kSizePerHead = 112;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_112 for fast build
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_MMHA_LAUNCHERS(__nv_bfloat16, kSizePerHead)
|
||||
#endif
|
||||
#endif // ENABLE_BF16
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,7 +29,9 @@ auto constexpr kSizePerHead = 112;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_112 for fast build
|
||||
INSTANTIATE_MMHA_LAUNCHERS(float, kSizePerHead)
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,7 +29,9 @@ auto constexpr kSizePerHead = 112;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_112 for fast build
|
||||
INSTANTIATE_MMHA_LAUNCHERS(uint16_t, kSizePerHead)
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,9 +29,11 @@ auto constexpr kSizePerHead = 144;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_144 for fast build
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_MMHA_LAUNCHERS(__nv_bfloat16, kSizePerHead)
|
||||
#endif
|
||||
#endif // ENABLE_BF16
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,7 +29,9 @@ auto constexpr kSizePerHead = 144;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_144 for fast build
|
||||
INSTANTIATE_MMHA_LAUNCHERS(float, kSizePerHead)
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,7 +29,9 @@ auto constexpr kSizePerHead = 144;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_144 for fast build
|
||||
INSTANTIATE_MMHA_LAUNCHERS(uint16_t, kSizePerHead)
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,8 +29,10 @@ auto constexpr kSizePerHead = 160;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_160 for fast build
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_MMHA_LAUNCHERS(__nv_bfloat16, kSizePerHead)
|
||||
#endif // ENABLE_BF16
|
||||
#endif
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
@ -29,7 +29,9 @@ auto constexpr kSizePerHead = 160;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_160 for fast build
|
||||
INSTANTIATE_MMHA_LAUNCHERS(float, kSizePerHead)
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,7 +29,9 @@ auto constexpr kSizePerHead = 160;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_160 for fast build
|
||||
INSTANTIATE_MMHA_LAUNCHERS(uint16_t, kSizePerHead)
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,9 +29,11 @@ auto constexpr kSizePerHead = 192;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_192 for fast build
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_MMHA_LAUNCHERS(__nv_bfloat16, kSizePerHead)
|
||||
#endif
|
||||
#endif // ENABLE_BF16
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,7 +29,9 @@ auto constexpr kSizePerHead = 192;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_192 for fast build
|
||||
INSTANTIATE_MMHA_LAUNCHERS(float, kSizePerHead)
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,7 +29,9 @@ auto constexpr kSizePerHead = 192;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_192 for fast build
|
||||
INSTANTIATE_MMHA_LAUNCHERS(uint16_t, kSizePerHead)
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,9 +29,11 @@ auto constexpr kSizePerHead = 224;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_224 for fast build
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_MMHA_LAUNCHERS(__nv_bfloat16, kSizePerHead)
|
||||
#endif
|
||||
#endif // ENABLE_BF16
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,7 +29,9 @@ auto constexpr kSizePerHead = 224;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_224 for fast build
|
||||
INSTANTIATE_MMHA_LAUNCHERS(float, kSizePerHead)
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,7 +29,9 @@ auto constexpr kSizePerHead = 224;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_224 for fast build
|
||||
INSTANTIATE_MMHA_LAUNCHERS(uint16_t, kSizePerHead)
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,9 +29,11 @@ auto constexpr kSizePerHead = 48;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_48 for fast build
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_MMHA_LAUNCHERS(__nv_bfloat16, kSizePerHead)
|
||||
#endif
|
||||
#endif // ENABLE_BF16
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,7 +29,9 @@ auto constexpr kSizePerHead = 48;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_48 for fast build
|
||||
INSTANTIATE_MMHA_LAUNCHERS(float, kSizePerHead)
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,7 +29,9 @@ auto constexpr kSizePerHead = 48;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_48 for fast build
|
||||
INSTANTIATE_MMHA_LAUNCHERS(uint16_t, kSizePerHead)
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,9 +29,11 @@ auto constexpr kSizePerHead = 80;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_80 for fast build
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_MMHA_LAUNCHERS(__nv_bfloat16, kSizePerHead)
|
||||
#endif
|
||||
#endif // ENABLE_BF16
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,7 +29,9 @@ auto constexpr kSizePerHead = 80;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_80 for fast build
|
||||
INSTANTIATE_MMHA_LAUNCHERS(float, kSizePerHead)
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,7 +29,9 @@ auto constexpr kSizePerHead = 80;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_80 for fast build
|
||||
INSTANTIATE_MMHA_LAUNCHERS(uint16_t, kSizePerHead)
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,9 +29,11 @@ auto constexpr kSizePerHead = 96;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_96 for fast build
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_MMHA_LAUNCHERS(__nv_bfloat16, kSizePerHead)
|
||||
#endif
|
||||
#endif // ENABLE_BF16
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,7 +29,9 @@ auto constexpr kSizePerHead = 96;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_96 for fast build
|
||||
INSTANTIATE_MMHA_LAUNCHERS(float, kSizePerHead)
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -29,7 +29,9 @@ auto constexpr kSizePerHead = 96;
|
||||
namespace mmha
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip mmha_96 for fast build
|
||||
INSTANTIATE_MMHA_LAUNCHERS(uint16_t, kSizePerHead)
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
|
||||
@ -2152,7 +2152,6 @@ __global__ void masked_multihead_attention_kernel(
|
||||
const int normlization_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : tlength;
|
||||
for (int ti = tidx; ti <= normlization_loop_end; ti += THREADS_PER_BLOCK)
|
||||
{
|
||||
|
||||
const int time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti;
|
||||
|
||||
if (!MULTI_BLOCK_FLAG)
|
||||
@ -2308,8 +2307,11 @@ __global__ void masked_multihead_attention_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
// Get the c_tile_id that handles the current timestep.
|
||||
const int ctile_idx = tlength / timesteps_per_block;
|
||||
|
||||
// One group of threads computes the product(s) for the current timestep.
|
||||
if (vo == tlength % V_PER_ITER && is_valid_vi && (!MULTI_BLOCK_FLAG || (c_tile == gridDim.z - 1)))
|
||||
if (vo == tlength % V_PER_ITER && is_valid_vi && (!MULTI_BLOCK_FLAG || (c_tile == ctile_idx)))
|
||||
{
|
||||
const int tokenIdx = tlength;
|
||||
const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(tokenIdx, hi_kv, Dh, vi);
|
||||
@ -2396,7 +2398,6 @@ __global__ void masked_multihead_attention_kernel(
|
||||
}
|
||||
#endif // MMHA_USE_FP32_ACCUM_FOR_LOGITS
|
||||
}
|
||||
|
||||
// Make sure we can start writing to shared memory.
|
||||
__syncthreads();
|
||||
|
||||
@ -2428,7 +2429,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
}
|
||||
|
||||
const auto bhi = tensorrt_llm::common::flat_index2(batch_beam_idx, hi, num_heads);
|
||||
const auto bhi_seq_len_tile = bhi * params.max_seq_len_tile;
|
||||
const auto bhi_seq_len_tile = bhi * params.seq_len_tile;
|
||||
// Output the final values.
|
||||
if (vo == 0 && (Dh == Dh_MAX || vi < Dh))
|
||||
{
|
||||
@ -2499,9 +2500,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
|
||||
float final_max = -FLT_MAX;
|
||||
float thread_partial_max = -FLT_MAX;
|
||||
if (tidx < gridDim.z)
|
||||
thread_partial_max = params.partial_max[bhi_seq_len_tile + tidx];
|
||||
// final_max = fmaxf(final_max, thread_partial_max);
|
||||
thread_partial_max = params.partial_max[bhi_seq_len_tile + min(tidx, gridDim.x - 1)];
|
||||
|
||||
// Make sure we can start writing to shared memory.
|
||||
__syncthreads();
|
||||
@ -2548,34 +2547,29 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// Shared memory to store partial outputs for each oi. -> size: gridDim.z * Dh * 4 Bytes. Reuse qk_smem.
|
||||
T* out_oi_smem = reinterpret_cast<T*>(smem_);
|
||||
|
||||
// Number of threads to utilize: THREADS_PER_VALUE * gridDim.z (THREADS_PER_VALUE for vectorized output
|
||||
// and gridDim.z for all the partial outputs)
|
||||
int threads_boundary = THREADS_PER_VALUE * gridDim.z; // should be smaller than THREADS_PER_BLOCK
|
||||
assert(threads_boundary <= THREADS_PER_BLOCK);
|
||||
|
||||
const auto o_idx = chunk_index<T, V_vec_k, THREADS_PER_VALUE>(tidx);
|
||||
// The partial output region this thread takes care of
|
||||
const auto oo = o_idx.x;
|
||||
// The hidden dimensions computed by this particular thread. (refer to vi)
|
||||
const auto oi = o_idx.y;
|
||||
|
||||
// Within the bound.
|
||||
const bool within_bound = oo < gridDim.z;
|
||||
|
||||
// Load partial output
|
||||
int thread_partial_out_offset = oo * params.batch_size * num_heads * params.hidden_size_per_head;
|
||||
// Load partial max (different to thread_partial_max since the threadIdx rule changes here)
|
||||
float thread_partial_max_for_out = params.partial_max[bhi_seq_len_tile + oo];
|
||||
float thread_partial_max_for_out = within_bound ? params.partial_max[bhi_seq_len_tile + oo] : final_max;
|
||||
|
||||
// Load the partial outputs.
|
||||
V_vec_k thread_partial_out
|
||||
= *reinterpret_cast<const V_vec_k*>(¶ms.partial_out[thread_partial_out_offset + bhi * Dh + oi]);
|
||||
|
||||
if (tidx >= threads_boundary)
|
||||
{
|
||||
zero(thread_partial_out);
|
||||
}
|
||||
V_vec_k zero_k;
|
||||
zero(zero_k);
|
||||
V_vec_k thread_partial_out = within_bound
|
||||
? *reinterpret_cast<const V_vec_k*>(¶ms.partial_out[thread_partial_out_offset + bhi * Dh + oi])
|
||||
: zero_k;
|
||||
|
||||
Tk factor_compute;
|
||||
convert_from_float(&factor_compute, __expf(thread_partial_max_for_out - final_max));
|
||||
|
||||
thread_partial_out = mul<V_vec_k, Tk, V_vec_k>(factor_compute, thread_partial_out);
|
||||
|
||||
// Make sure we can start writing to shared memory.
|
||||
@ -2620,7 +2614,6 @@ __global__ void masked_multihead_attention_kernel(
|
||||
convert_from_float(&inv_sum_compute, inv_sum);
|
||||
|
||||
thread_partial_out = mul<V_vec_k, Tk, V_vec_k>(inv_sum_compute, thread_partial_out);
|
||||
|
||||
*reinterpret_cast<V_vec_k*>(¶ms.out[bhi * Dh + oi]) = thread_partial_out;
|
||||
}
|
||||
|
||||
|
||||
@ -52,18 +52,22 @@ void invokeTopkSoftMax(const T* log_probs, const T* bias, const bool* finished,
|
||||
switch (log_beam_width)
|
||||
{
|
||||
// 0 < beam_width <= 4
|
||||
case 0: // 1, 2
|
||||
case 1: // 3, 4
|
||||
case 0: // 1, 2
|
||||
case 1: // 3, 4
|
||||
CASE_K(4)
|
||||
case 2: // 4 < beam_width <= 8
|
||||
case 2: // 4 < beam_width <= 8
|
||||
CASE_K(8)
|
||||
case 3: // 9 < beam_width <= 16
|
||||
#ifndef FAST_BUILD // For fast build, skip case 3, 4, 5
|
||||
case 3: // 9 < beam_width <= 16
|
||||
CASE_K(16)
|
||||
case 4: // 16 < beam_width <= 32
|
||||
case 4: // 16 < beam_width <= 32
|
||||
CASE_K(32)
|
||||
case 5: // 32 < beam_width <= 64
|
||||
case 5: // 32 < beam_width <= 64
|
||||
CASE_K(64)
|
||||
default: throw std::runtime_error(fmtstr("Topk kernel of beam search does not support beam_width=%d", beam_width));
|
||||
#endif // FAST_BUILD
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
fmtstr("%s:%d Topk kernel of beam search does not support beam_width=%d", __FILE__, __LINE__, beam_width));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -20,9 +20,9 @@ namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip beam_width between [?, 16] for fast build
|
||||
INSTANTIATE_BEAMSEARCH_K(float, 16);
|
||||
INSTANTIATE_BEAMSEARCH_K(half, 16);
|
||||
|
||||
#endif // FAST_BUILD
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -21,8 +21,10 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip beam_width between [?, 32] for fast build
|
||||
INSTANTIATE_BEAMSEARCH_K(float, 32);
|
||||
INSTANTIATE_BEAMSEARCH_K(half, 32);
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -21,8 +21,10 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
#ifndef FAST_BUILD // skip beam_width between [?, 64] for fast build
|
||||
INSTANTIATE_BEAMSEARCH_K(float, 64);
|
||||
INSTANTIATE_BEAMSEARCH_K(half, 64);
|
||||
#endif // FAST_BUILD
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -181,16 +181,11 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
|
||||
|
||||
for (int i = 0; i < MAX_K; ++i)
|
||||
{
|
||||
if (beam_hyps.num_beams != nullptr && x[total.p[i]] % vocab_size == beam_hyps.end_ids[vector_id])
|
||||
if (i < K && beam_hyps.num_beams != nullptr && x[total.p[i]] % vocab_size == beam_hyps.end_ids[vector_id])
|
||||
{
|
||||
// if beam_token does not belong to top num_beams tokens, it should not
|
||||
// be added. Refer from
|
||||
// https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/generation_beam_search.py#L257
|
||||
if (i >= K)
|
||||
{
|
||||
// do nothing
|
||||
}
|
||||
else
|
||||
{
|
||||
const float normed_score = (float) total.u[i];
|
||||
const int num_beam = beam_hyps.num_beams[global_batch_idx];
|
||||
|
||||
@ -274,7 +274,11 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm
|
||||
randNum = randNum - expLogit;
|
||||
if (randNum <= 0.0f || i == k - 1)
|
||||
{
|
||||
ids[batchId][sequenceLengths[batchId]] = topKTmpIdBuf[batchId * stride + s_id[i]] % vocabSize;
|
||||
int idx = s_id[i];
|
||||
// If s_id is -1 here we force output token to the last from vocabulary to get vivid indicator of smth
|
||||
// going wrong for the debug
|
||||
auto outputId = idx != -1 ? topKTmpIdBuf[batchId * stride + idx] % vocabSize : vocabSize - 1;
|
||||
ids[batchId][sequenceLengths[batchId]] = outputId;
|
||||
if (cumLogProbs != nullptr || outputLogProbs != nullptr)
|
||||
{
|
||||
float logProb = logf(expLogit);
|
||||
|
||||
@ -27,6 +27,7 @@ namespace kernels
|
||||
//! Computes sequenceLength, finished state, cumLogProbs inplace.
|
||||
//! Sampling per request can be controlled using skipDecode, topPs and topKs parameters.
|
||||
//! Function sets workspaceSize and exits early if workspace is nullptr.
|
||||
//! If logits are Nan, we set output token to be the last in the vocabulary.
|
||||
//!
|
||||
//! \param workspace pointer to the workspace. Has to be pre-allocated by caller. Function does not take ownership of the
|
||||
//! buffer.
|
||||
|
||||
@ -190,7 +190,8 @@ void BaseBeamSearchLayer<T>::forward(BeamSearchOutputParams& outputs, ForwardPar
|
||||
invokeAddBiasApplyPenalties(logits.getPtr<T>(), output_ids_ptr.template getPtr<const int*>(),
|
||||
outputs.parent_ids_ptr.template getPtr<const int*>(), input_lengths, sequence_length, embedding_bias, ite,
|
||||
local_batch_size, batch_size, beam_width, vocab_size_, vocab_size_padded_, end_ids, temperature_buf_,
|
||||
repetition_penalty_buf_, mRepetitionPenaltyType, min_lengths_buf_, max_seq_len, stream_);
|
||||
mTemperature, repetition_penalty_buf_, mRepetitionPenalty, mRepetitionPenaltyType, min_lengths_buf_,
|
||||
max_seq_len, stream_);
|
||||
sync_check_cuda_error();
|
||||
|
||||
invokeSoftMax(outputs, params);
|
||||
|
||||
41
cpp/tensorrt_llm/pybind/CMakeLists.txt
Normal file
41
cpp/tensorrt_llm/pybind/CMakeLists.txt
Normal file
@ -0,0 +1,41 @@
|
||||
set(TRTLLM_PYBIND_MODULE bindings)
|
||||
set(TRTLLM_PYBIND_MODULE
|
||||
${TRTLLM_PYBIND_MODULE}
|
||||
PARENT_SCOPE)
|
||||
|
||||
if(NOT BUILD_PYT)
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"Python bindings for C++ runtime require PyTorch. Please enable BUILD_PYT"
|
||||
)
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} "-c"
|
||||
"import pybind11 as pb11; print(pb11.get_cmake_dir(),end='');"
|
||||
RESULT_VARIABLE PYBIND_CMAKE_DIR_RET
|
||||
OUTPUT_VARIABLE PYBIND_CMAKE_DIR)
|
||||
|
||||
if(PYBIND_CMAKE_DIR_RET MATCHES 0)
|
||||
list(APPEND CMAKE_PREFIX_PATH "${PYBIND_CMAKE_DIR}")
|
||||
else()
|
||||
message(ERROR "pybind11 CMake directory not found.")
|
||||
endif()
|
||||
|
||||
find_package(pybind11 REQUIRED)
|
||||
|
||||
set(SRCS bindings.cpp runtime/generationInput.cpp runtime/generationOutput.cpp)
|
||||
|
||||
pybind11_add_module(${TRTLLM_PYBIND_MODULE} ${SRCS})
|
||||
|
||||
set_property(TARGET ${TRTLLM_PYBIND_MODULE} PROPERTY POSITION_INDEPENDENT_CODE
|
||||
ON)
|
||||
|
||||
target_link_directories(${TRTLLM_PYBIND_MODULE} PUBLIC
|
||||
"${TORCH_INSTALL_PREFIX}/lib")
|
||||
target_link_libraries(
|
||||
${TRTLLM_PYBIND_MODULE}
|
||||
PUBLIC ${STATIC_TARGET} ${Python3_LIBRARIES} ${TORCH_LIBRARIES} torch_python
|
||||
${UNDEFINED_FLAG})
|
||||
target_compile_definitions(${TRTLLM_PYBIND_MODULE}
|
||||
PUBLIC TRTLLM_PYBIND_MODULE=${TRTLLM_PYBIND_MODULE})
|
||||
250
cpp/tensorrt_llm/pybind/bindings.cpp
Normal file
250
cpp/tensorrt_llm/pybind/bindings.cpp
Normal file
@ -0,0 +1,250 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <pybind11/operators.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "runtime/generationInput.h"
|
||||
#include "runtime/generationOutput.h"
|
||||
|
||||
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
|
||||
#include "tensorrt_llm/common/quantization.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/gptJsonConfig.h"
|
||||
#include "tensorrt_llm/runtime/gptSession.h"
|
||||
#include "tensorrt_llm/runtime/samplingConfig.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
namespace tc = tensorrt_llm::common;
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
namespace tpr = tensorrt_llm::pybind::runtime;
|
||||
|
||||
#if not defined(TRTLLM_PYBIND_MODULE)
|
||||
#error "TRTLLM_PYBIND_MODULE must be defined"
|
||||
#endif
|
||||
|
||||
PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
{
|
||||
m.doc() = "TensorRT-LLM Python bindings for C++ runtime";
|
||||
|
||||
py::class_<tpr::PromptTuningParams>(m, "PromptTuningParams")
|
||||
.def(py::init<tpr::PromptTuningParams::TensorPtr, tpr::PromptTuningParams::TensorPtr,
|
||||
tpr::PromptTuningParams::TensorPtr>(),
|
||||
py::arg("embedding_table") = py::none(), py::arg("tasks") = py::none(), py::arg("vocab_size") = py::none())
|
||||
.def_readwrite("embedding_table", &tpr::PromptTuningParams::embeddingTable)
|
||||
.def_readwrite("tasks", &tpr::PromptTuningParams::tasks)
|
||||
.def_readwrite("vocab_size", &tpr::PromptTuningParams::vocabSize)
|
||||
.def_readwrite("prompt_tuning_enabled", &tpr::PromptTuningParams::promptTuningEnabled);
|
||||
|
||||
py::class_<tpr::GenerationInput>(m, "GenerationInput")
|
||||
.def(py::init<tr::SizeType, tr::SizeType, tpr::GenerationInput::TensorPtr, tpr::GenerationInput::TensorPtr,
|
||||
bool>(),
|
||||
py::arg("end_id"), py::arg("pad_id"), py::arg("ids"), py::arg("lengths"), py::arg("packed") = false)
|
||||
.def_readwrite("end_id", &tpr::GenerationInput::endId)
|
||||
.def_readwrite("pad_id", &tpr::GenerationInput::padId)
|
||||
.def_readwrite("ids", &tpr::GenerationInput::ids)
|
||||
.def_readwrite("lengths", &tpr::GenerationInput::lengths)
|
||||
.def_readwrite("packed", &tpr::GenerationInput::packed)
|
||||
.def_readwrite("embedding_bias", &tpr::GenerationInput::embeddingBiasOpt)
|
||||
.def_readwrite("bad_words_list", &tpr::GenerationInput::badWordsList)
|
||||
.def_readwrite("stop_words_list", &tpr::GenerationInput::stopWordsList)
|
||||
.def_readwrite("max_new_tokens", &tpr::GenerationInput::maxNewTokens)
|
||||
.def_readwrite("prompt_tuning_params", &tpr::GenerationInput::promptTuningParams);
|
||||
|
||||
py::class_<tpr::GenerationOutput>(m, "GenerationOutput")
|
||||
.def(py::init<tpr::GenerationOutput::TensorPtr, tpr::GenerationOutput::TensorPtr>(), py::arg("ids"),
|
||||
py::arg("lengths"))
|
||||
.def_readwrite("ids", &tpr::GenerationOutput::ids)
|
||||
.def_readwrite("lengths", &tpr::GenerationOutput::lengths)
|
||||
.def_readwrite("log_probs", &tpr::GenerationOutput::logProbs)
|
||||
.def_readwrite("context_logits", &tpr::GenerationOutput::contextLogits);
|
||||
|
||||
py::class_<tb::kv_cache_manager::KvCacheConfig>(m, "KvCacheConfig")
|
||||
.def(py::init<std::optional<tr::SizeType>, std::optional<float>>(), py::arg("max_tokens") = py::none(),
|
||||
py::arg("free_gpu_memory_fraction") = py::none())
|
||||
.def_readwrite("max_tokens", &tb::kv_cache_manager::KvCacheConfig::maxTokens)
|
||||
.def_readwrite("free_gpu_memory_fraction", &tb::kv_cache_manager::KvCacheConfig::freeGpuMemoryFraction);
|
||||
|
||||
py::class_<tr::GptSession::Config>(m, "GptSessionConfig")
|
||||
.def(py::init<tr::SizeType, tr::SizeType, tr::SizeType>(), py::arg("max_batch_size"), py::arg("max_beam_width"),
|
||||
py::arg("max_sequence_length"))
|
||||
.def_readwrite("max_batch_size", &tr::GptSession::Config::maxBatchSize)
|
||||
.def_readwrite("max_beam_width", &tr::GptSession::Config::maxBeamWidth)
|
||||
.def_readwrite("max_sequence_length", &tr::GptSession::Config::maxSequenceLength)
|
||||
.def_readwrite("decoder_per_request", &tr::GptSession::Config::decoderPerRequest)
|
||||
.def_readwrite("cuda_graph_mode", &tr::GptSession::Config::cudaGraphMode)
|
||||
.def_readwrite("ctx_micro_batch_size", &tr::GptSession::Config::ctxMicroBatchSize)
|
||||
.def_readwrite("gen_micro_batch_size", &tr::GptSession::Config::genMicroBatchSize)
|
||||
.def_readwrite("kv_cache_config", &tr::GptSession::Config::kvCacheConfig);
|
||||
|
||||
py::enum_<nvinfer1::DataType>(m, "DataType")
|
||||
.value("FLOAT", nvinfer1::DataType::kFLOAT)
|
||||
.value("HALF", nvinfer1::DataType::kHALF)
|
||||
.value("INT8", nvinfer1::DataType::kINT8)
|
||||
.value("INT32", nvinfer1::DataType::kINT32)
|
||||
.value("BOOL", nvinfer1::DataType::kBOOL)
|
||||
.value("UINT8", nvinfer1::DataType::kUINT8)
|
||||
.value("FP8", nvinfer1::DataType::kFP8)
|
||||
.value("BF16", nvinfer1::DataType::kBF16)
|
||||
.value("INT64", nvinfer1::DataType::kINT64)
|
||||
.export_values();
|
||||
|
||||
py::enum_<tr::GptModelConfig::ModelVariant>(m, "GptModelVariant")
|
||||
.value("GPT", tr::GptModelConfig::ModelVariant::kGpt)
|
||||
.value("GLM", tr::GptModelConfig::ModelVariant::kGlm);
|
||||
|
||||
py::class_<tc::QuantMode>(m, "QuantMode")
|
||||
.def_static("none", &tc::QuantMode::none)
|
||||
.def_static("int4_weights", &tc::QuantMode::int4Weights)
|
||||
.def_static("int8_weights", &tc::QuantMode::int8Weights)
|
||||
.def_static("activations", &tc::QuantMode::activations)
|
||||
.def_static("per_channel_scaling", &tc::QuantMode::perChannelScaling)
|
||||
.def_static("per_token_scaling", &tc::QuantMode::perTokenScaling)
|
||||
.def_static("per_group_scaling", &tc::QuantMode::perGroupScaling)
|
||||
.def_static("int8_kv_cache", &tc::QuantMode::int8KvCache)
|
||||
.def_static("fp8_kv_cache", &tc::QuantMode::fp8KvCache)
|
||||
.def_static("fp8_qdq", &tc::QuantMode::fp8Qdq)
|
||||
.def_property_readonly("value", &tc::QuantMode::value)
|
||||
.def("is_set", &tc::QuantMode::isSet, py::arg("mode"))
|
||||
.def_property_readonly("has_int4_weights", &tc::QuantMode::hasInt4Weights)
|
||||
.def_property_readonly("has_int8_weights", &tc::QuantMode::hasInt8Weights)
|
||||
.def_property_readonly("has_activations", &tc::QuantMode::hasActivations)
|
||||
.def_property_readonly("has_per_channel_scaling", &tc::QuantMode::hasPerChannelScaling)
|
||||
.def_property_readonly("has_per_token_scaling", &tc::QuantMode::hasPerTokenScaling)
|
||||
.def_property_readonly("has_per_group_scaling", &tc::QuantMode::hasPerGroupScaling)
|
||||
.def_property_readonly("has_static_activation_scaling", &tc::QuantMode::hasStaticActivationScaling)
|
||||
.def_property_readonly("has_int8_kv_cache", &tc::QuantMode::hasInt8KvCache)
|
||||
.def_property_readonly("has_fp8_kv_cache", &tc::QuantMode::hasFp8KvCache)
|
||||
.def_property_readonly("has_fp8_qdq", &tc::QuantMode::hasFp8Qdq)
|
||||
.def_property_readonly("has_kv_cache_quant", &tc::QuantMode::hasKvCacheQuant)
|
||||
.def_static("from_description", &tc::QuantMode::fromDescription, py::arg("quantize_weights") = false,
|
||||
py::arg("quantize_activations") = false, py::arg("per_token") = false, py::arg("per_channel") = false,
|
||||
py::arg("use_int4_weights") = false, py::arg("use_int8_kv_cache") = false,
|
||||
py::arg("use_fp8_kv_kache") = false, py::arg("use_fp8_qdq") = false)
|
||||
.def(py::self + py::self)
|
||||
.def(py::self += py::self)
|
||||
.def(py::self - py::self)
|
||||
.def(py::self -= py::self)
|
||||
.def(py::self == py::self)
|
||||
.def(py::self != py::self);
|
||||
|
||||
py::class_<tr::GptModelConfig>(m, "GptModelConfig")
|
||||
.def(py::init<tr::SizeType, tr::SizeType, tr::SizeType, tr::SizeType, nvinfer1::DataType>(),
|
||||
py::arg("vocab_size"), py::arg("num_layers"), py::arg("num_heads"), py::arg("hidden_size"),
|
||||
py::arg("data_type"))
|
||||
.def_property_readonly("vocab_size", &tr::GptModelConfig::getVocabSize)
|
||||
.def("vocab_size_padded", &tr::GptModelConfig::getVocabSizePadded, py::arg("world_size"))
|
||||
.def("num_layers", &tr::GptModelConfig::getNbLayers, py::arg("pipeline_parallelism") = 1)
|
||||
.def_property_readonly("num_heads", &tr::GptModelConfig::getNbHeads)
|
||||
.def_property_readonly("hidden_size", &tr::GptModelConfig::getHiddenSize)
|
||||
.def_property_readonly("size_per_head", &tr::GptModelConfig::getSizePerHead)
|
||||
.def_property_readonly("data_type", &tr::GptModelConfig::getDataType)
|
||||
.def_property("num_kv_heads", &tr::GptModelConfig::getNbKvHeads, &tr::GptModelConfig::setNbKvHeads)
|
||||
.def_property("use_gpt_attention_plugin",
|
||||
py::overload_cast<>(&tr::GptModelConfig::useGptAttentionPlugin, py::const_),
|
||||
py::overload_cast<bool>(&tr::GptModelConfig::useGptAttentionPlugin))
|
||||
.def_property("use_packed_input", py::overload_cast<>(&tr::GptModelConfig::usePackedInput, py::const_),
|
||||
py::overload_cast<bool>(&tr::GptModelConfig::usePackedInput))
|
||||
.def_property("use_paged_kv_cache", py::overload_cast<>(&tr::GptModelConfig::usePagedKvCache, py::const_),
|
||||
py::overload_cast<bool>(&tr::GptModelConfig::usePagedKvCache))
|
||||
.def_property(
|
||||
"tokens_per_block", &tr::GptModelConfig::getTokensPerBlock, &tr::GptModelConfig::setTokensPerBlock)
|
||||
.def_property("quant_mode", &tr::GptModelConfig::getQuantMode, &tr::GptModelConfig::setQuantMode)
|
||||
.def_property_readonly("supports_inflight_batching", &tr::GptModelConfig::supportsInflightBatching)
|
||||
.def_property("max_batch_size", &tr::GptModelConfig::getMaxBatchSize, &tr::GptModelConfig::setMaxBatchSize)
|
||||
.def_property("max_input_len", &tr::GptModelConfig::getMaxInputLen, &tr::GptModelConfig::setMaxInputLen)
|
||||
.def_property("max_output_len", &tr::GptModelConfig::getMaxOutputLen, &tr::GptModelConfig::setMaxOutputLen)
|
||||
.def_property("max_num_tokens", &tr::GptModelConfig::getMaxNumTokens, &tr::GptModelConfig::setMaxNumTokens)
|
||||
.def_property("compute_context_logits",
|
||||
py::overload_cast<>(&tr::GptModelConfig::computeContextLogits, py::const_),
|
||||
py::overload_cast<bool>(&tr::GptModelConfig::computeContextLogits))
|
||||
.def_property("model_variant", &tr::GptModelConfig::getModelVariant, &tr::GptModelConfig::setModelVariant)
|
||||
.def_property("use_custom_all_reduce", py::overload_cast<>(&tr::GptModelConfig::useCustomAllReduce, py::const_),
|
||||
py::overload_cast<bool>(&tr::GptModelConfig::useCustomAllReduce));
|
||||
|
||||
py::class_<tr::WorldConfig>(m, "WorldConfig")
|
||||
.def(py::init<tr::SizeType, tr::SizeType, tr::SizeType, tr::SizeType>(), py::arg("tensor_parallelism") = 1,
|
||||
py::arg("pipeline_parallelism") = 1, py::arg("rank") = 0,
|
||||
py::arg("gpus_per_node") = tr::WorldConfig::kDefaultGpusPerNode)
|
||||
.def_property_readonly("size", &tr::WorldConfig::getSize)
|
||||
.def_property_readonly("tensor_parallelism", &tr::WorldConfig::getTensorParallelism)
|
||||
.def_property_readonly("pipeline_parallelism", &tr::WorldConfig::getPipelineParallelism)
|
||||
.def_property_readonly("is_tensor_parallel", &tr::WorldConfig::isTensorParallel)
|
||||
.def_property_readonly("is_pipeline_parallel", &tr::WorldConfig::isPipelineParallel)
|
||||
.def_property_readonly("rank", &tr::WorldConfig::getRank)
|
||||
.def_property_readonly("gpus_per_node", &tr::WorldConfig::getGpusPerNode)
|
||||
.def_property_readonly("device", &tr::WorldConfig::getDevice)
|
||||
.def_property_readonly("pipeline_parallel_rank", &tr::WorldConfig::getPipelineParallelRank)
|
||||
.def_property_readonly("tensor_parallel_rank", &tr::WorldConfig::getTensorParallelRank)
|
||||
.def_static("mpi",
|
||||
py::overload_cast<tr::SizeType, std::optional<tr::SizeType>, std::optional<tr::SizeType>>(
|
||||
&tr::WorldConfig::mpi),
|
||||
py::arg("gpus_per_node") = tr::WorldConfig::kDefaultGpusPerNode, py::arg("tensor_parallelism") = py::none(),
|
||||
py::arg("pipeline_parallelism") = py::none());
|
||||
|
||||
py::class_<tr::SamplingConfig>(m, "SamplingConfig")
|
||||
.def(py::init<tr::SizeType>(), py::arg("beam_width") = 1)
|
||||
.def_readwrite("beam_width", &tr::SamplingConfig::beamWidth)
|
||||
.def_readwrite("temperature", &tr::SamplingConfig::temperature)
|
||||
.def_readwrite("min_length", &tr::SamplingConfig::minLength)
|
||||
.def_readwrite("repetition_penalty", &tr::SamplingConfig::repetitionPenalty)
|
||||
.def_readwrite("presence_penalty", &tr::SamplingConfig::presencePenalty)
|
||||
.def_readwrite("top_k", &tr::SamplingConfig::topK)
|
||||
.def_readwrite("top_p", &tr::SamplingConfig::topP)
|
||||
.def_readwrite("random_seed", &tr::SamplingConfig::randomSeed)
|
||||
.def_readwrite("top_p_decay", &tr::SamplingConfig::topPDecay)
|
||||
.def_readwrite("top_p_min", &tr::SamplingConfig::topPMin)
|
||||
.def_readwrite("top_p_reset_ids", &tr::SamplingConfig::topPResetIds)
|
||||
.def_readwrite("beam_search_diversity_rate", &tr::SamplingConfig::beamSearchDiversityRate)
|
||||
.def_readwrite("length_penalty", &tr::SamplingConfig::lengthPenalty);
|
||||
|
||||
py::class_<tr::GptJsonConfig>(m, "GptJsonConfig")
|
||||
.def(py::init<std::string, std::string, tr::SizeType, tr::SizeType, tr::GptModelConfig>(), py::arg("name"),
|
||||
py::arg("precision"), py::arg("tensor_parallelism"), py::arg("pipeline_parallelism"),
|
||||
py::arg("model_config"))
|
||||
.def_static("parse", py::overload_cast<std::string const&>(&tr::GptJsonConfig::parse), py::arg("json"))
|
||||
.def_static(
|
||||
"parse_file", [](std::string const& file) { return tr::GptJsonConfig::parse(std::filesystem::path(file)); },
|
||||
py::arg("file"))
|
||||
.def_property_readonly("model_config", &tr::GptJsonConfig::getModelConfig)
|
||||
.def_property_readonly("name", &tr::GptJsonConfig::getName)
|
||||
.def_property_readonly("precision", &tr::GptJsonConfig::getPrecision)
|
||||
.def_property_readonly("tensor_parallelism", &tr::GptJsonConfig::getTensorParallelism)
|
||||
.def_property_readonly("pipeline_parallelism", &tr::GptJsonConfig::getPipelineParallelism)
|
||||
.def_property_readonly("world_size", &tr::GptJsonConfig::getWorldSize)
|
||||
.def("engine_filename",
|
||||
py::overload_cast<const tr::WorldConfig&, const std::string&>(
|
||||
&tr::GptJsonConfig::engineFilename, py::const_),
|
||||
py::arg("world_config"), py::arg("model"))
|
||||
.def("engine_filename",
|
||||
py::overload_cast<const tr::WorldConfig&>(&tr::GptJsonConfig::engineFilename, py::const_),
|
||||
py::arg("world_config"));
|
||||
|
||||
py::class_<tr::GptSession>(m, "GptSession")
|
||||
.def(py::init<tr::GptSession::Config, tr::GptModelConfig, tr::WorldConfig, std::string>(), py::arg("config"),
|
||||
py::arg("model_config"), py::arg("world_config"), py::arg("engine_file"))
|
||||
.def_property_readonly("model_config", &tr::GptSession::getModelConfig)
|
||||
.def_property_readonly("world_config", &tr::GptSession::getWorldConfig)
|
||||
.def_property_readonly("device", &tr::GptSession::getDevice)
|
||||
.def(
|
||||
"generate",
|
||||
[](tr::GptSession& self, tpr::GenerationOutput& outputs, tpr::GenerationInput const& inputs,
|
||||
tr::SamplingConfig const& samplingConfig)
|
||||
{ self.generate(*outputs.toTrtLlm(), *inputs.toTrtLlm(), samplingConfig); },
|
||||
py::arg("outputs"), py::arg("inputs"), py::arg("sampling_config"));
|
||||
}
|
||||
54
cpp/tensorrt_llm/pybind/runtime/generationInput.cpp
Normal file
54
cpp/tensorrt_llm/pybind/runtime/generationInput.cpp
Normal file
@ -0,0 +1,54 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "generationInput.h"
|
||||
|
||||
#include "tensorrt_llm/runtime/generationInput.h"
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
|
||||
using namespace tensorrt_llm::pybind::runtime;
|
||||
|
||||
std::shared_ptr<tr::PromptTuningParams> PromptTuningParams::toTrtLlm() const
|
||||
{
|
||||
auto ptt = std::make_shared<tr::PromptTuningParams>();
|
||||
if (embeddingTable)
|
||||
ptt->embeddingTable = tr::TorchView::of(embeddingTable.value());
|
||||
if (tasks)
|
||||
ptt->tasks = tr::TorchView::of(tasks.value());
|
||||
if (vocabSize)
|
||||
ptt->vocabSize = tr::TorchView::of(vocabSize.value());
|
||||
ptt->promptTuningEnabled = promptTuningEnabled;
|
||||
return ptt;
|
||||
}
|
||||
|
||||
std::shared_ptr<tr::GenerationInput> GenerationInput::toTrtLlm() const
|
||||
{
|
||||
auto input = std::make_shared<tr::GenerationInput>(
|
||||
endId, padId, tr::TorchView::of(ids.value()), tr::TorchView::of(lengths.value()), packed);
|
||||
if (embeddingBiasOpt)
|
||||
input->embeddingBiasOpt = tr::TorchView::of(embeddingBiasOpt.value());
|
||||
if (badWordsList)
|
||||
input->badWordsList = tr::TorchView::of(badWordsList.value());
|
||||
if (stopWordsList)
|
||||
input->stopWordsList = tr::TorchView::of(stopWordsList.value());
|
||||
input->maxNewTokens = maxNewTokens;
|
||||
input->promptTuningParams = *promptTuningParams.toTrtLlm();
|
||||
return input;
|
||||
|
||||
return input;
|
||||
}
|
||||
66
cpp/tensorrt_llm/pybind/runtime/generationInput.h
Normal file
66
cpp/tensorrt_llm/pybind/runtime/generationInput.h
Normal file
@ -0,0 +1,66 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/generationInput.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <ATen/ops/tensor.h>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
namespace tensorrt_llm::pybind::runtime
|
||||
{
|
||||
|
||||
using SizeType = tensorrt_llm::runtime::SizeType;
|
||||
|
||||
class PromptTuningParams : public tensorrt_llm::runtime::GenericPromptTuningParams<std::optional<at::Tensor>>
|
||||
{
|
||||
public:
|
||||
using Base = tensorrt_llm::runtime::GenericPromptTuningParams<std::optional<at::Tensor>>;
|
||||
using TensorPtr = Base::TensorPtr;
|
||||
using SizeType = Base::SizeType;
|
||||
|
||||
explicit PromptTuningParams(
|
||||
TensorPtr embeddingTable = TensorPtr(), TensorPtr tasks = TensorPtr(), TensorPtr vocabSize = TensorPtr())
|
||||
: GenericPromptTuningParams(std::move(embeddingTable), std::move(tasks), std::move(vocabSize))
|
||||
{
|
||||
}
|
||||
|
||||
[[nodiscard]] std::shared_ptr<tensorrt_llm::runtime::PromptTuningParams> toTrtLlm() const;
|
||||
};
|
||||
|
||||
class GenerationInput
|
||||
: public tensorrt_llm::runtime::GenericGenerationInput<std::optional<at::Tensor>, PromptTuningParams>
|
||||
{
|
||||
public:
|
||||
using Base = tensorrt_llm::runtime::GenericGenerationInput<std::optional<at::Tensor>, PromptTuningParams>;
|
||||
using TensorPtr = Base::TensorPtr;
|
||||
|
||||
explicit GenerationInput(
|
||||
SizeType const endId, SizeType const padId, TensorPtr ids, TensorPtr lengths, bool packed = false)
|
||||
: GenericGenerationInput(endId, padId, std::move(ids), std::move(lengths), packed)
|
||||
{
|
||||
}
|
||||
|
||||
[[nodiscard]] std::shared_ptr<tensorrt_llm::runtime::GenerationInput> toTrtLlm() const;
|
||||
};
|
||||
} // namespace tensorrt_llm::pybind::runtime
|
||||
39
cpp/tensorrt_llm/pybind/runtime/generationOutput.cpp
Normal file
39
cpp/tensorrt_llm/pybind/runtime/generationOutput.cpp
Normal file
@ -0,0 +1,39 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "generationOutput.h"
|
||||
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
|
||||
using namespace tensorrt_llm::pybind::runtime;
|
||||
|
||||
std::shared_ptr<tr::GenerationOutput> GenerationOutput::toTrtLlm() const
|
||||
{
|
||||
auto output
|
||||
= std::make_shared<tr::GenerationOutput>(tr::TorchView::of(ids.value()), tr::TorchView::of(lengths.value()));
|
||||
if (logProbs)
|
||||
{
|
||||
output->logProbs = tr::TorchView::of(logProbs.value());
|
||||
}
|
||||
if (contextLogits)
|
||||
{
|
||||
output->contextLogits = tr::TorchView::of(contextLogits.value());
|
||||
}
|
||||
// TODO(mseznec): add support for onTokenGenerated
|
||||
return output;
|
||||
}
|
||||
41
cpp/tensorrt_llm/pybind/runtime/generationOutput.h
Normal file
41
cpp/tensorrt_llm/pybind/runtime/generationOutput.h
Normal file
@ -0,0 +1,41 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/runtime/generationOutput.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <optional>
|
||||
|
||||
namespace tensorrt_llm::pybind::runtime
|
||||
{
|
||||
|
||||
class GenerationOutput : public tensorrt_llm::runtime::GenericGenerationOutput<std::optional<at::Tensor>>
|
||||
{
|
||||
public:
|
||||
using Base = tensorrt_llm::runtime::GenericGenerationOutput<std::optional<at::Tensor>>;
|
||||
using TensorPtr = Base::TensorPtr;
|
||||
|
||||
explicit GenerationOutput(TensorPtr ids, TensorPtr lengths)
|
||||
: GenericGenerationOutput(std::move(ids), std::move(lengths))
|
||||
{
|
||||
}
|
||||
|
||||
[[nodiscard]] std::shared_ptr<tensorrt_llm::runtime::GenerationOutput> toTrtLlm() const;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::pybind::runtime
|
||||
@ -17,6 +17,7 @@ include(FetchContent)
|
||||
set(SRCS
|
||||
utils/numpyUtils.cpp
|
||||
utils/sessionUtils.cpp
|
||||
utils/debugUtils.cu
|
||||
bufferManager.cpp
|
||||
decodingOutput.cpp
|
||||
gptDecoder.cpp
|
||||
@ -28,6 +29,7 @@ set(SRCS
|
||||
ipcUtils.cpp
|
||||
memoryCounters.cpp
|
||||
ncclCommunicator.cpp
|
||||
promptTuningParams.cpp
|
||||
runtimeBuffers.cpp
|
||||
runtimeKernels.cu
|
||||
statefulGptDecoder.cpp
|
||||
|
||||
@ -212,9 +212,28 @@ void GptDecoderBatch::newRequest(
|
||||
TensorPtr endIdTensorPtr{ITensor::slice(constPointerCast(dJointInput.endIds), batchIdx, localBatchSize)};
|
||||
kernels::invokeFill(*endIdTensorPtr, endId, *stream);
|
||||
dInput = std::make_unique<DecodingInput>(inputLength, localBatchSize, dJointInput.logits, endIdTensorPtr);
|
||||
dInput->embeddingBias = request.embeddingBias;
|
||||
dInput->badWordsList = request.badWordsList;
|
||||
dInput->stopWordsList = request.stopWordsList;
|
||||
|
||||
// Here, we need to add leading 1 dimension since decoderInput expects batchSize as leading dim
|
||||
// and decoder_batch::Request doesn't have batch dimension
|
||||
if (request.embeddingBias)
|
||||
{
|
||||
TensorPtr biasView = ITensor::view(request.embeddingBias);
|
||||
biasView->unsqueeze(0);
|
||||
dInput->embeddingBias = biasView;
|
||||
}
|
||||
if (request.badWordsList)
|
||||
{
|
||||
TensorPtr badWordsView = ITensor::view(request.badWordsList);
|
||||
badWordsView->unsqueeze(0);
|
||||
dInput->badWordsList = badWordsView;
|
||||
}
|
||||
if (request.stopWordsList)
|
||||
{
|
||||
TensorPtr stopWordsView = ITensor::view(request.stopWordsList);
|
||||
stopWordsView->unsqueeze(0);
|
||||
dInput->stopWordsList = stopWordsView;
|
||||
}
|
||||
|
||||
TensorPtr sequenceLimitLength{
|
||||
ITensor::slice(constPointerCast(dJointInput.sequenceLimitLength), batchIdx, localBatchSize)};
|
||||
kernels::invokeFill(*sequenceLimitLength, inputLength + maxNewTokens, *stream);
|
||||
@ -437,10 +456,20 @@ void GptDecoderBatch::newBatch(GenerationInput const& inputs, SamplingConfig con
|
||||
inputView = ITensor::slice(inputs.ids, batchIdx, 1);
|
||||
inputView->reshape(inputShape);
|
||||
}
|
||||
auto request = decoder_batch::Request{inputView, std::nullopt, inputs.endId, inputs.padId};
|
||||
request.embeddingBias = inputs.embeddingBiasOpt;
|
||||
request.badWordsList = inputs.badWordsList;
|
||||
request.stopWordsList = inputs.stopWordsList;
|
||||
auto request = decoder_batch::Request{inputView, inputs.maxNewTokens, inputs.endId, inputs.padId};
|
||||
|
||||
if (inputs.embeddingBiasOpt)
|
||||
{
|
||||
TLLM_THROW("newBatch doesn't support embeddingBias yet.");
|
||||
}
|
||||
if (inputs.badWordsList)
|
||||
{
|
||||
TLLM_THROW("newBatch doesn't support badWordsList yet.");
|
||||
}
|
||||
if (inputs.stopWordsList)
|
||||
{
|
||||
TLLM_THROW("newBatch doesn't support stopWordsList yet.");
|
||||
}
|
||||
newRequest(batchIdx, request, extractSamplingConfig(samplingConfig, batchIdx));
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
|
||||
@ -38,9 +38,10 @@ FieldType parseJsonFieldOr(Json const& json, std::string_view name, FieldType de
|
||||
{
|
||||
value = json.at(name).template get<FieldType>();
|
||||
}
|
||||
catch (nlohmann::json::out_of_range&)
|
||||
catch (nlohmann::json::out_of_range& e)
|
||||
{
|
||||
// std::cerr << e.what() << '\n';
|
||||
TLLM_LOG_WARNING("Parameter %s cannot be read from json:", std::string(name).c_str());
|
||||
TLLM_LOG_WARNING(e.what());
|
||||
}
|
||||
return value;
|
||||
}
|
||||
@ -102,6 +103,8 @@ GptJsonConfig parseJson(InputType&& i)
|
||||
auto const maxInputLen = parseJsonFieldOr(builderConfig, "max_input_len", 0);
|
||||
auto const maxOutputLen = parseJsonFieldOr(builderConfig, "max_output_len", 0);
|
||||
auto const maxNumTokens = parseJsonFieldOptional<SizeType>(builderConfig, "max_num_tokens");
|
||||
auto const maxPromptEmbeddingTableSize
|
||||
= parseJsonFieldOr<SizeType>(builderConfig, "max_prompt_embedding_table_size", 0);
|
||||
|
||||
auto const computeContextLogits = parseJsonFieldOr(builderConfig, "gather_all_token_logits", false);
|
||||
|
||||
@ -127,11 +130,12 @@ GptJsonConfig parseJson(InputType&& i)
|
||||
modelConfig.setMaxInputLen(maxInputLen);
|
||||
modelConfig.setMaxOutputLen(maxOutputLen);
|
||||
modelConfig.setMaxNumTokens(maxNumTokens);
|
||||
modelConfig.setMaxPromptEmbeddingTableSize(maxPromptEmbeddingTableSize);
|
||||
|
||||
if (name == std::string("chatglm-6b"))
|
||||
{
|
||||
modelConfig.setModelVariant(GptModelConfig::ModelVariant::kGlm);
|
||||
// kGlm is only for ChatGLM-6B, not for ChatGLM2-6B
|
||||
// kGlm is only for ChatGLM-6B and Glm-10B
|
||||
}
|
||||
|
||||
return GptJsonConfig{name, precision, tensorParallelism, pipelineParallelism, modelConfig};
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
|
||||
#include "iBuffer.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
#include "tensorrt_llm/kernels/decodingKernels.h"
|
||||
#include "tensorrt_llm/runtime/gptDecoderBatch.h"
|
||||
#include "tensorrt_llm/runtime/ipcUtils.h"
|
||||
@ -48,7 +49,6 @@ GptSession::GptSession(Config const& sessionConfig, GptModelConfig const& modelC
|
||||
, mDevice{utils::initDevice(worldConfig)}
|
||||
, mLogger{logger ? std::move(logger) : std::make_shared<TllmLogger>()}
|
||||
, mRuntime{std::make_shared<TllmRuntime>(engineBuffer, engineSize, *mLogger)}
|
||||
, mNumMicroBatches{worldConfig.getPipelineParallelism()}
|
||||
, mDecoders{}
|
||||
, mBuffers{}
|
||||
, mCudaGraphInstances{}
|
||||
@ -59,6 +59,9 @@ GptSession::GptSession(Config const& sessionConfig, GptModelConfig const& modelC
|
||||
mCommStream = std::make_shared<CudaStream>();
|
||||
}
|
||||
|
||||
TLLM_CHECK_WITH_INFO(!(mModelConfig.usePromptTuning() && !mModelConfig.useGptAttentionPlugin()),
|
||||
"Prompt tuning is only enabled with GPT attention plugin.");
|
||||
|
||||
// TODO compare expected and runtime tensor names?
|
||||
|
||||
setup(sessionConfig);
|
||||
@ -74,7 +77,7 @@ BufferManager& GptSession::getBufferManager() const
|
||||
return mRuntime->getBufferManager();
|
||||
}
|
||||
|
||||
void GptSession::createContexts(SizeType numMicroBatches, bool useCudaGraphs)
|
||||
void GptSession::createContexts(SizeType numCtxBatches, SizeType numGenBatches, bool useCudaGraphs)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
mRuntime->clearContexts();
|
||||
@ -82,31 +85,22 @@ void GptSession::createContexts(SizeType numMicroBatches, bool useCudaGraphs)
|
||||
if (useCudaGraphs)
|
||||
{
|
||||
// Instantiate multiple graph instances for flip-flopping
|
||||
mCudaGraphInstances.resize(2 * numMicroBatches);
|
||||
mCudaGraphInstances.resize(2 * numGenBatches);
|
||||
}
|
||||
|
||||
auto const numProfiles = mRuntime->getNbProfiles();
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
numProfiles == 1 || numProfiles == 2, "GPT only expects one optimization profile or two optimization profiles");
|
||||
|
||||
if (numProfiles == 2)
|
||||
{
|
||||
auto constexpr ctxContextId = 0;
|
||||
auto constexpr genContextId = 1;
|
||||
// Instantiate 2 contexts for flip-flopping
|
||||
for (auto i = 0; i < 2 * numMicroBatches; ++i)
|
||||
mRuntime->addContext(genContextId);
|
||||
// Instantiate 1 context for context phase
|
||||
for (auto i = 0; i < numMicroBatches; ++i)
|
||||
mRuntime->addContext(ctxContextId);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto constexpr contextId = 0;
|
||||
// Instantiate 2 contexts for flip-flopping
|
||||
for (auto i = 0; i < 2 * numMicroBatches; ++i)
|
||||
mRuntime->addContext(contextId);
|
||||
}
|
||||
auto constexpr ctxContextId = 0;
|
||||
auto const genContextId = static_cast<std::int32_t>(numProfiles == 2);
|
||||
// Instantiate 2 contexts for flip-flopping
|
||||
for (auto i = 0; i < 2 * numGenBatches; ++i)
|
||||
mRuntime->addContext(genContextId);
|
||||
// Instantiate 1 context for context phase
|
||||
for (auto i = 0; i < numCtxBatches; ++i)
|
||||
mRuntime->addContext(ctxContextId);
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
@ -184,11 +178,48 @@ void GptSession::createCustomAllReduceWorkspace(
|
||||
{
|
||||
setPeerAccess(mWorldConfig, true);
|
||||
|
||||
mIpcMemoryHandles.clear();
|
||||
const std::size_t bufferSize = static_cast<std::size_t>(maxBatchSize) * maxBeamWidth * maxSequenceLength
|
||||
* mModelConfig.getHiddenSize() * mWorldConfig.getTensorParallelism() * sizeof(float);
|
||||
mIpcMemoryHandles.emplace_back(std::make_shared<IpcMemory>(mWorldConfig, bufferSize));
|
||||
mIpcMemoryHandles.emplace_back(std::make_shared<IpcMemory>(mWorldConfig, IpcMemory::FLAGS_SIZE * sizeof(int32_t)));
|
||||
mIpcMemoryHandles.emplace_back(std::make_shared<IpcMemory>(mWorldConfig, IpcMemory::FLAGS_SIZE * sizeof(int32_t)));
|
||||
|
||||
auto& manager = mRuntime->getBufferManager();
|
||||
for (const auto& buffer : mBuffers)
|
||||
mCommPtrs = manager.cpu(
|
||||
ITensor::makeShape({static_cast<SizeType>(mIpcMemoryHandles.size()) * mWorldConfig.getTensorParallelism()}),
|
||||
nvinfer1::DataType::kINT64);
|
||||
const auto commPtrsData = bufferCast<void*>(*mCommPtrs);
|
||||
|
||||
for (size_t memIdx = 0; memIdx < mIpcMemoryHandles.size(); memIdx++)
|
||||
{
|
||||
buffer->createCustomAllReduceWorkspace(
|
||||
maxBatchSize, maxBeamWidth, maxSequenceLength, mModelConfig.getHiddenSize(), mWorldConfig, manager);
|
||||
const auto& memCommPtrs = mIpcMemoryHandles[memIdx]->getCommPtrsTensor();
|
||||
for (SizeType tpIdx = 0; tpIdx < mWorldConfig.getTensorParallelism(); tpIdx++)
|
||||
{
|
||||
commPtrsData[memIdx * mWorldConfig.getTensorParallelism() + tpIdx] = memCommPtrs[tpIdx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GptSession::MicroBatchConfig::MicroBatchConfig(SizeType maxBatchSize, SizeType pipelineParallelism,
|
||||
std::optional<SizeType> genMicroBatchSize, std::optional<SizeType> ctxMicroBatchSize)
|
||||
{
|
||||
if (genMicroBatchSize || ctxMicroBatchSize)
|
||||
{
|
||||
genBatchSize = genMicroBatchSize.value_or(maxBatchSize);
|
||||
TLLM_CHECK(genBatchSize <= maxBatchSize);
|
||||
ctxBatchSize = ctxMicroBatchSize.value_or(genBatchSize);
|
||||
TLLM_CHECK_WITH_INFO(genBatchSize % ctxBatchSize == 0,
|
||||
tc::fmtstr(
|
||||
"Generation batch size (%d) must be divisible by context batch size (%d)", genBatchSize, ctxBatchSize)
|
||||
.c_str());
|
||||
numGenBatches = tc::ceilDiv(maxBatchSize, genBatchSize);
|
||||
numCtxBatches = numGenBatches * (genBatchSize / ctxBatchSize);
|
||||
}
|
||||
else
|
||||
{
|
||||
numCtxBatches = numGenBatches = pipelineParallelism;
|
||||
ctxBatchSize = genBatchSize = tc::ceilDiv(maxBatchSize, numGenBatches);
|
||||
}
|
||||
}
|
||||
|
||||
@ -202,12 +233,12 @@ void GptSession::setup(Config const& sessionConfig)
|
||||
auto const maxBeamWidth = sessionConfig.maxBeamWidth;
|
||||
auto const maxSequenceLength = sessionConfig.maxSequenceLength;
|
||||
|
||||
if (sessionConfig.numMicroBatches)
|
||||
mNumMicroBatches = sessionConfig.numMicroBatches.value();
|
||||
createContexts(mNumMicroBatches, sessionConfig.cudaGraphMode);
|
||||
createBuffers(mNumMicroBatches);
|
||||
mMicroBatchConfig = MicroBatchConfig(maxBatchSize, mWorldConfig.getPipelineParallelism(),
|
||||
sessionConfig.genMicroBatchSize, sessionConfig.ctxMicroBatchSize);
|
||||
|
||||
createContexts(mMicroBatchConfig.numCtxBatches, mMicroBatchConfig.numGenBatches, sessionConfig.cudaGraphMode);
|
||||
createBuffers(mMicroBatchConfig.numGenBatches);
|
||||
|
||||
auto const microBatchSize = tc::ceilDiv(maxBatchSize, mNumMicroBatches);
|
||||
// Store this param related to decoder buffer size and kv cache manager to check against
|
||||
// the input shape with the params given in generate().
|
||||
// gptDecoderBatch does not resize buffers, but allows smaller batchSize and beamWidth.
|
||||
@ -222,28 +253,29 @@ void GptSession::setup(Config const& sessionConfig)
|
||||
if (mWorldConfig.isLastPipelineParallelRank())
|
||||
{
|
||||
auto const logitsType = mRuntime->getEngine().getTensorDataType("logits");
|
||||
createDecoders(microBatchSize, maxBeamWidth, maxSequenceLength, logitsType, sessionConfig.decoderPerRequest,
|
||||
mNumMicroBatches);
|
||||
createDecoders(mMicroBatchConfig.genBatchSize, maxBeamWidth, maxSequenceLength, logitsType,
|
||||
sessionConfig.decoderPerRequest, mMicroBatchConfig.numGenBatches);
|
||||
}
|
||||
|
||||
if (mWorldConfig.isPipelineParallel() || mNumMicroBatches > 1)
|
||||
if (mWorldConfig.isPipelineParallel() || mMicroBatchConfig.numGenBatches > 1)
|
||||
{
|
||||
mReceivedEvents.clear();
|
||||
for (SizeType i = 0; i < mNumMicroBatches; ++i)
|
||||
for (SizeType i = 0; i < mMicroBatchConfig.numGenBatches; ++i)
|
||||
mReceivedEvents.emplace_back();
|
||||
}
|
||||
|
||||
if (mWorldConfig.isTensorParallel() && mModelConfig.useCustomAllReduce())
|
||||
{
|
||||
createCustomAllReduceWorkspace(microBatchSize, maxBeamWidth, maxSequenceLength);
|
||||
createCustomAllReduceWorkspace(mMicroBatchConfig.genBatchSize, maxBeamWidth, maxSequenceLength);
|
||||
}
|
||||
|
||||
// we don't know maxInputLength and maxNewTokens yet and ignore those for pre-allocation
|
||||
auto const generationConfig
|
||||
= RuntimeBuffers::GenerationConfig{microBatchSize, maxBeamWidth, 0, 0, maxSequenceLength};
|
||||
|
||||
for (auto& buffers : mBuffers)
|
||||
buffers->reshape(generationConfig, mModelConfig, mWorldConfig);
|
||||
{
|
||||
// we don't know maxInputLength yet and ignore it for pre-allocation
|
||||
buffers->generationConfig
|
||||
= RuntimeBuffers::GenerationConfig{mMicroBatchConfig.genBatchSize, maxBeamWidth, 0, maxSequenceLength};
|
||||
buffers->reshape(mModelConfig, mWorldConfig);
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
@ -263,8 +295,8 @@ void GptSession::kvCacheAddSequences(SizeType beamWidth, SizeType microBatchId,
|
||||
}
|
||||
}
|
||||
|
||||
ITensor::SharedPtr GptSession::initNewTokens(
|
||||
GenerationInput const& inputs, SamplingConfig const& samplingConfig, SizeType microBatchId)
|
||||
ITensor::SharedPtr GptSession::initDecoder(ITensor& outputIds, GenerationInput const& inputs,
|
||||
SamplingConfig const& samplingConfig, SizeType microBatchId) const
|
||||
{
|
||||
if (mWorldConfig.isLastPipelineParallelRank())
|
||||
{
|
||||
@ -274,9 +306,29 @@ ITensor::SharedPtr GptSession::initNewTokens(
|
||||
}
|
||||
else if (mWorldConfig.isFirstPipelineParallelRank())
|
||||
{
|
||||
auto& manager = mRuntime->getBufferManager();
|
||||
auto const& stream = mRuntime->getStreamPtr();
|
||||
|
||||
auto const inputLengths = inputs.lengths;
|
||||
auto const batchSize = static_cast<SizeType>(inputLengths->getSize());
|
||||
|
||||
auto const inputLengthsHost = manager.copyFrom(*inputLengths, MemoryType::kCPU);
|
||||
auto const* inputLengthsData = bufferCast<SizeType>(*inputLengthsHost);
|
||||
SizeType const maxInputLength = *std::max_element(inputLengthsData, inputLengthsData + inputLengths->getSize());
|
||||
|
||||
ITensor::SharedPtr inputOffsets = manager.emptyTensor(MemoryType::kGPU, TRTDataType<SizeType>::value);
|
||||
if (inputs.packed)
|
||||
{
|
||||
inputOffsets->reshape(ITensor::makeShape({batchSize + 1}));
|
||||
manager.setZero(*inputOffsets);
|
||||
kernels::invokeInclusiveSum(*ITensor::slice(inputOffsets, 1), *inputLengths, manager, *stream);
|
||||
}
|
||||
|
||||
kernels::initOutputIds(outputIds, *inputs.ids, *inputLengths, *inputOffsets, inputs.padId, inputs.endId,
|
||||
maxInputLength, inputs.packed, *stream);
|
||||
|
||||
auto const beamWidth = samplingConfig.beamWidth;
|
||||
auto const batchSize = static_cast<SizeType>(inputs.lengths->getSize());
|
||||
return mRuntime->getBufferManager().gpu(ITensor::makeShape({batchSize, beamWidth}), nvinfer1::DataType::kINT32);
|
||||
return manager.gpu(ITensor::makeShape({batchSize, beamWidth}), nvinfer1::DataType::kINT32);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -286,32 +338,34 @@ ITensor::SharedPtr GptSession::initNewTokens(
|
||||
|
||||
namespace
|
||||
{
|
||||
std::vector<GenerationInput> splitInputs(
|
||||
GenerationInput const& inputs, SizeType numMicroBatches, BufferManager& manager)
|
||||
std::tuple<std::vector<ITensor::SharedPtr>, std::vector<ITensor::SharedPtr>, std::vector<SizeType>> splitInputIds(
|
||||
GenerationInput const& inputs, SizeType microBatchSize, BufferManager& manager)
|
||||
{
|
||||
std::vector<GenerationInput> inputBatches;
|
||||
auto const numRequests = inputs.lengths->getShape().d[0];
|
||||
auto const microBatchSize = tc::ceilDiv(numRequests, numMicroBatches);
|
||||
|
||||
std::vector<ITensor::SharedPtr> inputIds;
|
||||
std::vector<ITensor::SharedPtr> inputLengths;
|
||||
std::vector<SizeType> microBatchOffsets(1, 0);
|
||||
if (inputs.packed)
|
||||
{
|
||||
auto contextLengthsHost = manager.copyFrom(*inputs.lengths, MemoryType::kCPU);
|
||||
auto const contextLengthsHost = manager.copyFrom(*inputs.lengths, MemoryType::kCPU);
|
||||
ITensor::SharedPtr inputIdsView = ITensor::view(inputs.ids);
|
||||
inputIdsView->squeeze(0);
|
||||
auto contextLengthsRange = BufferRange<SizeType>(*contextLengthsHost);
|
||||
auto const contextLengthsRange = BufferRange<SizeType>(*contextLengthsHost);
|
||||
|
||||
auto tokensBegin = 0;
|
||||
for (auto offset = 0; offset < numRequests; offset += microBatchSize)
|
||||
{
|
||||
auto batchSize = std::min(microBatchSize, numRequests - offset);
|
||||
auto numTokens = std::accumulate(
|
||||
auto const batchSize = std::min(microBatchSize, numRequests - offset);
|
||||
auto const numTokens = std::accumulate(
|
||||
contextLengthsRange.begin() + offset, contextLengthsRange.begin() + offset + batchSize, 0);
|
||||
|
||||
ITensor::SharedPtr batchInputs = ITensor::slice(inputIdsView, tokensBegin, numTokens);
|
||||
batchInputs->reshape(ITensor::makeShape({1, numTokens}));
|
||||
|
||||
inputBatches.emplace_back(inputs.endId, inputs.padId, batchInputs,
|
||||
ITensor::slice(inputs.lengths, offset, batchSize), inputs.packed);
|
||||
inputIds.emplace_back(std::move(batchInputs));
|
||||
inputLengths.emplace_back(ITensor::slice(inputs.lengths, offset, batchSize));
|
||||
microBatchOffsets.emplace_back(offset + batchSize);
|
||||
|
||||
tokensBegin += numTokens;
|
||||
}
|
||||
@ -320,24 +374,66 @@ std::vector<GenerationInput> splitInputs(
|
||||
{
|
||||
for (auto offset = 0; offset < numRequests; offset += microBatchSize)
|
||||
{
|
||||
auto batchSize = std::min(microBatchSize, numRequests - offset);
|
||||
inputBatches.emplace_back(inputs.endId, inputs.padId, ITensor::slice(inputs.ids, offset, batchSize),
|
||||
ITensor::slice(inputs.lengths, offset, batchSize), inputs.packed);
|
||||
auto const batchSize = std::min(microBatchSize, numRequests - offset);
|
||||
|
||||
inputIds.emplace_back(ITensor::slice(inputs.ids, offset, batchSize));
|
||||
inputLengths.emplace_back(ITensor::slice(inputs.lengths, offset, batchSize));
|
||||
microBatchOffsets.emplace_back(offset + batchSize);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& batch : inputBatches)
|
||||
return {inputIds, inputLengths, microBatchOffsets};
|
||||
}
|
||||
|
||||
std::vector<GenerationInput> splitInputs(GenerationInput const& inputs, SizeType microBatchSize, BufferManager& manager)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto [inputIds, inputLengths, microBatchOffsets] = splitInputIds(inputs, microBatchSize, manager);
|
||||
|
||||
std::vector<GenerationInput> inputBatches;
|
||||
for (std::size_t batchId = 0; batchId < inputIds.size(); ++batchId)
|
||||
{
|
||||
inputBatches.emplace_back(
|
||||
inputs.endId, inputs.padId, std::move(inputIds[batchId]), std::move(inputLengths[batchId]), inputs.packed);
|
||||
}
|
||||
|
||||
for (std::size_t batchId = 0; batchId < inputBatches.size(); ++batchId)
|
||||
{
|
||||
auto& batch = inputBatches[batchId];
|
||||
auto const offset = microBatchOffsets[batchId];
|
||||
auto const batchSize = microBatchOffsets[batchId + 1] - offset;
|
||||
|
||||
if (inputs.embeddingBiasOpt)
|
||||
batch.embeddingBiasOpt = inputs.embeddingBiasOpt;
|
||||
if (inputs.badWordsList)
|
||||
batch.badWordsList = inputs.badWordsList;
|
||||
{
|
||||
auto const& shape = inputs.badWordsList->getShape();
|
||||
if (shape.nbDims == 2)
|
||||
{
|
||||
batch.badWordsList = inputs.badWordsList;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(nbDims == 3);
|
||||
batch.badWordsList = ITensor::slice(inputs.badWordsList, offset, batchSize);
|
||||
}
|
||||
}
|
||||
if (inputs.stopWordsList)
|
||||
batch.stopWordsList = inputs.stopWordsList;
|
||||
{
|
||||
batch.stopWordsList = ITensor::slice(inputs.stopWordsList, offset, batchSize);
|
||||
}
|
||||
if (inputs.maxNewTokens)
|
||||
batch.maxNewTokens = inputs.maxNewTokens;
|
||||
|
||||
if (inputs.promptTuningParams.embeddingTable)
|
||||
batch.promptTuningParams.embeddingTable = inputs.promptTuningParams.embeddingTable;
|
||||
if (inputs.promptTuningParams.tasks)
|
||||
batch.promptTuningParams.tasks = ITensor::slice(inputs.promptTuningParams.tasks, offset, batchSize);
|
||||
if (inputs.promptTuningParams.vocabSize)
|
||||
batch.promptTuningParams.vocabSize = inputs.promptTuningParams.vocabSize;
|
||||
}
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
return inputBatches;
|
||||
}
|
||||
|
||||
@ -381,40 +477,33 @@ void GptSession::generate(
|
||||
outputs.contextLogits->reshape(ITensor::makeShape({batchSize, maxInputLength, vocabSizePadded}));
|
||||
}
|
||||
|
||||
auto const numMicroBatches = std::min(batchSize, mNumMicroBatches);
|
||||
if (numMicroBatches == 1)
|
||||
if (batchSize <= mMicroBatchConfig.genBatchSize)
|
||||
{
|
||||
std::vector<GenerationInput> microBatches{inputs};
|
||||
generateBatched(outputs, microBatches, samplingConfig);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto const microBatches = splitInputs(inputs, numMicroBatches, manager);
|
||||
auto const microBatches = splitInputs(inputs, mMicroBatchConfig.genBatchSize, manager);
|
||||
generateBatched(outputs, microBatches, samplingConfig);
|
||||
}
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
std::function<void(SizeType microBatchId, SizeType step, bool finished)> GptSession::createOnTokenGeneratedCallback(
|
||||
GenerationOutput& outputs, SizeType numMicroBatches)
|
||||
std::function<void(SizeType step, bool finished)> GptSession::createOnTokenGeneratedCallback(GenerationOutput& outputs)
|
||||
{
|
||||
if (outputs.onTokenGenerated && mWorldConfig.isFirstPipelineParallelRank())
|
||||
{
|
||||
ITensor::SharedPtr outputIds{mWorldConfig.isPipelineParallel() || mNumMicroBatches > 1
|
||||
ITensor::SharedPtr outputIds{mWorldConfig.isPipelineParallel() || mMicroBatchConfig.numGenBatches > 1
|
||||
? outputs.ids
|
||||
: mDecoders.front()->getOutputIds()};
|
||||
auto const lastMicroBatchId = numMicroBatches - 1;
|
||||
return [onTokenGenerated = outputs.onTokenGenerated, outputIds = std::move(outputIds), lastMicroBatchId](
|
||||
SizeType microBatchId, SizeType step, bool finished)
|
||||
{
|
||||
if (microBatchId == lastMicroBatchId)
|
||||
onTokenGenerated(outputIds, step, finished);
|
||||
};
|
||||
return [onTokenGenerated = outputs.onTokenGenerated, outputIds = std::move(outputIds)](
|
||||
SizeType step, bool finished) { onTokenGenerated(outputIds, step, finished); };
|
||||
}
|
||||
else
|
||||
{
|
||||
return [](SizeType microBatchId, SizeType step, bool finished) {};
|
||||
return [](SizeType step, bool finished) {};
|
||||
}
|
||||
}
|
||||
|
||||
@ -426,52 +515,50 @@ void GptSession::generateBatched(
|
||||
auto& manager = mRuntime->getBufferManager();
|
||||
auto const numMicroBatches = static_cast<SizeType>(microBatches.size());
|
||||
TLLM_CHECK(numMicroBatches > 0);
|
||||
TLLM_CHECK(numMicroBatches <= mNumMicroBatches);
|
||||
TLLM_CHECK(numMicroBatches <= mMicroBatchConfig.numGenBatches);
|
||||
SizeType const beamWidth{samplingConfig.beamWidth};
|
||||
|
||||
// Initialize and reshape buffers
|
||||
std::vector<RuntimeBuffers::GenerationConfig> generationConfigs;
|
||||
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
|
||||
{
|
||||
auto const& microBatchInputs = microBatches.at(microBatchId);
|
||||
auto& buffers = *mBuffers.at(microBatchId);
|
||||
TLLM_CHECK_WITH_INFO(buffers.allocated, "Buffers not allocated, please call setup first!");
|
||||
buffers.initContextLengths(microBatchInputs.lengths, manager);
|
||||
generationConfigs.emplace_back(
|
||||
RuntimeBuffers::GenerationConfig::fromInput(*microBatchInputs.ids, *buffers.contextLengthsHost,
|
||||
microBatchInputs.packed, beamWidth, mDecoderMaxSequenceLength, microBatchInputs.maxNewTokens));
|
||||
buffers.reshape(generationConfigs.back(), mModelConfig, mWorldConfig);
|
||||
buffers.initFromInput(*microBatchInputs.ids, microBatchInputs.lengths, microBatchInputs.packed, beamWidth,
|
||||
mDecoderMaxSequenceLength, manager);
|
||||
buffers.reshape(mModelConfig, mWorldConfig);
|
||||
buffers.reset(manager);
|
||||
}
|
||||
|
||||
auto minMaxNewTokens = std::numeric_limits<SizeType>::max();
|
||||
std::vector<SizeType> microBatchOffsets(1, 0);
|
||||
microBatchOffsets.reserve(numMicroBatches + 1);
|
||||
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
|
||||
{
|
||||
auto const& generationConfig = generationConfigs.at(microBatchId);
|
||||
minMaxNewTokens = std::min(minMaxNewTokens, generationConfig.maxNewTokens);
|
||||
auto const& generationConfig = mBuffers.at(microBatchId)->generationConfig;
|
||||
microBatchOffsets.emplace_back(microBatchOffsets.back() + generationConfig.batchSize);
|
||||
}
|
||||
|
||||
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
|
||||
{
|
||||
auto& buffers = *mBuffers.at(microBatchId);
|
||||
auto const& generationConfig = generationConfigs.at(microBatchId);
|
||||
auto const batchOffset = microBatchOffsets.at(microBatchId);
|
||||
kvCacheAddSequences(beamWidth, microBatchId, batchOffset);
|
||||
auto const& microBatchInputs = microBatches.at(microBatchId);
|
||||
buffers.newTokens = initNewTokens(microBatchInputs, samplingConfig, microBatchId);
|
||||
auto const microBatchSize = generationConfig.batchSize;
|
||||
auto const microBatchSize = buffers.generationConfig.batchSize;
|
||||
buffers.outputIds = ITensor::slice(outputs.ids, batchOffset, microBatchSize);
|
||||
buffers.outputLengths = ITensor::slice(outputs.lengths, batchOffset, microBatchSize);
|
||||
buffers.newTokens = initDecoder(*buffers.outputIds, microBatchInputs, samplingConfig, microBatchId);
|
||||
if (mWorldConfig.isLastPipelineParallelRank() && mModelConfig.computeContextLogits())
|
||||
{
|
||||
buffers.logits = ITensor::slice(outputs.contextLogits, batchOffset, microBatchSize);
|
||||
}
|
||||
if (mModelConfig.usePromptTuning())
|
||||
{
|
||||
buffers.promptTuningParams = microBatchInputs.promptTuningParams;
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare the onTokenGenerated callback
|
||||
auto const onTokenGenerated = createOnTokenGeneratedCallback(outputs, numMicroBatches);
|
||||
auto const onTokenGenerated = createOnTokenGeneratedCallback(outputs);
|
||||
|
||||
if (useCudaGraphs())
|
||||
{
|
||||
@ -483,101 +570,25 @@ void GptSession::generateBatched(
|
||||
|
||||
auto kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManager.get() : nullptr;
|
||||
|
||||
std::vector<RuntimeBuffers::TensorMap> inputBuffers(numMicroBatches * 2);
|
||||
std::vector<RuntimeBuffers::TensorMap> outputBuffers(numMicroBatches * 2);
|
||||
executeContextStep(microBatches, microBatchOffsets, kvCacheManager);
|
||||
|
||||
std::vector<bool> microBatchesFinished(numMicroBatches, false);
|
||||
auto notFinished = [µBatchesFinished]()
|
||||
{ return std::any_of(microBatchesFinished.begin(), microBatchesFinished.end(), [](bool x) { return !x; }); };
|
||||
|
||||
for (SizeType step = 0; step < minMaxNewTokens && notFinished(); ++step)
|
||||
SizeType numBatchesFinished{0};
|
||||
SizeType step{0};
|
||||
while (numBatchesFinished < numMicroBatches)
|
||||
{
|
||||
auto const flipFlopId = step % 2;
|
||||
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
|
||||
{
|
||||
if (microBatchesFinished.at(microBatchId))
|
||||
continue;
|
||||
++step;
|
||||
numBatchesFinished
|
||||
+= executeGenerationStep(step, microBatches, microBatchOffsets, kvCacheManager, microBatchesFinished);
|
||||
|
||||
auto& buffers = *mBuffers.at(microBatchId);
|
||||
auto& generationConfig = generationConfigs.at(microBatchId);
|
||||
|
||||
auto const contextId = flipFlopId * numMicroBatches + microBatchId;
|
||||
auto& inputBuffer = inputBuffers[contextId];
|
||||
auto& outputBuffer = outputBuffers[contextId];
|
||||
|
||||
if (step == 0)
|
||||
{
|
||||
SizeType const contextIdForContextPhase
|
||||
= (mRuntime->getNbProfiles() == 2 ? 2 * mNumMicroBatches : 0) + microBatchId;
|
||||
|
||||
auto const& microBatchInputs = microBatches.at(microBatchId);
|
||||
buffers.prepareContextStep(microBatchInputs.ids, microBatchInputs.padId, manager, kvCacheManager,
|
||||
microBatchOffsets.at(microBatchId), generationConfig, mModelConfig, mWorldConfig);
|
||||
buffers.getRuntimeBuffers(
|
||||
inputBuffer, outputBuffer, step, microBatchInputs.ids, mModelConfig, mWorldConfig);
|
||||
mRuntime->setInputTensors(contextIdForContextPhase, inputBuffer);
|
||||
mRuntime->setOutputTensors(contextIdForContextPhase, outputBuffer);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mRuntime->executeContext(contextIdForContextPhase), "Executing TRT engine in context step failed!");
|
||||
sync_check_cuda_error();
|
||||
|
||||
buffers.postContextStep(manager, generationConfig, mModelConfig, mWorldConfig);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
else
|
||||
{
|
||||
auto nextInputIds = buffers.prepareNextStep(step - 1, manager, kvCacheManager,
|
||||
microBatchOffsets.at(microBatchId), generationConfig, mModelConfig, mWorldConfig);
|
||||
buffers.getRuntimeBuffers(inputBuffer, outputBuffer, step, nextInputIds, mModelConfig, mWorldConfig);
|
||||
mRuntime->setInputTensors(contextId, inputBuffer);
|
||||
mRuntime->setOutputTensors(contextId, outputBuffer);
|
||||
|
||||
if (useCudaGraphs())
|
||||
{
|
||||
mCudaGraphInstances.at(contextId).prepareNextGraph(*mRuntime, contextId);
|
||||
}
|
||||
|
||||
// check decoder result of previous iteration
|
||||
auto const microBatchSize = generationConfig.batchSize;
|
||||
auto const shouldStop = shouldStopSync(microBatchSize, beamWidth, microBatchId);
|
||||
onTokenGenerated(microBatchId, step - 1, shouldStop);
|
||||
|
||||
if (shouldStop)
|
||||
{
|
||||
mLogger->log(nvinfer1::ILogger::Severity::kVERBOSE, "GPT decoding finished early");
|
||||
microBatchesFinished.at(microBatchId) = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (useCudaGraphs())
|
||||
{
|
||||
auto& cudaGraphInstance = mCudaGraphInstances.at(contextId);
|
||||
TLLM_CHECK(cudaGraphInstance.hasInstance());
|
||||
cudaGraphInstance.launch(mRuntime->getStream());
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mRuntime->executeContext(contextId),
|
||||
tc::fmtstr("Executing TRT engine in step %d failed!", step));
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
std::swap(buffers.cacheIndirectionDecoderInput, buffers.cacheIndirectionDecoderOutput);
|
||||
|
||||
auto const maxInputLength = generationConfigs.at(microBatchId).maxInputLength;
|
||||
auto const decoderStep = maxInputLength + step;
|
||||
decoderStepAsync(decoderStep, microBatchId);
|
||||
}
|
||||
onTokenGenerated(step - 1, numBatchesFinished == numMicroBatches);
|
||||
}
|
||||
|
||||
// Collect the results for the last step
|
||||
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
|
||||
{
|
||||
auto const& generationConfig = generationConfigs.at(microBatchId);
|
||||
auto const& generationConfig = mBuffers.at(microBatchId)->generationConfig;
|
||||
auto const microBatchSize = generationConfig.batchSize;
|
||||
auto const shouldStop = shouldStopSync(microBatchSize, beamWidth, microBatchId);
|
||||
onTokenGenerated(microBatchId, minMaxNewTokens - 1, shouldStop);
|
||||
|
||||
auto const firstBatchIdx = microBatchOffsets.at(microBatchId);
|
||||
if (mModelConfig.usePagedKvCache())
|
||||
@ -594,10 +605,129 @@ void GptSession::generateBatched(
|
||||
else if (!mWorldConfig.isPipelineParallel())
|
||||
manager.copy(*mDecoders.at(microBatchId)->getOutputIds(), *mBuffers.at(microBatchId)->outputIds);
|
||||
}
|
||||
|
||||
manager.getStream().synchronize();
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptSession::executeContextStep(std::vector<GenerationInput> const& generationBatches,
|
||||
std::vector<SizeType> const& generationBatchOffsets, KvCacheManager const* kvCacheManager)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto& manager = mRuntime->getBufferManager();
|
||||
|
||||
auto const numGenerationBatches = static_cast<SizeType>(generationBatches.size());
|
||||
auto constexpr step = 0;
|
||||
for (auto generationBatchId = 0; generationBatchId < numGenerationBatches; ++generationBatchId)
|
||||
{
|
||||
auto const& generationBatchInputs = generationBatches.at(generationBatchId);
|
||||
auto& generationBuffers = *mBuffers.at(generationBatchId);
|
||||
|
||||
auto const contextBatchSize = mMicroBatchConfig.ctxBatchSize;
|
||||
auto [inputIds, inputLengths, contextBatchOffsets]
|
||||
= splitInputIds(generationBatchInputs, contextBatchSize, manager);
|
||||
auto contextBuffers = generationBuffers.split(contextBatchSize, mModelConfig, mWorldConfig);
|
||||
TLLM_CHECK(inputIds.size() == contextBuffers.size());
|
||||
auto const numContextBatches = static_cast<SizeType>(contextBuffers.size());
|
||||
|
||||
for (auto contextBatchId = 0; contextBatchId < numContextBatches; ++contextBatchId)
|
||||
{
|
||||
auto batchOffset = generationBatchOffsets.at(generationBatchId) + contextBatchOffsets.at(contextBatchId);
|
||||
auto& buffers = contextBuffers.at(contextBatchId);
|
||||
auto& inputBuffer = buffers.inputBuffers[0];
|
||||
auto& outputBuffer = buffers.outputBuffers[0];
|
||||
|
||||
auto const contextId = mMicroBatchConfig.getCtxContextId(generationBatchId, contextBatchId);
|
||||
|
||||
buffers.prepareContextStep(inputIds.at(contextBatchId), generationBatchInputs.padId, manager,
|
||||
kvCacheManager, batchOffset, mModelConfig, mWorldConfig);
|
||||
buffers.getRuntimeBuffers(
|
||||
inputBuffer, outputBuffer, step, inputIds.at(contextBatchId), mCommPtrs, mModelConfig, mWorldConfig);
|
||||
mRuntime->setInputTensors(contextId, inputBuffer);
|
||||
mRuntime->setOutputTensors(contextId, outputBuffer);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(mRuntime->executeContext(contextId), "Executing TRT engine in context step failed!");
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
generationBuffers.postContextStep(contextBuffers, manager, mModelConfig, mWorldConfig);
|
||||
sync_check_cuda_error();
|
||||
|
||||
std::swap(generationBuffers.cacheIndirectionDecoderInput, generationBuffers.cacheIndirectionDecoderOutput);
|
||||
|
||||
auto const decoderStep = generationBuffers.generationConfig.maxInputLength + step;
|
||||
decoderStepAsync(decoderStep, generationBatchId);
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
SizeType GptSession::executeGenerationStep(SizeType step, std::vector<GenerationInput> const& microBatches,
|
||||
std::vector<SizeType> const& microBatchOffsets, KvCacheManager* kvCacheManager,
|
||||
std::vector<bool>& microBatchesFinished)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto& manager = mRuntime->getBufferManager();
|
||||
|
||||
auto const numMicroBatches = static_cast<SizeType>(microBatches.size());
|
||||
SizeType numBatchesFinished{0};
|
||||
|
||||
auto const flipFlopId = step % 2;
|
||||
for (auto generationBatchId = 0; generationBatchId < numMicroBatches; ++generationBatchId)
|
||||
{
|
||||
if (microBatchesFinished.at(generationBatchId))
|
||||
continue;
|
||||
|
||||
auto& buffers = *mBuffers.at(generationBatchId);
|
||||
auto const& generationConfig = buffers.generationConfig;
|
||||
|
||||
auto const contextId = mMicroBatchConfig.getGenContextId(flipFlopId, generationBatchId);
|
||||
auto& inputBuffer = buffers.inputBuffers[flipFlopId];
|
||||
auto& outputBuffer = buffers.outputBuffers[flipFlopId];
|
||||
|
||||
auto nextInputIds = buffers.prepareNextStep(
|
||||
step - 1, manager, kvCacheManager, microBatchOffsets.at(generationBatchId), mModelConfig, mWorldConfig);
|
||||
buffers.getRuntimeBuffers(inputBuffer, outputBuffer, step, nextInputIds, mCommPtrs, mModelConfig, mWorldConfig);
|
||||
mRuntime->setInputTensors(contextId, inputBuffer);
|
||||
mRuntime->setOutputTensors(contextId, outputBuffer);
|
||||
|
||||
if (useCudaGraphs())
|
||||
{
|
||||
mCudaGraphInstances.at(contextId).prepareNextGraph(*mRuntime, contextId);
|
||||
}
|
||||
|
||||
// check decoder result of previous iteration
|
||||
if (shouldStopSync(generationConfig.batchSize, generationConfig.beamWidth, generationBatchId))
|
||||
{
|
||||
mLogger->log(nvinfer1::ILogger::Severity::kVERBOSE,
|
||||
tc::fmtstr("GPT decoding finished for step %d and microBatchId %d", step, generationBatchId).c_str());
|
||||
microBatchesFinished.at(generationBatchId) = true;
|
||||
numBatchesFinished += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (useCudaGraphs())
|
||||
{
|
||||
auto& cudaGraphInstance = mCudaGraphInstances.at(contextId);
|
||||
TLLM_CHECK(cudaGraphInstance.hasInstance());
|
||||
cudaGraphInstance.launch(mRuntime->getStream());
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mRuntime->executeContext(contextId), tc::fmtstr("Executing TRT engine in step %d failed!", step));
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
|
||||
std::swap(buffers.cacheIndirectionDecoderInput, buffers.cacheIndirectionDecoderOutput);
|
||||
|
||||
auto const decoderStep = generationConfig.maxInputLength + step;
|
||||
decoderStepAsync(decoderStep, generationBatchId);
|
||||
}
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
return numBatchesFinished;
|
||||
}
|
||||
|
||||
void GptSession::decoderStepAsync(SizeType decoderStep, SizeType microBatchId)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
@ -662,7 +792,7 @@ void GptSession::decoderStepAsync(SizeType decoderStep, SizeType microBatchId)
|
||||
mCommStream->record(mReceivedEvents.at(microBatchId).get());
|
||||
}
|
||||
|
||||
if (!mWorldConfig.isPipelineParallel() && mNumMicroBatches > 1)
|
||||
if (!mWorldConfig.isPipelineParallel() && mMicroBatchConfig.numGenBatches > 1)
|
||||
{
|
||||
updateOutputIds(outputIds, newTokens, decoderStep, stream);
|
||||
stream.record(mReceivedEvents.at(microBatchId).get());
|
||||
@ -684,7 +814,7 @@ bool GptSession::shouldStopSync(SizeType batchSize, SizeType beamWidth, SizeType
|
||||
decoder.forwardSync();
|
||||
nbFinished = *bufferCast<SizeType>(*decoder.getNbFinished());
|
||||
|
||||
if (!mWorldConfig.isPipelineParallel() && mNumMicroBatches > 1)
|
||||
if (!mWorldConfig.isPipelineParallel() && mMicroBatchConfig.numGenBatches > 1)
|
||||
{
|
||||
// ensure outputIds have been updated
|
||||
mReceivedEvents.at(microBatchId).synchronize();
|
||||
|
||||
@ -57,6 +57,12 @@ std::string MemoryCounters::bytesToString(DiffType bytes, int precision)
|
||||
return doubleBytesToString(static_cast<double>(bytes), precision);
|
||||
}
|
||||
|
||||
std::string MemoryCounters::toString() const
|
||||
{
|
||||
return tensorrt_llm::common::fmtstr("[MemUsage] GPU %s, CPU %s, Pinned %s", bytesToString(this->getGpu()).c_str(),
|
||||
bytesToString(this->getCpu()).c_str(), bytesToString(this->getPinned()).c_str());
|
||||
}
|
||||
|
||||
void MemoryCounters::allocate(MemoryType memoryType, MemoryCounters::SizeType size)
|
||||
{
|
||||
switch (memoryType)
|
||||
|
||||
92
cpp/tensorrt_llm/runtime/promptTuningParams.cpp
Normal file
92
cpp/tensorrt_llm/runtime/promptTuningParams.cpp
Normal file
@ -0,0 +1,92 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/runtime/promptTuningParams.h"
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
|
||||
void PromptTuningParams::fillTasksTensor(TensorPtr tasksHost, const SizeType batchSize,
|
||||
const SizeType numContextRequests, const std::vector<SizeType>& reqBeamWidths,
|
||||
const std::vector<SizeType>& reqPromptLengths, BufferManager& manager, bool packedInput)
|
||||
{
|
||||
auto const& tasksHostShape = tasksHost->getShape();
|
||||
TLLM_CHECK_WITH_INFO(tasksHostShape.nbDims == 1, "tasksHost expected to have dimension [batchSize]");
|
||||
TLLM_CHECK_WITH_INFO(tasksHostShape.d[0] == batchSize, "tasksHost expected to have dimension [batchSize]");
|
||||
|
||||
auto const tasksHostPtr = bufferCast<SizeType const>(*tasksHost);
|
||||
|
||||
bool validInput = packedInput || numContextRequests == batchSize || numContextRequests == 0;
|
||||
TLLM_CHECK_WITH_INFO(validInput,
|
||||
"fillTasksTensor function with packed inputs must be called with only context requests or only generation "
|
||||
"requests.");
|
||||
|
||||
bool validShapes = (static_cast<SizeType>(reqBeamWidths.size()) == batchSize
|
||||
&& static_cast<SizeType>(reqPromptLengths.size()) == numContextRequests
|
||||
&& static_cast<SizeType>(promptTuningEnabled.size()) == batchSize);
|
||||
TLLM_CHECK_WITH_INFO(validShapes,
|
||||
"Invalid inputs to fillTasksTensor function. reqBeamWidths and reqPtuningEnabled size must be batchSize and "
|
||||
"propmtLenghts size must be numContextRequests");
|
||||
|
||||
SizeType totalInputSize = 0;
|
||||
std::vector<SizeType> promptTasksHost;
|
||||
for (SizeType bid = 0; bid < batchSize; bid++)
|
||||
{
|
||||
SizeType taskId = promptTuningEnabled[bid] ? tasksHostPtr[bid] : 0;
|
||||
if (packedInput)
|
||||
{
|
||||
if (bid < numContextRequests)
|
||||
{
|
||||
totalInputSize += reqPromptLengths[bid];
|
||||
promptTasksHost.insert(promptTasksHost.end(), reqPromptLengths[bid], taskId);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (SizeType beam = 0; beam < reqBeamWidths[bid]; ++beam)
|
||||
{
|
||||
promptTasksHost.insert(promptTasksHost.end(), 1, taskId);
|
||||
totalInputSize++;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (bid < numContextRequests)
|
||||
{
|
||||
promptTasksHost.push_back(taskId);
|
||||
++totalInputSize;
|
||||
}
|
||||
else
|
||||
{
|
||||
promptTasksHost.insert(promptTasksHost.end(), reqBeamWidths[bid], taskId);
|
||||
totalInputSize += reqBeamWidths[bid];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (packedInput)
|
||||
{
|
||||
tasks = manager.copyFrom(
|
||||
promptTasksHost, runtime::ITensor::makeShape({1, totalInputSize}), runtime::MemoryType::kGPU);
|
||||
}
|
||||
else
|
||||
{
|
||||
tasks = manager.copyFrom(
|
||||
promptTasksHost, runtime::ITensor::makeShape({totalInputSize, 1}), runtime::MemoryType::kGPU);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
@ -16,7 +16,6 @@
|
||||
|
||||
#include "tensorrt_llm/runtime/runtimeBuffers.h"
|
||||
|
||||
#include "ipcUtils.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/common/stlUtils.h"
|
||||
#include "tensorrt_llm/runtime/runtimeKernels.h"
|
||||
@ -30,8 +29,7 @@ using namespace tensorrt_llm::runtime;
|
||||
namespace tc = tensorrt_llm::common;
|
||||
|
||||
RuntimeBuffers::GenerationConfig RuntimeBuffers::GenerationConfig::fromInput(ITensor const& inputIds,
|
||||
ITensor const& inputLengthsHost, bool const inputPacked, SizeType const beamWidth, SizeType const maxSequenceLength,
|
||||
std::optional<SizeType> const& maxNewTokensOpt)
|
||||
ITensor const& inputLengthsHost, bool const inputPacked, SizeType const beamWidth, SizeType const maxSequenceLength)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto const batchSize = static_cast<SizeType>(inputLengthsHost.getSize());
|
||||
@ -54,13 +52,12 @@ RuntimeBuffers::GenerationConfig RuntimeBuffers::GenerationConfig::fromInput(ITe
|
||||
maxInputLength = inputShape.d[1];
|
||||
}
|
||||
|
||||
auto const maxNewTokens = maxNewTokensOpt.value_or(maxSequenceLength - maxInputLength);
|
||||
TLLM_CHECK_WITH_INFO(1 <= maxNewTokens && maxNewTokens <= maxSequenceLength - maxInputLength,
|
||||
TLLM_CHECK_WITH_INFO(maxInputLength < maxSequenceLength,
|
||||
"Max input length is equal to or larger that maxSequenceLength given in setup. No new tokens can be "
|
||||
"generated.");
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
return GenerationConfig{batchSize, beamWidth, maxInputLength, maxNewTokens, maxSequenceLength, inputLengthSum};
|
||||
return GenerationConfig{batchSize, beamWidth, maxInputLength, maxSequenceLength, inputLengthSum};
|
||||
}
|
||||
|
||||
void RuntimeBuffers::clear()
|
||||
@ -91,6 +88,16 @@ void RuntimeBuffers::clear()
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::clearTensorMaps()
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
for (auto& buffer : inputBuffers)
|
||||
buffer.clear();
|
||||
for (auto& buffer : outputBuffers)
|
||||
buffer.clear();
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
@ -171,41 +178,19 @@ void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelCon
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::initContextLengths(TensorPtr const& inputLengths, BufferManager& manager)
|
||||
void RuntimeBuffers::initFromInput(ITensor const& inputIds, TensorPtr const& inputLengths, bool inputPacked,
|
||||
SizeType beamWidth, SizeType maxSequenceLength, BufferManager& manager)
|
||||
{
|
||||
contextLengthsDevice = inputLengths;
|
||||
contextLengthsHost->reshape(inputLengths->getShape());
|
||||
manager.copy(*contextLengthsDevice, *contextLengthsHost);
|
||||
manager.getStream().synchronize(); // wait for context lengths to be copied to host
|
||||
|
||||
generationConfig = RuntimeBuffers::GenerationConfig::fromInput(
|
||||
inputIds, *contextLengthsHost, inputPacked, beamWidth, maxSequenceLength);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::createCustomAllReduceWorkspace(SizeType maxBatchSize, SizeType maxBeamWidth,
|
||||
SizeType maxSequenceLength, SizeType hiddenSize, WorldConfig const& worldConfig, BufferManager& manager)
|
||||
{
|
||||
mIpcMemoryHandles.clear();
|
||||
const std::size_t bufferSize = static_cast<std::size_t>(maxBatchSize) * maxBeamWidth * maxSequenceLength
|
||||
* hiddenSize * worldConfig.getTensorParallelism() * sizeof(float);
|
||||
mIpcMemoryHandles.emplace_back(std::make_shared<IpcMemory>(worldConfig, bufferSize));
|
||||
mIpcMemoryHandles.emplace_back(std::make_shared<IpcMemory>(worldConfig, IpcMemory::FLAGS_SIZE * sizeof(int32_t)));
|
||||
mIpcMemoryHandles.emplace_back(std::make_shared<IpcMemory>(worldConfig, IpcMemory::FLAGS_SIZE * sizeof(int32_t)));
|
||||
|
||||
commPtrs = manager.cpu(
|
||||
ITensor::makeShape({static_cast<SizeType>(mIpcMemoryHandles.size()) * worldConfig.getTensorParallelism()}),
|
||||
nvinfer1::DataType::kINT64);
|
||||
const auto commPtrsData = bufferCast<void*>(*commPtrs);
|
||||
|
||||
for (size_t memIdx = 0; memIdx < mIpcMemoryHandles.size(); memIdx++)
|
||||
{
|
||||
const auto& memCommPtrs = mIpcMemoryHandles[memIdx]->getCommPtrsTensor();
|
||||
for (SizeType tpIdx = 0; tpIdx < worldConfig.getTensorParallelism(); tpIdx++)
|
||||
{
|
||||
commPtrsData[memIdx * worldConfig.getTensorParallelism() + tpIdx] = memCommPtrs[tpIdx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RuntimeBuffers::reshape(
|
||||
GenerationConfig const& generationConfig, GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
@ -222,8 +207,10 @@ void RuntimeBuffers::reshape(
|
||||
|
||||
lastTokenIds->reshape(ITensor::makeShape({batchSize}));
|
||||
|
||||
auto kvCacheShape
|
||||
auto kvCacheReserve
|
||||
= ITensor::makeShape({batchSize, 2, modelConfig.getNbKvHeads(), maxSeqLength, modelConfig.getSizePerHead()});
|
||||
auto kvCacheShape
|
||||
= ITensor::makeShape({batchSize, 2, modelConfig.getNbKvHeads(), maxInputLength, modelConfig.getSizePerHead()});
|
||||
if (modelConfig.usePagedKvCache())
|
||||
{
|
||||
auto const localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism());
|
||||
@ -240,7 +227,7 @@ void RuntimeBuffers::reshape(
|
||||
}
|
||||
else
|
||||
{
|
||||
utils::reshapeBufferVector(presentKeysVals, kvCacheShape);
|
||||
utils::reshapeBufferVector(presentKeysVals, kvCacheReserve);
|
||||
}
|
||||
|
||||
if (modelConfig.useGptAttentionPlugin())
|
||||
@ -250,7 +237,10 @@ void RuntimeBuffers::reshape(
|
||||
}
|
||||
else
|
||||
{
|
||||
utils::reshapeBufferVector(presentKeysValsAlt, kvCacheShape);
|
||||
utils::reshapeBufferVector(presentKeysValsAlt, kvCacheReserve);
|
||||
// present KV cache tensors will be reshaped by shape inference.
|
||||
// reshape to the required shape here to make context batch slicing work correctly.
|
||||
utils::reshapeBufferVector(presentKeysVals, kvCacheShape);
|
||||
}
|
||||
|
||||
auto const cacheIndirShape = ITensor::makeShape({batchSize, beamWidth, maxSeqLength});
|
||||
@ -260,9 +250,9 @@ void RuntimeBuffers::reshape(
|
||||
if (worldConfig.isPipelineParallel())
|
||||
{
|
||||
// reserve max size
|
||||
auto const maxNumTokens = std::max(batchSize * beamWidth, batchSize * maxInputLength);
|
||||
auto const maxNumTokens = std::max(beamWidth, maxInputLength);
|
||||
auto const hiddenSize = modelConfig.getHiddenSize() * worldConfig.getTensorParallelism();
|
||||
auto const hiddenStatesShape = ITensor::makeShape({1, maxNumTokens, hiddenSize});
|
||||
auto const hiddenStatesShape = ITensor::makeShape({batchSize, maxNumTokens, hiddenSize});
|
||||
hiddenStates->reshape(hiddenStatesShape);
|
||||
}
|
||||
|
||||
@ -270,8 +260,104 @@ void RuntimeBuffers::reshape(
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::gatherLastTokenLogits(BufferManager& manager, GenerationConfig const& generationConfig,
|
||||
GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
void RuntimeBuffers::reset(BufferManager& manager)
|
||||
{
|
||||
clearTensorMaps();
|
||||
manager.setZero(*cacheIndirectionDecoderInput);
|
||||
manager.setZero(*cacheIndirectionDecoderOutput);
|
||||
}
|
||||
|
||||
std::vector<RuntimeBuffers> RuntimeBuffers::split(
|
||||
SizeType contextBatchSize, GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
std::vector<RuntimeBuffers> bufferSlices;
|
||||
auto const generationBatchSize = generationConfig.batchSize;
|
||||
bufferSlices.reserve(tc::ceilDiv(generationBatchSize, contextBatchSize));
|
||||
if (contextBatchSize >= generationBatchSize)
|
||||
{
|
||||
bufferSlices.emplace_back(*this);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (auto offset = 0; offset < generationBatchSize; offset += contextBatchSize)
|
||||
{
|
||||
auto const batchSize = std::min(contextBatchSize, generationBatchSize - offset);
|
||||
auto& buffers = bufferSlices.emplace_back();
|
||||
buffers.generationConfig = generationConfig;
|
||||
buffers.generationConfig.batchSize = batchSize;
|
||||
|
||||
buffers.contextLengthsHost = ITensor::slice(contextLengthsHost, offset, batchSize);
|
||||
buffers.contextLengthsDevice = ITensor::slice(contextLengthsDevice, offset, batchSize);
|
||||
|
||||
if (worldConfig.isLastPipelineParallelRank() && !modelConfig.computeContextLogits())
|
||||
{
|
||||
buffers.logits = ITensor::slice(logits, offset, batchSize);
|
||||
}
|
||||
|
||||
buffers.lastTokenIds = ITensor::slice(lastTokenIds, offset, batchSize);
|
||||
|
||||
if (modelConfig.usePagedKvCache())
|
||||
{
|
||||
auto const& realCacheBlockPointersShape = kvCacheBlockPointersHost->getShape();
|
||||
auto const localNbLayers = realCacheBlockPointersShape.d[0];
|
||||
auto const maxBlocksPerSeq = realCacheBlockPointersShape.d[3];
|
||||
|
||||
// enable slicing by moving generationBatchSize to first dim
|
||||
auto const fakeCacheBlockPointersShape
|
||||
= ITensor::makeShape({generationBatchSize, localNbLayers, 2, maxBlocksPerSeq});
|
||||
TensorPtr kvCacheBlockPointersHostView{
|
||||
ITensor::view(kvCacheBlockPointersHost, fakeCacheBlockPointersShape)};
|
||||
TensorPtr kvCacheBlockPointersDeviceView{
|
||||
ITensor::view(kvCacheBlockPointersDevice, fakeCacheBlockPointersShape)};
|
||||
|
||||
// slice and reshape to correct shape
|
||||
auto const cacheBlockPointersShape = ITensor::makeShape({localNbLayers, batchSize, 2, maxBlocksPerSeq});
|
||||
buffers.kvCacheBlockPointersHost = ITensor::slice(kvCacheBlockPointersHostView, offset, batchSize);
|
||||
buffers.kvCacheBlockPointersHost->reshape(cacheBlockPointersShape);
|
||||
buffers.kvCacheBlockPointersDevice = ITensor::slice(kvCacheBlockPointersDeviceView, offset, batchSize);
|
||||
buffers.kvCacheBlockPointersDevice->reshape(cacheBlockPointersShape);
|
||||
}
|
||||
else
|
||||
{
|
||||
buffers.presentKeysVals = utils::sliceBufferVector(presentKeysVals, offset, batchSize);
|
||||
}
|
||||
|
||||
if (modelConfig.useGptAttentionPlugin())
|
||||
{
|
||||
buffers.pastKeyValueLengths = ITensor::slice(pastKeyValueLengths, offset, batchSize);
|
||||
buffers.requestTypes = ITensor::slice(requestTypes, offset, batchSize);
|
||||
}
|
||||
else
|
||||
{
|
||||
buffers.presentKeysValsAlt = utils::sliceBufferVector(presentKeysValsAlt, offset, batchSize);
|
||||
}
|
||||
|
||||
if (worldConfig.isPipelineParallel())
|
||||
{
|
||||
buffers.hiddenStates = ITensor::slice(hiddenStates, offset, batchSize);
|
||||
}
|
||||
|
||||
buffers.cacheIndirectionDecoderOutput = ITensor::slice(cacheIndirectionDecoderOutput, offset, batchSize);
|
||||
|
||||
if (modelConfig.usePromptTuning())
|
||||
{
|
||||
auto const& ptuningEnabled = promptTuningParams.promptTuningEnabled;
|
||||
buffers.promptTuningParams.promptTuningEnabled
|
||||
= std::vector<bool>(ptuningEnabled.begin() + offset, ptuningEnabled.begin() + offset + batchSize);
|
||||
|
||||
buffers.promptTuningParams.tasks = ITensor::slice(promptTuningParams.tasks, offset, batchSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
return bufferSlices;
|
||||
}
|
||||
|
||||
void RuntimeBuffers::gatherLastTokenLogits(
|
||||
BufferManager& manager, GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK_WITH_INFO(modelConfig.computeContextLogits(),
|
||||
@ -294,8 +380,29 @@ void RuntimeBuffers::gatherLastTokenLogits(BufferManager& manager, GenerationCon
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::tile(BufferManager& manager, GenerationConfig const& generationConfig,
|
||||
GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
void RuntimeBuffers::copyAttentionMasks(std::vector<RuntimeBuffers> const& contextBatches, BufferManager& manager)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto const batchSize = generationConfig.batchSize;
|
||||
auto const maxInputLength = generationConfig.maxInputLength;
|
||||
|
||||
// TODO(rkobus) include tiling
|
||||
attentionMask = manager.gpu(ITensor::makeShape({batchSize, maxInputLength}), nvinfer1::DataType::kINT32);
|
||||
|
||||
auto const numContextBatches = static_cast<SizeType>(contextBatches.size());
|
||||
auto offset = 0;
|
||||
for (auto contextBatchId = 0; contextBatchId < numContextBatches; ++contextBatchId)
|
||||
{
|
||||
auto& buffers = contextBatches.at(contextBatchId);
|
||||
auto contextBatchSize = buffers.generationConfig.batchSize;
|
||||
auto attentionMaskSlice = ITensor::slice(attentionMask, offset, contextBatchSize);
|
||||
manager.copy(*buffers.attentionMask, *attentionMaskSlice);
|
||||
offset += contextBatchSize;
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::tile(BufferManager& manager, GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto const beamWidth = generationConfig.beamWidth;
|
||||
@ -333,7 +440,7 @@ void RuntimeBuffers::tile(BufferManager& manager, GenerationConfig const& genera
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::postContextStep(BufferManager& manager, GenerationConfig const& generationConfig,
|
||||
void RuntimeBuffers::postContextStep(std::vector<RuntimeBuffers> const& contextBuffers, BufferManager& manager,
|
||||
GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
@ -346,15 +453,22 @@ void RuntimeBuffers::postContextStep(BufferManager& manager, GenerationConfig co
|
||||
auto hostRequestTypes = bufferCast<int32_t>(*requestTypes);
|
||||
std::fill_n(hostRequestTypes, requestTypes->getSize(), 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
copyAttentionMasks(contextBuffers, manager);
|
||||
}
|
||||
|
||||
// TODO(rkobus) handle this more gracefully
|
||||
positionIds = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
|
||||
|
||||
if (modelConfig.computeContextLogits())
|
||||
{
|
||||
gatherLastTokenLogits(manager, generationConfig, modelConfig, worldConfig);
|
||||
gatherLastTokenLogits(manager, modelConfig, worldConfig);
|
||||
}
|
||||
|
||||
if (beamWidth > 1)
|
||||
{
|
||||
tile(manager, generationConfig, modelConfig, worldConfig);
|
||||
tile(manager, modelConfig, worldConfig);
|
||||
}
|
||||
|
||||
// use output lengths after context step
|
||||
@ -371,12 +485,25 @@ void RuntimeBuffers::postContextStep(BufferManager& manager, GenerationConfig co
|
||||
kvCacheBlockPointersHost->reshape(cacheBlockPointersShape);
|
||||
kvCacheBlockPointersDevice->reshape(cacheBlockPointersShape);
|
||||
}
|
||||
|
||||
if (modelConfig.usePromptTuning())
|
||||
{
|
||||
std::vector<SizeType> reqBeamWidths(batchSize, beamWidth);
|
||||
//// Note: reqPromptLenghts won't be used
|
||||
std::vector<SizeType> reqPromptLengths;
|
||||
// Copy the generationInput tasks to host
|
||||
promptTuningTasksHost = manager.copyFrom(*promptTuningParams.tasks, MemoryType::kPINNED);
|
||||
// Update the promptTuningParams tasks tensor
|
||||
promptTuningParams.fillTasksTensor(promptTuningTasksHost, batchSize, 0, reqBeamWidths, reqPromptLengths,
|
||||
manager, modelConfig.usePackedInput());
|
||||
}
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType const padId, BufferManager& manager,
|
||||
KvCacheManager const* kvCacheManager, SizeType firstBatchSlotIdx, GenerationConfig const& generationConfig,
|
||||
GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
KvCacheManager const* kvCacheManager, SizeType firstBatchSlotIdx, GptModelConfig const& modelConfig,
|
||||
WorldConfig const& worldConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto& stream = manager.getStream();
|
||||
@ -391,12 +518,10 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c
|
||||
auto pastKeyValueLengthsPtr = bufferCast<SizeType>(*pastKeyValueLengths);
|
||||
TLLM_CHECK(pastKeyValueLengths->getSize() == static_cast<std::size_t>(batchSize));
|
||||
std::fill_n(pastKeyValueLengthsPtr, batchSize, 0);
|
||||
if (modelConfig.useGptAttentionPlugin())
|
||||
{
|
||||
auto RequestTypesPtr = bufferCast<int32_t>(*requestTypes);
|
||||
TLLM_CHECK(requestTypes->getSize() == static_cast<std::size_t>(batchSize));
|
||||
std::fill_n(RequestTypesPtr, batchSize, 0);
|
||||
}
|
||||
|
||||
auto RequestTypesPtr = bufferCast<int32_t>(*requestTypes);
|
||||
TLLM_CHECK(requestTypes->getSize() == static_cast<std::size_t>(batchSize));
|
||||
std::fill_n(RequestTypesPtr, batchSize, 0);
|
||||
|
||||
auto const& inputShape = inputIds->getShape();
|
||||
auto const contextLengthsHostPtr = bufferCast<SizeType const>(*contextLengthsHost);
|
||||
@ -417,10 +542,19 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c
|
||||
}
|
||||
else if (modelVariant == GptModelConfig::ModelVariant::kGlm)
|
||||
{
|
||||
auto const positionIdsVec = getPositionIdsContextPhaseGlm(
|
||||
batchSize, maxInputLength, contextLengthsHostPtr, modelConfig.useGptAttentionPlugin());
|
||||
auto const positionIdsShape = ITensor::makeShape({batchSize, 2, maxInputLength});
|
||||
positionIds = manager.copyFrom(positionIdsVec, positionIdsShape, MemoryType::kGPU);
|
||||
auto const positionIdsVec = getPositionIdsContextPhaseGlm(batchSize, maxInputLength, contextLengthsHostPtr,
|
||||
modelConfig.useGptAttentionPlugin(), modelConfig.usePackedInput());
|
||||
if (modelConfig.usePackedInput())
|
||||
{
|
||||
int num_tokens = (int) positionIdsVec.size() / 2;
|
||||
auto const positionIdsShape = ITensor::makeShape({1, 2, num_tokens});
|
||||
positionIds = manager.copyFrom(positionIdsVec, positionIdsShape, MemoryType::kGPU);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto const positionIdsShape = ITensor::makeShape({batchSize, 2, maxInputLength});
|
||||
positionIds = manager.copyFrom(positionIdsVec, positionIdsShape, MemoryType::kGPU);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -433,6 +567,23 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c
|
||||
auto const hiddenStatesShape = ITensor::makeShape({inputShape.d[0], inputShape.d[1], hiddenSize});
|
||||
hiddenStates->reshape(hiddenStatesShape);
|
||||
}
|
||||
|
||||
if (modelConfig.usePromptTuning())
|
||||
{
|
||||
std::vector<SizeType> reqBeamWidths(batchSize, 1);
|
||||
std::vector<SizeType> reqPromptLengths;
|
||||
for (SizeType i = 0; i < batchSize; ++i)
|
||||
{
|
||||
reqPromptLengths.push_back(contextLengthsHostPtr[i]);
|
||||
}
|
||||
|
||||
// Copy the generationInput tasks to host
|
||||
promptTuningTasksHost = manager.copyFrom(*promptTuningParams.tasks, MemoryType::kPINNED);
|
||||
|
||||
// Update the tasks tensor
|
||||
promptTuningParams.fillTasksTensor(promptTuningTasksHost, batchSize, batchSize, reqBeamWidths,
|
||||
reqPromptLengths, manager, modelConfig.usePackedInput());
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -470,14 +621,12 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c
|
||||
manager.copy(*contextLengthsDevice, *lastTokenIds);
|
||||
}
|
||||
|
||||
manager.setZero(*cacheIndirectionDecoderInput);
|
||||
manager.setZero(*cacheIndirectionDecoderOutput);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
RuntimeBuffers::TensorPtr RuntimeBuffers::prepareNextStep(SizeType const step, BufferManager& manager,
|
||||
KvCacheManager* kvCacheManager, SizeType firstBatchSlotIdx, GenerationConfig const& generationConfig,
|
||||
GptModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
KvCacheManager* kvCacheManager, SizeType firstBatchSlotIdx, GptModelConfig const& modelConfig,
|
||||
WorldConfig const& worldConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto& stream = manager.getStream();
|
||||
@ -519,10 +668,18 @@ RuntimeBuffers::TensorPtr RuntimeBuffers::prepareNextStep(SizeType const step, B
|
||||
}
|
||||
else if (modelVariant == GptModelConfig::ModelVariant::kGlm)
|
||||
{
|
||||
auto const positionIdsVec = getPositionIdsGenerationPhaseGlm(
|
||||
batchSize, beamWidth, step, contextLengthsHostPtr, modelConfig.useGptAttentionPlugin());
|
||||
auto const positionIdsShape = ITensor::makeShape({batchSize * beamWidth, 2, 1});
|
||||
positionIds = manager.copyFrom(positionIdsVec, positionIdsShape, MemoryType::kGPU);
|
||||
auto const positionIdsVec = getPositionIdsGenerationPhaseGlm(batchSize, beamWidth, step,
|
||||
contextLengthsHostPtr, modelConfig.useGptAttentionPlugin(), modelConfig.usePackedInput());
|
||||
if (modelConfig.usePackedInput())
|
||||
{
|
||||
auto const positionIdsShape = ITensor::makeShape({1, 2, batchSize * beamWidth});
|
||||
positionIds = manager.copyFrom(positionIdsVec, positionIdsShape, MemoryType::kGPU);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto const positionIdsShape = ITensor::makeShape({batchSize * beamWidth, 2, 1});
|
||||
positionIds = manager.copyFrom(positionIdsVec, positionIdsShape, MemoryType::kGPU);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -538,7 +695,7 @@ RuntimeBuffers::TensorPtr RuntimeBuffers::prepareNextStep(SizeType const step, B
|
||||
}
|
||||
else
|
||||
{
|
||||
auto const shape = attentionMask->getShape();
|
||||
auto const& shape = attentionMask->getShape();
|
||||
auto const nbInputs = shape.d[0];
|
||||
auto const oldLength = shape.d[1];
|
||||
auto const newLength = oldLength + 1;
|
||||
@ -583,13 +740,13 @@ RuntimeBuffers::TensorPtr RuntimeBuffers::prepareNextStep(SizeType const step, B
|
||||
{
|
||||
kernels::invokeInclusiveSum(*lastTokenIds, *lastTokenIds, manager, stream);
|
||||
}
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
return nextInputIds;
|
||||
}
|
||||
|
||||
void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outputBuffers, SizeType const step,
|
||||
TensorPtr const& inputIds, GptModelConfig const& modelConfig, WorldConfig const& worldConfig) const
|
||||
TensorPtr const& inputIds, TensorPtr const& commPtrs, GptModelConfig const& modelConfig,
|
||||
WorldConfig const& worldConfig) const
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
inputBuffers.clear();
|
||||
@ -676,49 +833,110 @@ void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outpu
|
||||
{
|
||||
inputBuffers.insert_or_assign("all_reduce_workspace", commPtrs);
|
||||
}
|
||||
|
||||
if (modelConfig.usePromptTuning())
|
||||
{
|
||||
inputBuffers.insert_or_assign("prompt_embedding_table", promptTuningParams.embeddingTable);
|
||||
inputBuffers.insert_or_assign("tasks", promptTuningParams.tasks);
|
||||
inputBuffers.insert_or_assign("prompt_vocab_size", promptTuningParams.vocabSize);
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
std::vector<SizeType> RuntimeBuffers::getPositionIdsContextPhaseGlm(
|
||||
SizeType batchSize, SizeType maxInputLength, SizeType const* pInputLengths, bool useGptAttentionPlugin)
|
||||
std::vector<SizeType> RuntimeBuffers::getPositionIdsContextPhaseGlm(const SizeType& batchSize,
|
||||
const SizeType& maxInputLength, const SizeType* pInputLengths, bool useGptAttentionPlugin, bool usePackedInput)
|
||||
{
|
||||
TLLM_CHECK(pInputLengths != nullptr);
|
||||
|
||||
auto const size = batchSize * 2 * maxInputLength;
|
||||
std::vector<SizeType> positionIdsVec(size, 0);
|
||||
|
||||
for (SizeType b = 0; b < batchSize; ++b)
|
||||
std::vector<SizeType> positionIdsVec(1, 0);
|
||||
if (useGptAttentionPlugin)
|
||||
{
|
||||
auto* pIdB = positionIdsVec.data() + b * 2 * maxInputLength;
|
||||
auto const length = pInputLengths[b];
|
||||
std::iota(pIdB, pIdB + length, 0);
|
||||
|
||||
pIdB[length - 1] = length - 2;
|
||||
pIdB[length - 1 + maxInputLength] = 1;
|
||||
}
|
||||
|
||||
return positionIdsVec;
|
||||
}
|
||||
|
||||
std::vector<SizeType> RuntimeBuffers::getPositionIdsGenerationPhaseGlm(
|
||||
SizeType batchSize, SizeType beamSize, SizeType step, SizeType const* pInputLengths, bool useGptAttentionPlugin)
|
||||
{
|
||||
TLLM_CHECK(pInputLengths != nullptr);
|
||||
|
||||
auto const size = batchSize * beamSize * 2;
|
||||
std::vector<SizeType> positionIdsVec(size, 0);
|
||||
|
||||
for (SizeType b = 0; b < batchSize; ++b)
|
||||
{
|
||||
auto* pIdB = positionIdsVec.data() + b * beamSize * 2;
|
||||
auto const length = pInputLengths[b * beamSize];
|
||||
|
||||
for (SizeType bm = 0; bm < beamSize; ++bm)
|
||||
if (usePackedInput)
|
||||
{
|
||||
pIdB[bm * 2 + 0] = length - 2;
|
||||
pIdB[bm * 2 + 1] = step + 2;
|
||||
std::vector<int> pInputLengthsAcc = std::vector<int>(batchSize + 1, 0);
|
||||
for (int i = 0; i < batchSize; ++i)
|
||||
{
|
||||
pInputLengthsAcc[i + 1] = pInputLengthsAcc[i] + pInputLengths[i];
|
||||
}
|
||||
|
||||
auto const size = 1 * 2 * pInputLengthsAcc[batchSize];
|
||||
positionIdsVec.resize(size, 0);
|
||||
for (SizeType b = 0; b < batchSize; ++b)
|
||||
{
|
||||
auto* pIdB = positionIdsVec.data() + pInputLengthsAcc[b];
|
||||
auto const length = pInputLengths[b];
|
||||
std::iota(pIdB, pIdB + length, 0);
|
||||
|
||||
pIdB[length - 1] = length - 2;
|
||||
pIdB[length - 1 + pInputLengthsAcc[batchSize]] = 1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
auto const size = batchSize * 2 * maxInputLength;
|
||||
positionIdsVec.resize(size, 0);
|
||||
for (SizeType b = 0; b < batchSize; ++b)
|
||||
{
|
||||
auto* pIdB = positionIdsVec.data() + b * 2 * maxInputLength;
|
||||
auto const length = pInputLengths[b];
|
||||
std::iota(pIdB, pIdB + length, 0);
|
||||
|
||||
pIdB[length - 1] = length - 2;
|
||||
pIdB[length - 1 + maxInputLength] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Unsupported model without GPT Attention Plugin");
|
||||
}
|
||||
|
||||
return positionIdsVec;
|
||||
}
|
||||
|
||||
std::vector<SizeType> RuntimeBuffers::getPositionIdsGenerationPhaseGlm(const SizeType& batchSize,
|
||||
const SizeType& beamSize, const SizeType& step, const SizeType* pInputLengths, bool useGptAttentionPlugin,
|
||||
bool usePackedInput)
|
||||
{
|
||||
TLLM_CHECK(pInputLengths != nullptr);
|
||||
|
||||
auto const size = 2 * batchSize * beamSize;
|
||||
std::vector<SizeType> positionIdsVec(size, 0);
|
||||
if (useGptAttentionPlugin)
|
||||
{
|
||||
if (usePackedInput)
|
||||
{
|
||||
for (SizeType b = 0; b < batchSize; ++b)
|
||||
{
|
||||
auto* pIdB = positionIdsVec.data() + b * beamSize * 2;
|
||||
auto const length = pInputLengths[b * beamSize];
|
||||
|
||||
for (SizeType bm = 0; bm < beamSize; ++bm)
|
||||
{
|
||||
pIdB[bm * 2 + 0] = length - 2;
|
||||
pIdB[bm * 2 + 1] = step + 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (SizeType b = 0; b < batchSize; ++b)
|
||||
{
|
||||
auto* pIdB = positionIdsVec.data() + b * beamSize * 2;
|
||||
auto const length = pInputLengths[b * beamSize];
|
||||
|
||||
for (SizeType bm = 0; bm < beamSize; ++bm)
|
||||
{
|
||||
pIdB[bm * 2 + 0] = length - 2;
|
||||
pIdB[bm * 2 + 1] = step + 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Unsupported model without GPT Attention Plugin");
|
||||
}
|
||||
|
||||
return positionIdsVec;
|
||||
}
|
||||
|
||||
@ -19,8 +19,12 @@
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/gptModelConfig.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/promptTuningParams.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
{
|
||||
class KVCacheManager;
|
||||
@ -28,7 +32,6 @@ class KVCacheManager;
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
class IpcMemory;
|
||||
class TllmRuntime;
|
||||
|
||||
class RuntimeBuffers
|
||||
@ -40,11 +43,39 @@ protected:
|
||||
public:
|
||||
using TensorMap = StringPtrMap<ITensor>;
|
||||
|
||||
class GenerationConfig
|
||||
{
|
||||
public:
|
||||
GenerationConfig() = default;
|
||||
|
||||
explicit GenerationConfig(SizeType batchSize, SizeType beamWidth, SizeType maxInputLength,
|
||||
SizeType maxSeqLength, SizeType inputLengthSum = SizeType(0))
|
||||
: batchSize{batchSize}
|
||||
, beamWidth{beamWidth}
|
||||
, maxInputLength{maxInputLength}
|
||||
, maxSeqLength{maxSeqLength}
|
||||
, inputLengthSum{inputLengthSum}
|
||||
{
|
||||
}
|
||||
|
||||
SizeType batchSize{};
|
||||
SizeType beamWidth{};
|
||||
SizeType maxInputLength{};
|
||||
SizeType maxSeqLength{};
|
||||
SizeType inputLengthSum{}; // Initialized only if inputPacked is set to true in fromInput.
|
||||
|
||||
static GenerationConfig fromInput(ITensor const& inputIds, ITensor const& inputLengths, bool inputPacked,
|
||||
SizeType beamWidth, SizeType maxSequenceLength);
|
||||
};
|
||||
|
||||
public:
|
||||
GenerationConfig generationConfig{};
|
||||
std::array<TensorMap, 2> inputBuffers{};
|
||||
std::array<TensorMap, 2> outputBuffers{};
|
||||
|
||||
// general
|
||||
TensorPtr contextLengthsHost;
|
||||
TensorPtr contextLengthsDevice;
|
||||
TensorPtr inputOffsets; // helper for packed input
|
||||
TensorPtr kvCacheBlockPointersHost; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
|
||||
|
||||
// engine
|
||||
TensorPtr logits;
|
||||
@ -57,6 +88,7 @@ public:
|
||||
|
||||
std::vector<TensorPtr> presentKeysVals;
|
||||
std::vector<TensorPtr> presentKeysValsAlt; // without attention plugin
|
||||
TensorPtr kvCacheBlockPointersHost; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
|
||||
TensorPtr kvCacheBlockPointersDevice; // [numLayers, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
|
||||
|
||||
// References to tmp buffers
|
||||
@ -74,82 +106,58 @@ public:
|
||||
// pipeline parallelism
|
||||
TensorPtr hiddenStates;
|
||||
|
||||
// tensor parallelism
|
||||
TensorPtr commPtrs;
|
||||
// Prompt tuning
|
||||
PromptTuningParams promptTuningParams;
|
||||
TensorPtr promptTuningTasksHost; // Tensor to hold tasks on host
|
||||
|
||||
bool allocated{false};
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<IpcMemory>> mIpcMemoryHandles;
|
||||
|
||||
public:
|
||||
class GenerationConfig
|
||||
{
|
||||
public:
|
||||
GenerationConfig() = default;
|
||||
|
||||
GenerationConfig(SizeType batchSize, SizeType beamWidth, SizeType maxInputLength, SizeType maxNewTokens,
|
||||
SizeType maxSeqLength, SizeType inputLengthSum = SizeType(0))
|
||||
: batchSize{batchSize}
|
||||
, beamWidth{beamWidth}
|
||||
, maxInputLength{maxInputLength}
|
||||
, maxNewTokens{maxNewTokens}
|
||||
, maxSeqLength{maxSeqLength}
|
||||
, inputLengthSum{inputLengthSum}
|
||||
{
|
||||
}
|
||||
|
||||
SizeType batchSize{};
|
||||
SizeType beamWidth{};
|
||||
SizeType maxInputLength{};
|
||||
SizeType maxNewTokens{};
|
||||
SizeType maxSeqLength{};
|
||||
SizeType inputLengthSum{}; // Initialized only if inputPacked is set to true in fromInput.
|
||||
|
||||
static RuntimeBuffers::GenerationConfig fromInput(ITensor const& inputIds, ITensor const& inputLengths,
|
||||
bool const inputPacked, SizeType const beamWidth, SizeType const maxSequenceLength,
|
||||
std::optional<SizeType> const& maxNewTokensOpt);
|
||||
};
|
||||
|
||||
public:
|
||||
void clear();
|
||||
void clearTensorMaps();
|
||||
|
||||
void create(TllmRuntime& runtime, GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
|
||||
|
||||
void initContextLengths(TensorPtr const& inputLengths, BufferManager& manager);
|
||||
void initFromInput(ITensor const& inputIds, TensorPtr const& inputLengths, bool inputPacked, SizeType beamWidth,
|
||||
SizeType maxSequenceLength, BufferManager& manager);
|
||||
|
||||
void reshape(
|
||||
GenerationConfig const& generationConfig, GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
|
||||
//! \brief Reshape buffers based on current GenerationConfig
|
||||
void reshape(GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
|
||||
|
||||
void postContextStep(BufferManager& manager, GenerationConfig const& generationConfig,
|
||||
void reset(BufferManager& manager);
|
||||
|
||||
std::vector<RuntimeBuffers> split(
|
||||
SizeType contextBatchSize, GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
|
||||
|
||||
void postContextStep(std::vector<RuntimeBuffers> const& contextBuffers, BufferManager& manager,
|
||||
GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
|
||||
|
||||
void prepareContextStep(TensorPtr const& inputIds, TokenIdType padId, BufferManager& manager,
|
||||
KvCacheManager const* kvCacheManager, SizeType firstBatchSlotIdx, GenerationConfig const& generationConfig,
|
||||
GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
|
||||
TensorPtr prepareNextStep(SizeType step, BufferManager& manager, KvCacheManager* kvCacheManager,
|
||||
SizeType firstBatchSlotIdx, GenerationConfig const& generationConfig, GptModelConfig const& modelConfig,
|
||||
KvCacheManager const* kvCacheManager, SizeType firstBatchSlotIdx, GptModelConfig const& modelConfig,
|
||||
WorldConfig const& worldConfig);
|
||||
TensorPtr prepareNextStep(SizeType step, BufferManager& manager, KvCacheManager* kvCacheManager,
|
||||
SizeType firstBatchSlotIdx, GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
|
||||
|
||||
void getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outputBuffers, SizeType step, TensorPtr const& inputIds,
|
||||
GptModelConfig const& modelConfig, WorldConfig const& worldConfig) const;
|
||||
|
||||
void createCustomAllReduceWorkspace(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength,
|
||||
SizeType hiddenSize, WorldConfig const& worldConfig, BufferManager& manager);
|
||||
void getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outputBuffers, SizeType const step,
|
||||
TensorPtr const& inputIds, TensorPtr const& commPtrs, GptModelConfig const& modelConfig,
|
||||
WorldConfig const& worldConfig) const;
|
||||
|
||||
private:
|
||||
void gatherLastTokenLogits(BufferManager& manager, GenerationConfig const& generationConfig,
|
||||
GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
|
||||
void gatherLastTokenLogits(
|
||||
BufferManager& manager, GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
|
||||
|
||||
void copyAttentionMasks(std::vector<RuntimeBuffers> const& contextBatches, BufferManager& manager);
|
||||
|
||||
// Some tensors are properly tiled, some are just reshaped.
|
||||
void tile(BufferManager& manager, GenerationConfig const& generationConfig, GptModelConfig const& modelConfig,
|
||||
WorldConfig const& worldConfig);
|
||||
void tile(BufferManager& manager, GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
|
||||
|
||||
static std::vector<SizeType> getPositionIdsContextPhaseGlm(
|
||||
SizeType batchSize, SizeType maxInputLength, SizeType const* pInputLengths, bool useGptAttentionPlugin);
|
||||
static std::vector<SizeType> getPositionIdsContextPhaseGlm(const SizeType& batchSize,
|
||||
const SizeType& maxInputLength, const SizeType* pInputLengths, const bool useGptAttentionPlugin,
|
||||
const bool usePackedInput);
|
||||
|
||||
static std::vector<SizeType> getPositionIdsGenerationPhaseGlm(SizeType batchSize, SizeType beamSize, SizeType step,
|
||||
SizeType const* pInputLengths, bool useGptAttentionPlugin);
|
||||
static std::vector<SizeType> getPositionIdsGenerationPhaseGlm(const SizeType& batchSize, const SizeType& beamSize,
|
||||
const SizeType& step, const SizeType* pInputLengths, const bool useGptAttentionPlugin,
|
||||
const bool usePackedInput);
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -747,6 +747,24 @@ void invokeCopyPackedInputToOutput(ITensor& outputIds, ITensor const& inputIds,
|
||||
maxInputLength, maxSeqLength);
|
||||
}
|
||||
|
||||
void initOutputIds(ITensor& outputIds, ITensor const& inputIds, ITensor const& inputLengths,
|
||||
ITensor const& inputOffsets, TokenIdType const padId, TokenIdType const endId, SizeType const maxInputLength,
|
||||
bool const inputPacked, CudaStream const& stream)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
kernels::invokeFill(outputIds, endId, stream);
|
||||
|
||||
if (inputPacked)
|
||||
{
|
||||
kernels::invokeCopyPackedInputToOutput(outputIds, inputIds, inputOffsets, maxInputLength, padId, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
kernels::invokeCopyInputToOutput(outputIds, inputIds, inputLengths, padId, stream);
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
template <typename T>
|
||||
|
||||
@ -68,6 +68,10 @@ void invokeCopyInputToOutput(
|
||||
void invokeCopyPackedInputToOutput(ITensor& outputIds, ITensor const& inputIds, ITensor const& inputOffsets,
|
||||
SizeType maxInputLength, SizeType padId, CudaStream const& stream);
|
||||
|
||||
void initOutputIds(ITensor& outputIds, ITensor const& inputIds, ITensor const& inputLengths,
|
||||
ITensor const& inputOffsets, TokenIdType padId, TokenIdType endId, SizeType maxInputLength, bool inputPacked,
|
||||
CudaStream const& stream);
|
||||
|
||||
void scatterTensor(ITensor& output, ITensor const& input, SizeType beamWidth, CudaStream const& stream);
|
||||
|
||||
void tileTensor(ITensor& output, ITensor const& input, SizeType beamWidth, CudaStream const& stream);
|
||||
|
||||
@ -114,32 +114,10 @@ void StatefulGptDecoder::reshapeBuffers(SizeType batchSize, SizeType beamWidth,
|
||||
dOutput.beamHypotheses.release();
|
||||
}
|
||||
|
||||
mMaxNewTokens = 0;
|
||||
mNbSteps = 0;
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
void initOutputIds(TensorPtr const& outputIds, TensorPtr const& inputIds, TensorPtr const& inputLengths,
|
||||
TensorPtr const& inputOffsets, SizeType const padId, SizeType const endId, SizeType const maxInputLength,
|
||||
bool const inputPacked, CudaStream const& stream)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
kernels::invokeFill(*outputIds, endId, stream);
|
||||
|
||||
if (inputPacked)
|
||||
{
|
||||
kernels::invokeCopyPackedInputToOutput(*outputIds, *inputIds, *inputOffsets, maxInputLength, padId, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
kernels::invokeCopyInputToOutput(*outputIds, *inputIds, *inputLengths, padId, stream);
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void StatefulGptDecoder::newBatch(GenerationInput const& inputs, SamplingConfig const& samplingConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
@ -174,11 +152,6 @@ void StatefulGptDecoder::newBatch(GenerationInput const& inputs, SamplingConfig
|
||||
kernels::invokeInclusiveSum(*ITensor::slice(inputOffsets, 1), *inputLengths, manager, *stream);
|
||||
}
|
||||
|
||||
mMaxNewTokens = inputs.maxNewTokens.value_or(mMaxSequenceLength - maxInputLength);
|
||||
TLLM_CHECK_WITH_INFO(maxInputLength + mMaxNewTokens <= mMaxSequenceLength,
|
||||
tc::fmtstr("Input length (%d) + max new tokens (%d) must be less than max sequence length (%d).",
|
||||
maxInputLength, mMaxNewTokens, mMaxSequenceLength));
|
||||
|
||||
TLLM_CHECK(inputIds->getDataType() == TRTDataType<TokenIdType>::value);
|
||||
auto const endId = inputs.endId;
|
||||
auto const padId = inputs.padId;
|
||||
@ -191,9 +164,21 @@ void StatefulGptDecoder::newBatch(GenerationInput const& inputs, SamplingConfig
|
||||
dInput.embeddingBias = inputs.embeddingBiasOpt;
|
||||
dInput.badWordsList = inputs.badWordsList;
|
||||
dInput.stopWordsList = inputs.stopWordsList;
|
||||
kernels::invokeFill(const_cast<ITensor&>(*dInput.sequenceLimitLength), mMaxSequenceLength, *stream);
|
||||
auto inputLengthsView = ITensor::view(dInput.lengths, ITensor::makeShape({batchSize * beamWidth}));
|
||||
kernels::tileTensor(const_cast<ITensor&>(*inputLengthsView), *inputLengths, beamWidth, *stream);
|
||||
if (inputs.maxNewTokens)
|
||||
{
|
||||
auto const maxNewTokens = inputs.maxNewTokens.value();
|
||||
TLLM_CHECK_WITH_INFO(maxInputLength + maxNewTokens <= mMaxSequenceLength,
|
||||
tc::fmtstr("Input length (%d) + max new tokens (%d) must be less than max sequence length (%d).",
|
||||
maxInputLength, maxNewTokens, mMaxSequenceLength));
|
||||
manager.copy(*inputLengths, const_cast<ITensor&>(*dInput.sequenceLimitLength));
|
||||
kernels::invokeAdd(const_cast<ITensor&>(*dInput.sequenceLimitLength), maxNewTokens, *stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
kernels::invokeFill(const_cast<ITensor&>(*dInput.sequenceLimitLength), mMaxSequenceLength, *stream);
|
||||
}
|
||||
|
||||
// output
|
||||
auto& dOutput = *mDecodingOutput;
|
||||
@ -227,8 +212,8 @@ void StatefulGptDecoder::newBatch(GenerationInput const& inputs, SamplingConfig
|
||||
}
|
||||
|
||||
// copy the request ids into dOutput.ids (with tiling)
|
||||
initOutputIds(
|
||||
dOutput.ids, inputIds, inputLengths, inputOffsets, padId, endId, maxInputLength, inputs.packed, *stream);
|
||||
kernels::initOutputIds(
|
||||
*dOutput.ids, *inputIds, *inputLengths, *inputOffsets, padId, endId, maxInputLength, inputs.packed, *stream);
|
||||
|
||||
// remaining
|
||||
mNbSteps = 0;
|
||||
|
||||
@ -90,6 +90,5 @@ private:
|
||||
|
||||
SizeType mNbSteps;
|
||||
SizeType mMaxSequenceLength{};
|
||||
SizeType mMaxNewTokens;
|
||||
};
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
66
cpp/tensorrt_llm/runtime/utils/debugUtils.cu
Normal file
66
cpp/tensorrt_llm/runtime/utils/debugUtils.cu
Normal file
@ -0,0 +1,66 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "debugUtils.h"
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
__global__ void checkTensorNanKernel(const float* data, std::size_t size, int* foundNan)
|
||||
{
|
||||
auto tidx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
int32_t found = 0;
|
||||
|
||||
for (auto idx = tidx; idx < size; idx += blockDim.x * gridDim.x)
|
||||
{
|
||||
auto value = data[idx];
|
||||
if (isnan(value))
|
||||
{
|
||||
found = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
atomicCAS(foundNan, 0, found);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
namespace tc = tensorrt_llm::common;
|
||||
|
||||
namespace tensorrt_llm::runtime::utils
|
||||
{
|
||||
|
||||
void invokeCheckTensorNanKernel(const float* data, std::size_t size, int* foundNan, cudaStream_t stream)
|
||||
{
|
||||
constexpr uint32_t kThreadsPerCta = 256;
|
||||
checkTensorNanKernel<<<tc::ceilDiv(size, kThreadsPerCta), kThreadsPerCta, 0, stream>>>(data, size, foundNan);
|
||||
}
|
||||
|
||||
bool tensorHasNan(const IBuffer& tensor, BufferManager& manager)
|
||||
{
|
||||
auto foundNan = manager.pinned(ITensor::makeShape({1}), nvinfer1::DataType::kINT32);
|
||||
auto foundNanPtr = bufferCast<int32_t>(*foundNan);
|
||||
foundNanPtr[0] = 0;
|
||||
const auto size = tensor.getSize();
|
||||
invokeCheckTensorNanKernel(bufferCast<float>(tensor), size, foundNanPtr, manager.getStream().get());
|
||||
manager.getStream().synchronize();
|
||||
return static_cast<bool>(foundNanPtr[0]);
|
||||
}
|
||||
} // namespace tensorrt_llm::runtime::utils
|
||||
29
cpp/tensorrt_llm/runtime/utils/debugUtils.h
Normal file
29
cpp/tensorrt_llm/runtime/utils/debugUtils.h
Normal file
@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/runtimeKernels.h"
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
namespace utils
|
||||
{
|
||||
|
||||
bool tensorHasNan(const IBuffer& tensor, BufferManager& manager);
|
||||
|
||||
}
|
||||
} // namespace tensorrt_llm::runtime
|
||||
@ -89,6 +89,13 @@ void reshapeBufferVector(std::vector<ITensor::SharedPtr>& vector, nvinfer1::Dims
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ITensor::SharedPtr> sliceBufferVector(
|
||||
std::vector<ITensor::SharedPtr> const& vector, SizeType const offset, SizeType const size)
|
||||
{
|
||||
return transformVector(
|
||||
vector, [offset, size](auto const& buffer) { return std::shared_ptr{ITensor::slice(buffer, offset, size)}; });
|
||||
}
|
||||
|
||||
void insertTensorVector(StringPtrMap<ITensor>& map, std::string const& key, std::vector<ITensor::SharedPtr> const& vec,
|
||||
SizeType const indexOffset)
|
||||
{
|
||||
|
||||
@ -37,6 +37,16 @@ int initDevice(WorldConfig const& worldConfig);
|
||||
|
||||
std::vector<uint8_t> loadEngine(std::string const& enginePath);
|
||||
|
||||
template <typename TInputContainer, typename TFunc>
|
||||
auto transformVector(TInputContainer const& input, TFunc func)
|
||||
-> std::vector<std::remove_reference_t<decltype(func(input.front()))>>
|
||||
{
|
||||
std::vector<std::remove_reference_t<decltype(func(input.front()))>> output{};
|
||||
output.reserve(input.size());
|
||||
std::transform(input.begin(), input.end(), std::back_inserter(output), func);
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<ITensor::SharedPtr> createBufferVector(TllmRuntime const& runtime, SizeType indexOffset,
|
||||
SizeType numBuffers, std::string const& prefix, MemoryType memType);
|
||||
|
||||
@ -45,6 +55,9 @@ std::vector<ITensor::SharedPtr> createBufferVector(
|
||||
|
||||
void reshapeBufferVector(std::vector<ITensor::SharedPtr>& vector, nvinfer1::Dims const& shape);
|
||||
|
||||
std::vector<ITensor::SharedPtr> sliceBufferVector(
|
||||
std::vector<ITensor::SharedPtr> const& vector, SizeType offset, SizeType size);
|
||||
|
||||
void insertTensorVector(StringPtrMap<ITensor>& map, std::string const& key, std::vector<ITensor::SharedPtr> const& vec,
|
||||
SizeType indexOffset);
|
||||
|
||||
|
||||
@ -21,6 +21,5 @@ target_link_libraries(th_utils PUBLIC ${TORCH_LIBRARIES} ${CUBLAS_LIB}
|
||||
add_library(th_common SHARED dynamicDecodeOp.cpp weightOnlyQuantOp.cpp
|
||||
gatherTreeOp.cpp fp8Op.cpp ncclCommunicatorOp.cpp)
|
||||
set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
target_link_libraries(
|
||||
th_common PRIVATE ${TORCH_LIBRARIES} th_utils ${Python3_LIBRARIES}
|
||||
${STATIC_TARGET} ${UNDEFINED_FLAG})
|
||||
target_link_libraries(th_common PRIVATE ${TORCH_LIBRARIES} th_utils
|
||||
${Python3_LIBRARIES} ${STATIC_TARGET})
|
||||
|
||||
@ -74,6 +74,7 @@ add_gtest(tllmBuffersTest runtime/tllmBuffersTest.cpp)
|
||||
add_gtest(bufferManagerTest runtime/bufferManagerTest.cpp)
|
||||
add_gtest(runtimeKernelTest runtime/runtimeKernelTest.cpp)
|
||||
add_gtest(samplingTest runtime/samplingTest.cpp)
|
||||
add_gtest(iTensorTest runtime/iTensorTest.cpp)
|
||||
add_gtest(torchTest runtime/torchTest.cpp)
|
||||
set(SAMPLING_KERNEL_TEST_SRC
|
||||
kernels/sampling/samplingTest.cpp
|
||||
|
||||
@ -36,7 +36,7 @@ To build the engines from the top-level directory:
|
||||
PYTHONPATH=examples/gpt:$PYTHONPATH python3 cpp/tests/resources/scripts/build_gpt_engines.py
|
||||
PYTHONPATH=examples/gptj:$PYTHONPATH python3 cpp/tests/resources/scripts/build_gptj_engines.py
|
||||
PYTHONPATH=examples/llama:$PYTHONPATH python3 cpp/tests/resources/scripts/build_llama_engines.py
|
||||
PYTHONPATH=examples/CHATGLM6B:$PYTHONPATH python3 cpp/tests/resources/scripts/build_chatglm6b_engines.py
|
||||
PYTHONPATH=examples/chatglm:$PYTHONPATH python3 cpp/tests/resources/scripts/build_chatglm_engines.py
|
||||
```
|
||||
|
||||
It is possible to build engines with tensor and pipeline parallelism for LLaMA using 4 GPUs.
|
||||
@ -53,8 +53,7 @@ End-to-end tests read inputs and expected outputs from Numpy files located at [c
|
||||
PYTHONPATH=examples/gpt:$PYTHONPATH python3 cpp/tests/resources/scripts/generate_expected_gpt_output.py
|
||||
PYTHONPATH=examples/gptj:$PYTHONPATH python3 cpp/tests/resources/scripts/generate_expected_gptj_output.py
|
||||
PYTHONPATH=examples/llama:$PYTHONPATH python3 cpp/tests/resources/scripts/generate_expected_llama_output.py
|
||||
PYTHONPATH=examples/chatglm6b:$PYTHONPATH python3 cpp/tests/resources/scripts/generate_expected_chatglm6b_output.py
|
||||
PYTHONPATH=examples/chatglm2-6b:$PYTHONPATH python3 cpp/tests/resources/scripts/generate_expected_chatglm2-6b_output.py
|
||||
PYTHONPATH=examples/chatglm:$PYTHONPATH python3 cpp/tests/resources/scripts/generate_expected_chatglm_output.py
|
||||
```
|
||||
|
||||
### Generate data with tensor and pipeline parallelism
|
||||
|
||||
@ -1,112 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse as _arg
|
||||
import os as _os
|
||||
import pathlib as _pl
|
||||
import subprocess as _sp
|
||||
import sys
|
||||
import typing as _tp
|
||||
from glob import glob as _glob
|
||||
|
||||
import torch.multiprocessing as _mp
|
||||
|
||||
resources_dir = _pl.Path(
|
||||
__file__).parent.parent.parent.parent.parent / "examples/chatglm6b"
|
||||
sys.path.insert(0, str(resources_dir))
|
||||
|
||||
engine_target_path = _pl.Path(
|
||||
__file__).parent.parent / "models/rt_engine/chatglm6b"
|
||||
|
||||
import build as _ecb
|
||||
|
||||
|
||||
def build_engine(weight_dir: _pl.Path, engine_dir: _pl.Path, world_size, *args):
|
||||
args = [
|
||||
'--log_level=error',
|
||||
'--model_dir',
|
||||
str(weight_dir),
|
||||
'--output_dir',
|
||||
str(engine_dir),
|
||||
'--max_batch_size=2',
|
||||
'--max_beam_width=2',
|
||||
'--builder_opt=0',
|
||||
f'--world_size={world_size}',
|
||||
] + list(args)
|
||||
print("Running: " + " ".join(args))
|
||||
_ecb.run_build(args)
|
||||
|
||||
|
||||
def run_command(command: _tp.Sequence[str], *, cwd=None, **kwargs) -> None:
|
||||
|
||||
command = [str(i) for i in command]
|
||||
print(f"Running: cd %s && %s" %
|
||||
(str(cwd or _pl.Path.cwd()), " ".join(command)))
|
||||
_sp.check_call(command, cwd=cwd, **kwargs)
|
||||
|
||||
|
||||
def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1):
|
||||
|
||||
# Clone the model directory
|
||||
hf_dir = resources_dir / "pyTorchModel"
|
||||
trt_dir = resources_dir / "trtModel"
|
||||
|
||||
run_command(
|
||||
["pip", "install", "-r",
|
||||
str(resources_dir) + "/requirements.txt"],
|
||||
cwd=resources_dir)
|
||||
|
||||
if not _os.path.exists(hf_dir):
|
||||
_os.mkdir(hf_dir)
|
||||
|
||||
if len(_glob(str(hf_dir) + "/*")) == 0:
|
||||
run_command(
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
"https://huggingface.co/THUDM/chatglm-6b",
|
||||
hf_dir,
|
||||
],
|
||||
cwd=resources_dir,
|
||||
)
|
||||
|
||||
print("\nBuilding engine")
|
||||
build_engine(hf_dir, trt_dir, world_size, "--dtype", "float16",
|
||||
"--use_gpt_attention_plugin", "float16", "--use_gemm_plugin",
|
||||
"float16")
|
||||
|
||||
if not _os.path.exists(str(engine_target_path)):
|
||||
_os.system(f"mkdir -p {str(engine_target_path)}")
|
||||
|
||||
_os.system(f"cp -r {str(trt_dir) + '/*'} {engine_target_path}")
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = _arg.ArgumentParser()
|
||||
parser.add_argument("--model_cache",
|
||||
type=str,
|
||||
help="Directory where models are stored")
|
||||
|
||||
parser.add_argument('--world_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='world size, only support tensor parallelism now')
|
||||
|
||||
_mp.set_start_method("spawn")
|
||||
|
||||
build_engines(**vars(parser.parse_args()))
|
||||
@ -15,27 +15,31 @@
|
||||
# limitations under the License.
|
||||
|
||||
import argparse as _arg
|
||||
import os as _os
|
||||
import pathlib as _pl
|
||||
import shutil as _shutil
|
||||
import subprocess as _sp
|
||||
import sys
|
||||
import typing as _tp
|
||||
from glob import glob as _glob
|
||||
from collections import OrderedDict as _OrderedDict
|
||||
from pathlib import Path as _Path
|
||||
|
||||
import torch.multiprocessing as _mp
|
||||
|
||||
resources_dir = _pl.Path(
|
||||
__file__).parent.parent.parent.parent.parent / "examples/chatglm2-6b"
|
||||
__file__).parent.parent.parent.parent.parent / "examples/chatglm"
|
||||
sys.path.insert(0, str(resources_dir))
|
||||
|
||||
engine_target_path = _pl.Path(
|
||||
__file__).parent.parent / "models/rt_engine/chatglm2-6b"
|
||||
__file__).parent.parent / "models/rt_engine/chatglm"
|
||||
|
||||
import build as _ecb
|
||||
|
||||
|
||||
def build_engine(weight_dir: _pl.Path, engine_dir: _pl.Path, world_size, *args):
|
||||
def build_engine(model_version: str, weight_dir: _pl.Path, engine_dir: _pl.Path,
|
||||
world_size, *args):
|
||||
args = [
|
||||
'-m',
|
||||
str(model_version),
|
||||
'--log_level=error',
|
||||
'--model_dir',
|
||||
str(weight_dir),
|
||||
@ -60,8 +64,14 @@ def run_command(command: _tp.Sequence[str], *, cwd=None, **kwargs) -> None:
|
||||
|
||||
def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1):
|
||||
|
||||
# Clone the model directory
|
||||
hf_dir = resources_dir / "pyTorchModel"
|
||||
model_name_dict = _OrderedDict([
|
||||
["chatglm-6b", "1"],
|
||||
["chatglm2-6b", "2"],
|
||||
["chatglm3-6b", "3"],
|
||||
])
|
||||
hf_dir_list = [
|
||||
resources_dir / model_name for model_name in model_name_dict.keys()
|
||||
]
|
||||
trt_dir = resources_dir / "trtModel"
|
||||
|
||||
run_command(
|
||||
@ -69,29 +79,27 @@ def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1):
|
||||
str(resources_dir) + "/requirements.txt"],
|
||||
cwd=resources_dir)
|
||||
|
||||
if not _os.path.exists(hf_dir):
|
||||
_os.mkdir(hf_dir)
|
||||
# Clone the model directory
|
||||
for model_name, hf_dir in zip(model_name_dict.keys(), hf_dir_list):
|
||||
if not _Path(hf_dir).exists():
|
||||
run_command(
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
"https://huggingface.co/THUDM/" + model_name,
|
||||
],
|
||||
cwd=resources_dir,
|
||||
)
|
||||
|
||||
if len(_glob(str(hf_dir) + "/*")) == 0:
|
||||
run_command(
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
"https://huggingface.co/THUDM/chatglm2-6b",
|
||||
hf_dir,
|
||||
],
|
||||
cwd=resources_dir,
|
||||
)
|
||||
print("\nBuilding engines")
|
||||
for model, hf_dir in zip(model_name_dict.items(), hf_dir_list):
|
||||
print("Building %s" % model[0])
|
||||
build_engine(model[1], hf_dir, trt_dir, world_size)
|
||||
|
||||
print("\nBuilding engine")
|
||||
build_engine(hf_dir, trt_dir, world_size, "--dtype", "float16",
|
||||
"--use_gpt_attention_plugin", "float16", "--use_gemm_plugin",
|
||||
"float16")
|
||||
|
||||
if not _os.path.exists(str(engine_target_path)):
|
||||
_os.system(f"mkdir -p {str(engine_target_path)}")
|
||||
|
||||
_os.system(f"cp -r {str(trt_dir) + '/*'} {engine_target_path}")
|
||||
if not _Path(engine_target_path).exists():
|
||||
_Path(engine_target_path).mkdir(parents=True, exist_ok=True)
|
||||
for file in _Path(trt_dir).glob("*"):
|
||||
_shutil.move(file, engine_target_path)
|
||||
|
||||
print("Done.")
|
||||
|
||||
@ -1,163 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import pathlib as _pl
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm.quantization import QuantMode
|
||||
from tensorrt_llm.runtime import GenerationSession, ModelConfig, SamplingConfig
|
||||
|
||||
resources_dir = _pl.Path(
|
||||
__file__).parent.parent.parent.parent.parent / "examples/chatglm2-6b"
|
||||
sys.path.insert(0, str(resources_dir))
|
||||
|
||||
from run import parse_arguments # isort:skip
|
||||
|
||||
from build import find_engines # isort:skip
|
||||
|
||||
MODEL_NAME = "chatglm2-6b"
|
||||
|
||||
|
||||
def generate(batch_size, beam_width):
|
||||
|
||||
print("generate expected ChatGLM2-6B output BatchSize=%d, BeamWidth=%d" %
|
||||
(batch_size, beam_width))
|
||||
args = parse_arguments()
|
||||
if batch_size == 1:
|
||||
args.input_text = args.input_text[:1]
|
||||
elif batch_size > 2:
|
||||
args.input_text += args.input_text[0] * (batch_size - 2)
|
||||
args.beam_width = beam_width
|
||||
args.tokenizer_dir = resources_dir / "pyTorchModel"
|
||||
args.engine_dir = _pl.Path(
|
||||
__file__).parent.parent / "models/rt_engine/chatglm2-6b"
|
||||
|
||||
tensorrt_llm.logger.set_level(args.log_level)
|
||||
|
||||
config_path = os.path.join(args.engine_dir, 'config.json')
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
assert (config['builder_config']['name'] == MODEL_NAME)
|
||||
dtype = config['builder_config']['precision']
|
||||
end_id = config['builder_config']['eos_token_id']
|
||||
pad_id = config['builder_config']['pad_token_id']
|
||||
use_gpt_attention_plugin = config['plugin_config']['gpt_attention_plugin']
|
||||
world_size = config['builder_config']['tensor_parallel']
|
||||
assert world_size == tensorrt_llm.mpi_world_size(
|
||||
), f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'
|
||||
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
runtime_mapping = tensorrt_llm.Mapping(world_size,
|
||||
runtime_rank,
|
||||
tp_size=world_size)
|
||||
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
|
||||
|
||||
serialize_path = find_engines(Path(args.engine_dir),
|
||||
dtype=dtype,
|
||||
tp_size=world_size,
|
||||
rank=runtime_rank)[0]
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_dir, trust_remote_code=True)
|
||||
input_text = args.input_text
|
||||
tokenized = tokenizer(input_text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
return_length=True)
|
||||
input_ids = tokenized['input_ids'].int().contiguous().cuda()
|
||||
input_lengths = tokenized['length'].int().contiguous().cuda()
|
||||
|
||||
if use_gpt_attention_plugin:
|
||||
# when using gpt attention plugin, inputs needs to align at the head
|
||||
input_ids_padding_right = torch.zeros_like(input_ids) + end_id
|
||||
for i, sample in enumerate(input_ids):
|
||||
nPadding = 0
|
||||
for token in sample:
|
||||
if token == pad_id:
|
||||
nPadding += 1
|
||||
else:
|
||||
break
|
||||
input_ids_padding_right[
|
||||
i, :len(sample[nPadding:])] = sample[nPadding:]
|
||||
input_ids = input_ids_padding_right
|
||||
|
||||
model_config = ModelConfig(
|
||||
vocab_size=config['builder_config']['vocab_size'],
|
||||
num_layers=config['builder_config']['num_layers'],
|
||||
num_heads=config['builder_config']['num_heads'] // world_size,
|
||||
num_kv_heads=config['builder_config']['num_kv_heads'] // world_size,
|
||||
hidden_size=config['builder_config']['hidden_size'] // world_size,
|
||||
gpt_attention_plugin=use_gpt_attention_plugin,
|
||||
remove_input_padding=config['builder_config']['remove_input_padding'],
|
||||
model_name=MODEL_NAME,
|
||||
paged_kv_cache=config['builder_config']['paged_kv_cache'],
|
||||
quant_mode=QuantMode(config['builder_config']['quant_mode']),
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
sampling_config = SamplingConfig(
|
||||
end_id=end_id,
|
||||
pad_id=pad_id,
|
||||
num_beams=args.beam_width,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
)
|
||||
sampling_config.random_seed = args.random_seed
|
||||
|
||||
with open(serialize_path, 'rb') as f:
|
||||
engine_buffer = f.read()
|
||||
decoder = GenerationSession(
|
||||
model_config,
|
||||
engine_buffer,
|
||||
runtime_mapping,
|
||||
)
|
||||
decoder.setup(input_ids.size(0), input_ids.size(1), args.max_output_len,
|
||||
args.beam_width)
|
||||
output_ids = decoder.decode(input_ids, input_lengths, sampling_config)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
data_path = _pl.Path(__file__).parent.parent / "data/chatglm2-6b"
|
||||
if not os.path.exists(str(data_path)):
|
||||
os.mkdir(data_path)
|
||||
nBS, nBM = input_ids.size(0), args.beam_width
|
||||
np.save(
|
||||
str(data_path) + "/inputId-BS%d-BM%d.npy" % (nBS, nBM),
|
||||
input_ids.detach().cpu().numpy())
|
||||
outputId = output_ids.detach().cpu().numpy()
|
||||
|
||||
nMaxOutputLength = 0
|
||||
for single_output in outputId.reshape(nBS * nBM, -1):
|
||||
nMaxOutputLength = max(nMaxOutputLength,
|
||||
np.min(np.where(single_output == end_id)))
|
||||
np.save(
|
||||
str(data_path) + "/outputId-BS%d-BM%d.npy" % (nBS, nBM),
|
||||
outputId[:, :, :(nMaxOutputLength + 1)])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
generate(batch_size=1, beam_width=1)
|
||||
generate(batch_size=2, beam_width=1)
|
||||
generate(batch_size=1, beam_width=2)
|
||||
print("Finish!")
|
||||
@ -15,9 +15,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import pathlib as _pl
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@ -26,40 +25,45 @@ import transformers
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm.quantization import QuantMode
|
||||
from tensorrt_llm.runtime import (ChatGLM6BHeadModelGenerationSession,
|
||||
from tensorrt_llm.runtime import (ChatGLMGenerationSession, GenerationSession,
|
||||
ModelConfig, SamplingConfig)
|
||||
|
||||
resources_dir = _pl.Path(
|
||||
__file__).parent.parent.parent.parent.parent / "examples/chatglm6b"
|
||||
resources_dir = Path(
|
||||
__file__).parent.parent.parent.parent.parent / "examples/chatglm"
|
||||
sys.path.insert(0, str(resources_dir))
|
||||
|
||||
from run import parse_arguments # isort:skip
|
||||
|
||||
from build import find_engines # isort:skip
|
||||
|
||||
MODEL_NAME = "chatglm-6b"
|
||||
|
||||
def generate(model_name, batch_size, beam_width):
|
||||
|
||||
def generate(batch_size, beam_width):
|
||||
model_name_dict = OrderedDict([
|
||||
["chatglm-6b", "1"],
|
||||
["chatglm2-6b", "2"],
|
||||
["chatglm3-6b", "3"],
|
||||
])
|
||||
|
||||
print("generate expected %s output BatchSize=%d, BeamWidth=%d" %
|
||||
(model_name, batch_size, beam_width))
|
||||
|
||||
print("generate expected ChatGLM-6B output BatchSize=%d, BeamWidth=%d" %
|
||||
(batch_size, beam_width))
|
||||
args = parse_arguments()
|
||||
if batch_size == 1:
|
||||
args.input_text = args.input_text[:1]
|
||||
elif batch_size > 2:
|
||||
args.input_text += args.input_text[0] * (batch_size - 2)
|
||||
args.model_version = model_name_dict[model_name]
|
||||
args.beam_width = beam_width
|
||||
args.tokenizer_dir = resources_dir / "pyTorchModel"
|
||||
args.engine_dir = _pl.Path(
|
||||
__file__).parent.parent / "models/rt_engine/chatglm6b"
|
||||
args.tokenizer_dir = resources_dir / model_name
|
||||
args.engine_dir = Path(__file__).parent.parent / "models/rt_engine/chatglm"
|
||||
|
||||
tensorrt_llm.logger.set_level(args.log_level)
|
||||
|
||||
config_path = os.path.join(args.engine_dir, 'config.json')
|
||||
config_path = Path(args.engine_dir) / (model_name + '-config.json')
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
assert (config['builder_config']['name'] == MODEL_NAME)
|
||||
assert (config['builder_config']['name'] == model_name)
|
||||
dtype = config['builder_config']['precision']
|
||||
end_id = config['builder_config']['eos_token_id']
|
||||
pad_id = config['builder_config']['pad_token_id']
|
||||
@ -74,10 +78,13 @@ def generate(batch_size, beam_width):
|
||||
tp_size=world_size)
|
||||
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
|
||||
|
||||
serialize_path = find_engines(Path(args.engine_dir),
|
||||
dtype=dtype,
|
||||
tp_size=world_size,
|
||||
rank=runtime_rank)[0]
|
||||
serialize_path = find_engines(
|
||||
Path(args.engine_dir),
|
||||
model_name=model_name,
|
||||
dtype=dtype,
|
||||
tp_size=world_size,
|
||||
rank=runtime_rank,
|
||||
)[0]
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_dir, trust_remote_code=True)
|
||||
@ -111,7 +118,7 @@ def generate(batch_size, beam_width):
|
||||
hidden_size=config['builder_config']['hidden_size'] // world_size,
|
||||
gpt_attention_plugin=use_gpt_attention_plugin,
|
||||
remove_input_padding=config['builder_config']['remove_input_padding'],
|
||||
model_name=MODEL_NAME,
|
||||
model_name=model_name,
|
||||
paged_kv_cache=config['builder_config']['paged_kv_cache'],
|
||||
quant_mode=QuantMode(config['builder_config']['quant_mode']),
|
||||
dtype=dtype,
|
||||
@ -129,19 +136,25 @@ def generate(batch_size, beam_width):
|
||||
|
||||
with open(serialize_path, 'rb') as f:
|
||||
engine_buffer = f.read()
|
||||
decoder = ChatGLM6BHeadModelGenerationSession(
|
||||
model_config,
|
||||
engine_buffer,
|
||||
runtime_mapping,
|
||||
)
|
||||
if model_name == 'chatglm-6b':
|
||||
decoder = ChatGLMGenerationSession(
|
||||
model_config,
|
||||
engine_buffer,
|
||||
runtime_mapping,
|
||||
)
|
||||
else:
|
||||
decoder = GenerationSession(
|
||||
model_config,
|
||||
engine_buffer,
|
||||
runtime_mapping,
|
||||
)
|
||||
decoder.setup(input_ids.size(0), input_ids.size(1), args.max_output_len,
|
||||
args.beam_width)
|
||||
output_ids = decoder.decode(input_ids, input_lengths, sampling_config)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
data_path = _pl.Path(__file__).parent.parent / "data/chatglm6b"
|
||||
if not os.path.exists(str(data_path)):
|
||||
os.mkdir(data_path)
|
||||
data_path = Path(__file__).parent.parent / "data" / model_name
|
||||
data_path.mkdir(parents=True, exist_ok=True)
|
||||
nBS, nBM = input_ids.size(0), args.beam_width
|
||||
np.save(
|
||||
str(data_path) + "/inputId-BS%d-BM%d.npy" % (nBS, nBM),
|
||||
@ -150,15 +163,23 @@ def generate(batch_size, beam_width):
|
||||
|
||||
nMaxOutputLength = 0
|
||||
for single_output in outputId.reshape(nBS * nBM, -1):
|
||||
nMaxOutputLength = max(nMaxOutputLength,
|
||||
np.min(np.where(single_output == end_id)))
|
||||
if end_id in single_output:
|
||||
nMaxOutputLength = max(nMaxOutputLength,
|
||||
np.min(np.where(single_output == end_id)))
|
||||
else:
|
||||
nMaxOutputLength = len(single_output)
|
||||
np.save(
|
||||
str(data_path) + "/outputId-BS%d-BM%d.npy" % (nBS, nBM),
|
||||
outputId[:, :, :(nMaxOutputLength + 1)])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
generate(batch_size=1, beam_width=1)
|
||||
generate(batch_size=2, beam_width=1)
|
||||
generate(batch_size=1, beam_width=2)
|
||||
print("Finish!")
|
||||
generate("chatglm-6b", batch_size=1, beam_width=1)
|
||||
generate("chatglm-6b", batch_size=2, beam_width=1)
|
||||
generate("chatglm2-6b", batch_size=1, beam_width=1)
|
||||
generate("chatglm2-6b", batch_size=2, beam_width=1)
|
||||
generate("chatglm2-6b", batch_size=1, beam_width=2)
|
||||
generate("chatglm3-6b", batch_size=1, beam_width=1)
|
||||
generate("chatglm3-6b", batch_size=2, beam_width=1)
|
||||
generate("chatglm3-6b", batch_size=1, beam_width=2)
|
||||
print("Done.")
|
||||
@ -88,8 +88,7 @@ def run_tests(cuda_architectures: _tp.Optional[str] = None,
|
||||
model_cache: _tp.Optional[str] = None,
|
||||
skip_gptj=False,
|
||||
skip_llama=False,
|
||||
skip_chatglm6b=False,
|
||||
skip_chatglm2_6b=False,
|
||||
skip_chatglm=False,
|
||||
only_fp8=False,
|
||||
only_multi_gpu=False,
|
||||
trt_root: _tp.Optional[str] = None) -> None:
|
||||
@ -117,15 +116,13 @@ def run_tests(cuda_architectures: _tp.Optional[str] = None,
|
||||
model_cache=model_cache,
|
||||
skip_gptj=skip_gptj,
|
||||
skip_llama=skip_llama,
|
||||
skip_chatglm6b=skip_chatglm6b,
|
||||
skip_chatglm2_6b=skip_chatglm2_6b,
|
||||
skip_chatglm=skip_chatglm,
|
||||
only_fp8=only_fp8)
|
||||
|
||||
run_google_tests(build_dir=build_dir,
|
||||
skip_gptj=skip_gptj,
|
||||
skip_llama=skip_llama,
|
||||
skip_chatglm6b=skip_chatglm6b,
|
||||
skip_chatglm2_6b=skip_chatglm2_6b,
|
||||
skip_chatglm=skip_chatglm,
|
||||
only_fp8=only_fp8)
|
||||
|
||||
run_benchmarks(python_exe=python_exe,
|
||||
@ -147,8 +144,7 @@ def prepare_all_model_tests(python_exe: str,
|
||||
model_cache: _tp.Optional[str] = None,
|
||||
skip_gptj=False,
|
||||
skip_llama=False,
|
||||
skip_chatglm6b=False,
|
||||
skip_chatglm2_6b=False,
|
||||
skip_chatglm=False,
|
||||
only_fp8=False):
|
||||
model_cache_arg = ["--model_cache", model_cache] if model_cache else []
|
||||
only_fp8_arg = ["--only_fp8"] if only_fp8 else []
|
||||
@ -178,21 +174,13 @@ def prepare_all_model_tests(python_exe: str,
|
||||
else:
|
||||
_log.info("Skipping Lllama tests")
|
||||
|
||||
if not skip_chatglm6b:
|
||||
prepare_model_tests(model_name="chatglm6b",
|
||||
if not skip_chatglm:
|
||||
prepare_model_tests(model_name="chatglm",
|
||||
python_exe=python_exe,
|
||||
root_dir=root_dir,
|
||||
resources_dir=resources_dir)
|
||||
else:
|
||||
_log.info("Skipping ChatGLM6B tests")
|
||||
|
||||
if not skip_chatglm2_6b:
|
||||
prepare_model_tests(model_name="chatglm2-6b",
|
||||
python_exe=python_exe,
|
||||
root_dir=root_dir,
|
||||
resources_dir=resources_dir)
|
||||
else:
|
||||
_log.info("Skipping ChatGLM2-6B tests")
|
||||
_log.info("Skipping ChatGLM tests")
|
||||
|
||||
|
||||
def prepare_multi_gpu_model_tests(python_exe: str,
|
||||
@ -231,13 +219,17 @@ def prepare_model_tests(model_name: str,
|
||||
str(scripts_dir / f"generate_expected_{model_name}_output.py")
|
||||
] + only_fp8_arg + only_multi_gpu_arg
|
||||
if only_multi_gpu_arg:
|
||||
generate_expected_output = ["mpirun", "-n", "4"
|
||||
] + generate_expected_output
|
||||
generate_expected_output = [
|
||||
"mpirun",
|
||||
"-n",
|
||||
"4",
|
||||
"--allow-run-as-root",
|
||||
] + generate_expected_output
|
||||
run_command(generate_expected_output, cwd=root_dir, env=model_env)
|
||||
|
||||
|
||||
def run_google_tests(build_dir: _pl.Path, skip_gptj, skip_llama, skip_chatglm6b,
|
||||
skip_chatglm2_6b, only_fp8):
|
||||
def run_google_tests(build_dir: _pl.Path, skip_gptj, skip_llama, skip_chatglm,
|
||||
only_fp8):
|
||||
make_google_tests = [
|
||||
"cmake", "--build", ".", "--config", "Release", "-j", "--target",
|
||||
"google-tests"
|
||||
@ -245,16 +237,14 @@ def run_google_tests(build_dir: _pl.Path, skip_gptj, skip_llama, skip_chatglm6b,
|
||||
run_command(make_google_tests, cwd=build_dir)
|
||||
|
||||
cpp_env = {**_os.environ}
|
||||
ctest = ["ctest", "--output-on-failure", "--output-junit", "report.xml"]
|
||||
ctest = ["ctest", "--output-on-failure", "--output-junit", "results.xml"]
|
||||
excluded_tests = []
|
||||
if skip_gptj:
|
||||
excluded_tests.append(".*Gptj.*")
|
||||
if skip_llama:
|
||||
excluded_tests.append(".*Llama.*")
|
||||
if skip_chatglm6b:
|
||||
excluded_tests.append(".*Glm6.*")
|
||||
if skip_chatglm2_6b:
|
||||
excluded_tests.append(".*Glm2_6.*")
|
||||
if skip_chatglm:
|
||||
excluded_tests.append(".*ChatGlm.*")
|
||||
if only_fp8:
|
||||
ctest.extend(["-R", ".*FP8.*"])
|
||||
else:
|
||||
@ -274,7 +264,8 @@ def run_multi_gpu_tests(build_dir: _pl.Path):
|
||||
tests_dir = build_dir / "tests"
|
||||
cpp_env = {**_os.environ}
|
||||
session_test = [
|
||||
"mpirun", "-n", "4", "gptSessionTest", "--gtest_filter=*TP*:*PP*"
|
||||
"mpirun", "-n", "4", "--allow-run-as-root", "gptSessionTest",
|
||||
"--gtest_filter=*TP*:*PP*"
|
||||
]
|
||||
run_command(session_test, cwd=tests_dir, env=cpp_env)
|
||||
|
||||
@ -358,12 +349,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--skip_llama",
|
||||
action="store_true",
|
||||
help="Skip the tests for Llama")
|
||||
parser.add_argument("--skip_chatglm6b",
|
||||
parser.add_argument("--skip_chatglm",
|
||||
action="store_true",
|
||||
help="Skip the tests for ChatGLM6B")
|
||||
parser.add_argument("--skip_chatglm2_6b",
|
||||
action="store_true",
|
||||
help="Skip the tests for ChatGLM2-6B")
|
||||
help="Skip the tests for ChatGLM")
|
||||
parser.add_argument(
|
||||
"--only_fp8",
|
||||
action="store_true",
|
||||
|
||||
@ -148,6 +148,12 @@ public:
|
||||
int mTPSize;
|
||||
bool mRandomEndId;
|
||||
};
|
||||
|
||||
struct MicroBatchSizes
|
||||
{
|
||||
std::optional<SizeType> ctxMicroBatchSize{std::nullopt};
|
||||
std::optional<SizeType> genMicroBatchSize{std::nullopt};
|
||||
};
|
||||
} // namespace
|
||||
|
||||
class SessionTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-type-member-init)
|
||||
@ -183,7 +189,7 @@ void verifyModelConfig(GptModelConfig const& modelConfig, ModelSpec const& model
|
||||
|
||||
void testGptSession(fs::path const& modelPath, ModelSpec const& modelSpec, ModelIds const modelIds, SizeType beamWidth,
|
||||
std::initializer_list<int> const& batchSizes, fs::path const& resultsFile,
|
||||
std::shared_ptr<nvinfer1::ILogger> const& logger, bool cudaGraphMode, SizeType numMicroBatches)
|
||||
std::shared_ptr<nvinfer1::ILogger> const& logger, bool cudaGraphMode, MicroBatchSizes microBatchSizes)
|
||||
{
|
||||
auto manager = BufferManager(std::make_shared<CudaStream>());
|
||||
|
||||
@ -275,7 +281,8 @@ void testGptSession(fs::path const& modelPath, ModelSpec const& modelSpec, Model
|
||||
auto const maxBatchSize = *std::max_element(batchSizes.begin(), batchSizes.end());
|
||||
GptSession::Config sessionConfig{maxBatchSize, beamWidth, maxSeqLength};
|
||||
sessionConfig.decoderPerRequest = modelSpec.mDecoderPerRequest;
|
||||
sessionConfig.numMicroBatches = numMicroBatches;
|
||||
sessionConfig.ctxMicroBatchSize = microBatchSizes.ctxMicroBatchSize;
|
||||
sessionConfig.genMicroBatchSize = microBatchSizes.genMicroBatchSize;
|
||||
sessionConfig.cudaGraphMode = cudaGraphMode;
|
||||
|
||||
GptSession session{sessionConfig, modelConfig, worldConfig, enginePath.string(), logger};
|
||||
@ -327,6 +334,7 @@ void testGptSession(fs::path const& modelPath, ModelSpec const& modelSpec, Model
|
||||
|
||||
GenerationInput generationInput{
|
||||
endId, padId, std::move(inputIds), std::move(inputLenghts), modelConfig.usePackedInput()};
|
||||
generationInput.maxNewTokens = maxNewTokens;
|
||||
|
||||
// runtime will allocate memory for output if this tensor is empty
|
||||
GenerationOutput generationOutput{bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32),
|
||||
@ -338,11 +346,19 @@ void testGptSession(fs::path const& modelPath, ModelSpec const& modelSpec, Model
|
||||
{
|
||||
SizeType numSteps = 0;
|
||||
generationOutput.onTokenGenerated
|
||||
= [&numSteps, &modelSpec, maxNewTokens]([[maybe_unused]] GenerationOutput::TensorPtr const& outputIds,
|
||||
[[maybe_unused]] SizeType step, bool finished)
|
||||
= [&numSteps, &modelSpec, maxNewTokens](
|
||||
[[maybe_unused]] GenerationOutput::TensorPtr const& outputIds, SizeType step, bool finished)
|
||||
{
|
||||
// check that we execute the callback in each step
|
||||
EXPECT_EQ(step, numSteps);
|
||||
++numSteps;
|
||||
EXPECT_TRUE(!finished || modelSpec.mRandomEndId || numSteps == maxNewTokens);
|
||||
if (!modelSpec.mRandomEndId)
|
||||
{
|
||||
// check that we only finish after producing `maxNewTokens` tokens
|
||||
EXPECT_TRUE(!finished || numSteps == maxNewTokens);
|
||||
}
|
||||
// check that `finished` is set to true after producing `maxNewTokens` tokens
|
||||
EXPECT_TRUE(numSteps != maxNewTokens || finished);
|
||||
};
|
||||
|
||||
session.generate(generationOutput, generationInput, samplingConfig);
|
||||
@ -416,7 +432,7 @@ void testGptSession(fs::path const& modelPath, ModelSpec const& modelSpec, Model
|
||||
|
||||
auto constexpr kBatchSizes = {1, 8};
|
||||
|
||||
using ParamType = std::tuple<ModelParams, ModelSpec, SizeType, bool, SizeType>;
|
||||
using ParamType = std::tuple<ModelParams, ModelSpec, SizeType, bool, MicroBatchSizes>;
|
||||
|
||||
std::string generateTestName(const testing::TestParamInfo<ParamType>& info)
|
||||
{
|
||||
@ -434,9 +450,11 @@ std::string generateTestName(const testing::TestParamInfo<ParamType>& info)
|
||||
name.append("DecoderBatch");
|
||||
if (std::get<3>(info.param))
|
||||
name.append("CudaGraph");
|
||||
auto const numMicroBatches = std::get<4>(info.param);
|
||||
if (numMicroBatches > 1)
|
||||
name.append("MicroBatch" + std::to_string(numMicroBatches));
|
||||
auto const microBatcheSizes = std::get<4>(info.param);
|
||||
if (microBatcheSizes.ctxMicroBatchSize)
|
||||
name.append("CBS" + std::to_string(microBatcheSizes.ctxMicroBatchSize.value()));
|
||||
if (microBatcheSizes.genMicroBatchSize)
|
||||
name.append("GBS" + std::to_string(microBatcheSizes.genMicroBatchSize.value()));
|
||||
if (modelSpec.mPPSize > 1)
|
||||
name.append("PP" + std::to_string(modelSpec.mPPSize));
|
||||
if (modelSpec.mTPSize > 1)
|
||||
@ -458,10 +476,8 @@ TEST_P(ParamTest, Test)
|
||||
auto const modelIds = modelParams.ids;
|
||||
auto const modelSpec = std::get<1>(GetParam());
|
||||
SizeType const beamWidth{std::get<2>(GetParam())};
|
||||
auto const resultsPath
|
||||
= DATA_PATH / modelDir / ((beamWidth == 1) ? "sampling" : "beam_search_" + std::to_string(beamWidth));
|
||||
fs::path const resultsFile{resultsPath / modelSpec.mResultsFile};
|
||||
auto const numMicroBatches = std::get<4>(GetParam());
|
||||
auto const cudaGraphMode = std::get<3>(GetParam());
|
||||
auto const microBatchSizes = std::get<4>(GetParam());
|
||||
|
||||
if (!modelSpec.mUseGptAttentionPlugin && beamWidth > 1)
|
||||
GTEST_SKIP();
|
||||
@ -485,10 +501,12 @@ TEST_P(ParamTest, Test)
|
||||
std::ostringstream gpuSizePath;
|
||||
gpuSizePath << "tp" << modelSpec.mTPSize << "-pp" << modelSpec.mPPSize << "-gpu";
|
||||
auto const modelPath{ENGINGE_PATH / modelDir / modelSpec.mModelPath / gpuSizePath.str()};
|
||||
auto const cudaGraphMode = std::get<3>(GetParam());
|
||||
auto const resultsPath
|
||||
= DATA_PATH / modelDir / ((beamWidth == 1) ? "sampling" : "beam_search_" + std::to_string(beamWidth));
|
||||
fs::path const resultsFile{resultsPath / modelSpec.mResultsFile};
|
||||
|
||||
testGptSession(
|
||||
modelPath, modelSpec, modelIds, beamWidth, kBatchSizes, resultsFile, mLogger, cudaGraphMode, numMicroBatches);
|
||||
modelPath, modelSpec, modelIds, beamWidth, kBatchSizes, resultsFile, mLogger, cudaGraphMode, microBatchSizes);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(GptSessionTest, ParamTest,
|
||||
@ -535,7 +553,8 @@ INSTANTIATE_TEST_SUITE_P(GptSessionTest, ParamTest,
|
||||
.usePagedKvCache()
|
||||
.useDecoderPerRequest()
|
||||
.useRandomEndId()),
|
||||
testing::Values(1, 2), testing::Values(false, true), testing::Values(1, 3)),
|
||||
testing::Values(1, 2), testing::Values(false, true),
|
||||
testing::Values(MicroBatchSizes(), MicroBatchSizes{3, 3}, MicroBatchSizes{3, 6})),
|
||||
generateTestName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(GptjSessionTest, ParamTest,
|
||||
@ -568,7 +587,7 @@ INSTANTIATE_TEST_SUITE_P(GptjSessionTest, ParamTest,
|
||||
.useDecoderPerRequest()
|
||||
|
||||
),
|
||||
testing::Values(1, 2), testing::Values(false), testing::Values(1)),
|
||||
testing::Values(1, 2), testing::Values(false), testing::Values(MicroBatchSizes())),
|
||||
generateTestName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(LlamaSessionTest, ParamTest,
|
||||
@ -611,7 +630,7 @@ INSTANTIATE_TEST_SUITE_P(LlamaSessionTest, ParamTest,
|
||||
.useTensorParallelism(2)
|
||||
|
||||
),
|
||||
testing::Values(1, 2), testing::Values(false), testing::Values(1)),
|
||||
testing::Values(1, 2), testing::Values(false), testing::Values(MicroBatchSizes())),
|
||||
generateTestName);
|
||||
|
||||
class LlamaSessionOnDemandTest : public SessionTest
|
||||
@ -632,7 +651,8 @@ TEST_F(LlamaSessionOnDemandTest, SamplingFP16WithAttentionPlugin)
|
||||
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
|
||||
auto const modeIds = ModelIds{2, 2};
|
||||
|
||||
testGptSession(modelPath, modelSpec, modeIds, beamWidth, batchSizes, resultsFile, mLogger, false, 1);
|
||||
testGptSession(
|
||||
modelPath, modelSpec, modeIds, beamWidth, batchSizes, resultsFile, mLogger, false, MicroBatchSizes());
|
||||
}
|
||||
|
||||
TEST_F(LlamaSessionOnDemandTest, SamplingFP16AttentionPluginDecoderBatch)
|
||||
@ -648,28 +668,34 @@ TEST_F(LlamaSessionOnDemandTest, SamplingFP16AttentionPluginDecoderBatch)
|
||||
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin().usePackedInput().useDecoderPerRequest();
|
||||
auto const modeIds = ModelIds{2, 2};
|
||||
|
||||
testGptSession(modelPath, modelSpec, modeIds, beamWidth, batchSizes, resultsFile, mLogger, false, 1);
|
||||
testGptSession(
|
||||
modelPath, modelSpec, modeIds, beamWidth, batchSizes, resultsFile, mLogger, false, MicroBatchSizes());
|
||||
}
|
||||
|
||||
class Glm6bSessionTest : public SessionTest
|
||||
class ChatGlmSessionTest : public SessionTest // for ChatGLM-6B
|
||||
{
|
||||
};
|
||||
|
||||
class Glm2_6bSessionTest : public SessionTest
|
||||
class ChatGlm2SessionTest : public SessionTest // for ChatGLM2-6B and ChatGLM2-6B-32k
|
||||
{
|
||||
};
|
||||
|
||||
// Engines need to be generated using cpp/tests/resources/scripts/build_gpt_engines.py.
|
||||
// Expected outputs need to be generated using cpp/tests/resources/scripts/generate_expected_gpt_output.py.
|
||||
class ChatGlm3SessionTest : public SessionTest // for ChatGLM3-6B and ChatGLM3-6B-32k
|
||||
{
|
||||
};
|
||||
|
||||
// Engines need to be generated using cpp/tests/resources/scripts/build_chatglm_engines.py.
|
||||
// Expected outputs need to be generated using cpp/tests/resources/scripts/generate_expected_chatglm_output.py.
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
// TODO: consolidate this function with testGptSession
|
||||
// Notice: both ChatGLM-6B and ChatGLM2-6B use this function, which are different at GptModelConfig::ModelVariant
|
||||
void testGlm6bSession(fs::path const& modelPath, std::string const& modelName, ModelSpec const& modelSpec,
|
||||
// Notice: all ChatGLM models (ChatGLM-6B, ChatGLM2-6B, ChatGLM3-6B, ChatGLM2-6B-32k and ChatGLM3-6B-32k) use this
|
||||
// function The differences are GptModelConfig::ModelVariant
|
||||
void testChatGlmSession(fs::path const& modelPath, std::string const& modelName, ModelSpec const& modelSpec,
|
||||
ModelIds const modelIds, SizeType beamWidth, std::initializer_list<int> const& batchSizes,
|
||||
std::shared_ptr<nvinfer1::ILogger> const& logger, bool cudaGraphMode, SizeType numMicroBatches)
|
||||
std::shared_ptr<nvinfer1::ILogger> const& logger, bool cudaGraphMode, MicroBatchSizes microBatchSizes)
|
||||
{
|
||||
auto manager = BufferManager(std::make_shared<CudaStream>());
|
||||
|
||||
@ -692,7 +718,7 @@ void testGlm6bSession(fs::path const& modelPath, std::string const& modelName, M
|
||||
auto const expectedOutputData = bufferCast<TokenIdType const>(*expectedOutput);
|
||||
|
||||
ASSERT_TRUE(fs::exists(modelPath));
|
||||
auto const json = GptJsonConfig::parse(modelPath / "config.json");
|
||||
auto const json = GptJsonConfig::parse(modelPath / (modelName + "-config.json"));
|
||||
auto const modelConfig = json.getModelConfig();
|
||||
verifyModelConfig(modelConfig, modelSpec);
|
||||
auto const decoderPerRequest = modelSpec.mDecoderPerRequest;
|
||||
@ -728,9 +754,9 @@ void testGlm6bSession(fs::path const& modelPath, std::string const& modelName, M
|
||||
auto const maxBatchSize = *std::max_element(batchSizes.begin(), batchSizes.end());
|
||||
GptSession::Config sessionConfig{maxBatchSize, beamWidth, maxSeqLength};
|
||||
sessionConfig.decoderPerRequest = decoderPerRequest;
|
||||
sessionConfig.numMicroBatches = numMicroBatches;
|
||||
sessionConfig.ctxMicroBatchSize = microBatchSizes.ctxMicroBatchSize;
|
||||
sessionConfig.genMicroBatchSize = microBatchSizes.genMicroBatchSize;
|
||||
sessionConfig.cudaGraphMode = cudaGraphMode;
|
||||
|
||||
GptSession session{sessionConfig, modelConfig, worldConfig, enginePath.string(), logger};
|
||||
EXPECT_EQ(session.getDevice(), worldConfig.getDevice());
|
||||
// Use bufferManager for copying data to and from the GPU
|
||||
@ -837,62 +863,74 @@ void testGlm6bSession(fs::path const& modelPath, std::string const& modelName, M
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST_F(Glm6bSessionTest, SamplingFP16WithGptAttentionPluginBS1BM1)
|
||||
TEST_F(ChatGlmSessionTest, SamplingFP16WithGptAttentionPluginBS1BM1)
|
||||
{
|
||||
auto const modelName{"chatglm6b"};
|
||||
auto const modelPath{ENGINGE_PATH / modelName};
|
||||
auto const modelName{"chatglm-6b"};
|
||||
auto const modelPath{ENGINGE_PATH / "chatglm"};
|
||||
auto const batchSizes = {1};
|
||||
auto constexpr dtype = nvinfer1::DataType::kHALF;
|
||||
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
|
||||
auto const modeIds = ModelIds{130005, 130005};
|
||||
|
||||
testGlm6bSession(modelPath, modelName, modelSpec, modeIds, 1, batchSizes, mLogger, false, 1);
|
||||
testChatGlmSession(modelPath, modelName, modelSpec, modeIds, 1, batchSizes, mLogger, false, MicroBatchSizes());
|
||||
}
|
||||
|
||||
TEST_F(Glm6bSessionTest, SamplingFP16WithGptAttentionPluginBS2BM1)
|
||||
TEST_F(ChatGlmSessionTest, SamplingFP16WithGptAttentionPluginBS2BM1)
|
||||
{
|
||||
auto const modelName{"chatglm6b"};
|
||||
auto const modelPath{ENGINGE_PATH / modelName};
|
||||
auto const modelName{"chatglm-6b"};
|
||||
auto const modelPath{ENGINGE_PATH / "chatglm"};
|
||||
auto const batchSizes = {2};
|
||||
auto constexpr dtype = nvinfer1::DataType::kHALF;
|
||||
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
|
||||
auto const modeIds = ModelIds{130005, 130005};
|
||||
|
||||
testGlm6bSession(modelPath, modelName, modelSpec, modeIds, 1, batchSizes, mLogger, false, 1);
|
||||
testChatGlmSession(modelPath, modelName, modelSpec, modeIds, 1, batchSizes, mLogger, false, MicroBatchSizes());
|
||||
}
|
||||
|
||||
TEST_F(Glm2_6bSessionTest, SamplingFP16WithGptAttentionPluginBS1BM1)
|
||||
TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS1BM1)
|
||||
{
|
||||
auto const modelName{"chatglm2-6b"};
|
||||
auto const modelPath{ENGINGE_PATH / modelName};
|
||||
auto const modelPath{ENGINGE_PATH / "chatglm"};
|
||||
auto const batchSizes = {1};
|
||||
auto constexpr dtype = nvinfer1::DataType::kHALF;
|
||||
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
|
||||
auto const modeIds = ModelIds{2, 2};
|
||||
|
||||
testGlm6bSession(modelPath, modelName, modelSpec, modeIds, 1, batchSizes, mLogger, false, 1);
|
||||
testChatGlmSession(modelPath, modelName, modelSpec, modeIds, 1, batchSizes, mLogger, false, MicroBatchSizes());
|
||||
}
|
||||
|
||||
TEST_F(Glm2_6bSessionTest, SamplingFP16WithGptAttentionPluginBS2BM1)
|
||||
TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS2BM1)
|
||||
{
|
||||
auto const modelName{"chatglm2-6b"};
|
||||
auto const modelPath{ENGINGE_PATH / modelName};
|
||||
auto const modelPath{ENGINGE_PATH / "chatglm"};
|
||||
auto const batchSizes = {2};
|
||||
auto constexpr dtype = nvinfer1::DataType::kHALF;
|
||||
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
|
||||
auto const modeIds = ModelIds{2, 2};
|
||||
|
||||
testGlm6bSession(modelPath, modelName, modelSpec, modeIds, 1, batchSizes, mLogger, false, 1);
|
||||
testChatGlmSession(modelPath, modelName, modelSpec, modeIds, 1, batchSizes, mLogger, false, MicroBatchSizes());
|
||||
}
|
||||
|
||||
TEST_F(Glm2_6bSessionTest, SamplingFP16WithGptAttentionPluginBS1BM2)
|
||||
TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS1BM2)
|
||||
{
|
||||
auto const modelName{"chatglm2-6b"};
|
||||
auto const modelPath{ENGINGE_PATH / modelName};
|
||||
auto const modelPath{ENGINGE_PATH / "chatglm"};
|
||||
auto const batchSizes = {1};
|
||||
auto constexpr dtype = nvinfer1::DataType::kHALF;
|
||||
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
|
||||
auto const modeIds = ModelIds{2, 2};
|
||||
|
||||
testGlm6bSession(modelPath, modelName, modelSpec, modeIds, 2, batchSizes, mLogger, false, 1);
|
||||
testChatGlmSession(modelPath, modelName, modelSpec, modeIds, 2, batchSizes, mLogger, false, MicroBatchSizes());
|
||||
}
|
||||
|
||||
TEST_F(ChatGlm3SessionTest, SamplingFP16WithGptAttentionPluginBS1BM1)
|
||||
{
|
||||
auto const modelName{"chatglm3-6b"};
|
||||
auto const modelPath{ENGINGE_PATH / "chatglm"};
|
||||
auto const batchSizes = {1};
|
||||
auto constexpr dtype = nvinfer1::DataType::kHALF;
|
||||
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
|
||||
auto const modeIds = ModelIds{2, 2};
|
||||
|
||||
testChatGlmSession(modelPath, modelName, modelSpec, modeIds, 1, batchSizes, mLogger, false, MicroBatchSizes());
|
||||
}
|
||||
|
||||
146
cpp/tests/runtime/iTensorTest.cpp
Normal file
146
cpp/tests/runtime/iTensorTest.cpp
Normal file
@ -0,0 +1,146 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
using namespace ::testing;
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
TEST(iTensorTest, UnsqueezeShape)
|
||||
{
|
||||
auto oldShape = ITensor::makeShape({2, 3, 4, 5});
|
||||
{
|
||||
auto shape = ITensor::unsqueeze(oldShape, 0);
|
||||
|
||||
EXPECT_EQ(shape.nbDims, 5);
|
||||
EXPECT_EQ(shape.d[0], 1);
|
||||
EXPECT_EQ(shape.d[1], 2);
|
||||
EXPECT_EQ(shape.d[2], 3);
|
||||
EXPECT_EQ(shape.d[3], 4);
|
||||
EXPECT_EQ(shape.d[4], 5);
|
||||
}
|
||||
{
|
||||
auto shape = ITensor::unsqueeze(oldShape, 1);
|
||||
|
||||
EXPECT_EQ(shape.nbDims, 5);
|
||||
EXPECT_EQ(shape.d[0], 2);
|
||||
EXPECT_EQ(shape.d[1], 1);
|
||||
EXPECT_EQ(shape.d[2], 3);
|
||||
EXPECT_EQ(shape.d[3], 4);
|
||||
EXPECT_EQ(shape.d[4], 5);
|
||||
}
|
||||
|
||||
{
|
||||
auto shape = ITensor::unsqueeze(oldShape, 4);
|
||||
|
||||
EXPECT_EQ(shape.nbDims, 5);
|
||||
EXPECT_EQ(shape.d[0], 2);
|
||||
EXPECT_EQ(shape.d[1], 3);
|
||||
EXPECT_EQ(shape.d[2], 4);
|
||||
EXPECT_EQ(shape.d[3], 5);
|
||||
EXPECT_EQ(shape.d[4], 1);
|
||||
}
|
||||
|
||||
std::vector<int> invalidDims{-1, 5, 10};
|
||||
for (auto invalidDim : invalidDims)
|
||||
{
|
||||
try
|
||||
{
|
||||
auto shape = ITensor::unsqueeze(oldShape, invalidDim);
|
||||
FAIL() << "Expected failure";
|
||||
}
|
||||
catch (tensorrt_llm::common::TllmException const& e)
|
||||
{
|
||||
EXPECT_THAT(e.what(), testing::HasSubstr("Invalid dim"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Expected TllmException";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(iTensorTest, UnsqueezeTensor)
|
||||
{
|
||||
auto oldShape = ITensor::makeShape({2, 3, 4, 5});
|
||||
BufferManager manager(std::make_shared<CudaStream>());
|
||||
|
||||
{
|
||||
auto tensor = manager.cpu(oldShape, nvinfer1::DataType::kINT32);
|
||||
tensor->unsqueeze(0);
|
||||
auto shape = tensor->getShape();
|
||||
|
||||
EXPECT_EQ(shape.nbDims, 5);
|
||||
EXPECT_EQ(shape.d[0], 1);
|
||||
EXPECT_EQ(shape.d[1], 2);
|
||||
EXPECT_EQ(shape.d[2], 3);
|
||||
EXPECT_EQ(shape.d[3], 4);
|
||||
EXPECT_EQ(shape.d[4], 5);
|
||||
}
|
||||
{
|
||||
auto tensor = manager.cpu(oldShape, nvinfer1::DataType::kINT32);
|
||||
tensor->unsqueeze(1);
|
||||
auto shape = tensor->getShape();
|
||||
|
||||
EXPECT_EQ(shape.nbDims, 5);
|
||||
EXPECT_EQ(shape.d[0], 2);
|
||||
EXPECT_EQ(shape.d[1], 1);
|
||||
EXPECT_EQ(shape.d[2], 3);
|
||||
EXPECT_EQ(shape.d[3], 4);
|
||||
EXPECT_EQ(shape.d[4], 5);
|
||||
}
|
||||
|
||||
{
|
||||
auto tensor = manager.cpu(oldShape, nvinfer1::DataType::kINT32);
|
||||
tensor->unsqueeze(4);
|
||||
auto shape = tensor->getShape();
|
||||
|
||||
EXPECT_EQ(shape.nbDims, 5);
|
||||
EXPECT_EQ(shape.d[0], 2);
|
||||
EXPECT_EQ(shape.d[1], 3);
|
||||
EXPECT_EQ(shape.d[2], 4);
|
||||
EXPECT_EQ(shape.d[3], 5);
|
||||
EXPECT_EQ(shape.d[4], 1);
|
||||
}
|
||||
|
||||
std::vector<int> invalidDims{-1, 5, 10};
|
||||
for (auto invalidDim : invalidDims)
|
||||
{
|
||||
try
|
||||
{
|
||||
auto tensor = manager.cpu(oldShape, nvinfer1::DataType::kINT32);
|
||||
tensor->unsqueeze(invalidDim);
|
||||
FAIL() << "Expected failure";
|
||||
}
|
||||
catch (tensorrt_llm::common::TllmException const& e)
|
||||
{
|
||||
EXPECT_THAT(e.what(), testing::HasSubstr("Invalid dim"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Expected TllmException";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -94,17 +94,29 @@ The statistics are packaged as a JSON string. That string contains the following
|
||||
* `Active Request Count`, the number of active requests in batch manager
|
||||
* `Max Request Count`, the max number of requests batch manager can support at a time
|
||||
|
||||
When using in-flight batching, the following additional statistics are reported:
|
||||
When using paged KV cache, following statistics are reported:
|
||||
* `Max KV cache blocks`, the maximum number of KV cache blocks per GPU
|
||||
* `Free KV cache blocks`, number of free KV cache blocks per GPU
|
||||
* `Used KV cache blocks`, number of used KV cache blocks per GPU
|
||||
* `Tokens per KV cache block`, number of tokens per KV cache block
|
||||
* `Scheduled Requests`, number of requests scheduled this iteration
|
||||
|
||||
When using in-flight batching, the following additional statistics are reported per step/iteration:
|
||||
|
||||
* `Scheduled Requests`, number of total requests scheduled
|
||||
* `Context Requests`, number of requests in Context phase
|
||||
* `Total Context Tokens`, total number of tokens across requests in context phase
|
||||
* `Generation Requests`, number of requests in Context phase
|
||||
* `Generation Requests`, number of requests in Generation phase
|
||||
* `MicroBatch ID`, number of requests in Generation phase
|
||||
* `Total Context Tokens`, total number of tokens across requests in context phase
|
||||
* `MicroBatch ID`, micro batch ID
|
||||
|
||||
When using V1 batching, the following additional statistics are reported per V1 iteration:
|
||||
|
||||
* `Scheduled Requests`, number of total requests scheduled
|
||||
* `Context Requests`, number of requests in Context phase
|
||||
* `Total Generation Tokens`, Total number of tokens generated
|
||||
* `Total Context Tokens`, total number of tokens across requests in context phase
|
||||
* `Empty Generation Slots`, total number of padded Slots during generation phase
|
||||
|
||||
|
||||
### GptManager Design
|
||||
|
||||
|
||||
@ -266,7 +266,7 @@ second one contains `[9, 2]` and the third one is composed of tokens `[6, 2, 4,
|
||||
1]`. In total, there are 9 tokens. That's the length. The shape of the tensor
|
||||
is `[2, 9]`. The first row of the tensor must contain the 9 token IDs and the
|
||||
second row must store the
|
||||
[exclusive prefix-sum](https://en.wikipedia.org/wiki/Prefix_sum)
|
||||
[inclusive prefix-sum](https://en.wikipedia.org/wiki/Prefix_sum)
|
||||
of the word lengths as shown on the following diagram:
|
||||
|
||||
```
|
||||
@ -274,7 +274,7 @@ of the word lengths as shown on the following diagram:
|
||||
| | | |
|
||||
V V V V
|
||||
[ 5, 7, 3, 9, 2, 6, 2, 4, 1]
|
||||
[ 0, 3, 5, 9, -1, -1, -1, -1, -1]
|
||||
[ 3, 5, 9, -1, -1, -1, -1, -1, -1]
|
||||
```
|
||||
|
||||
In case all the words are made of a single token, the inner-most dimension of
|
||||
|
||||
@ -114,23 +114,26 @@ GPT-J and LLaMA. Those examples can be found in
|
||||
|
||||
This release of TensorRT-LLM contains the following examples:
|
||||
|
||||
| Model | FP32 | FP16 | BF16 | FP8 | W8A8 SQ | W8A16 | W4A16 | W4A16 AWQ | W4A16 GPTQ |
|
||||
| :-------------------------- | :--: | :--: | :--: | :--: | :-----: | :---: | :---: | :-------: | :--------: |
|
||||
| Baichuan | Y | Y | Y | . | . | Y | Y | . | . |
|
||||
| BERT | Y | Y | Y | . | . | . | . | . | . |
|
||||
| BLOOM | Y | Y | Y | . | Y | Y | Y | . | . |
|
||||
| ChatGLM | Y | Y | Y | . | . | . | . | . | . |
|
||||
| ChatGLM-v2 | Y | Y | Y | . | . | . | . | . | . |
|
||||
| Falcon | Y | Y | Y | . | . | . | . | . | . |
|
||||
| GPT | Y | Y | Y | Y | Y | Y | Y | . | . |
|
||||
| GPT-J | Y | Y | Y | Y | Y | Y | Y | Y | . |
|
||||
| GPT-NeMo | Y | Y | Y | . | . | . | . | . | . |
|
||||
| GPT-NeoX | Y | Y | Y | . | . | . | . | . | Y |
|
||||
| LLaMA | Y | Y | Y | . | Y | Y | Y | Y | Y |
|
||||
| LLaMA-v2 | Y | Y | Y | Y | Y | Y | Y | Y | Y |
|
||||
| OPT | Y | Y | Y | . | . | . | . | . | . |
|
||||
| SantaCoder | Y | Y | Y | . | . | . | . | . | . |
|
||||
| StarCoder | Y | Y | Y | . | . | . | . | . | . |
|
||||
| Model | FP32 | FP16 | BF16 | FP8 | W8A8 SQ | W8A16 | W4A16 | W4A16 AWQ | W4A16 GPTQ |
|
||||
| :--------- | :---: | :---: | :---: | :---: | :-----: | :---: | :---: | :-------: | :--------: |
|
||||
| Baichuan | Y | Y | Y | . | Y | Y | Y | . | . |
|
||||
| BERT | Y | Y | Y | . | . | . | . | . | . |
|
||||
| BLOOM | Y | Y | Y | . | Y | Y | Y | . | . |
|
||||
| ChatGLM | Y | Y | Y | . | . | . | . | . | . |
|
||||
| ChatGLM-v2 | Y | Y | Y | . | . | . | . | . | . |
|
||||
| ChatGLM-v3 | Y | Y | Y | . | . | . | . | . | . |
|
||||
| Falcon | Y | Y | Y | . | . | . | . | . | . |
|
||||
| GPT | Y | Y | Y | Y | Y | Y | Y | . | . |
|
||||
| GPT-J | Y | Y | Y | Y | Y | Y | Y | Y | . |
|
||||
| GPT-NeMo | Y | Y | Y | . | . | . | . | . | . |
|
||||
| GPT-NeoX | Y | Y | Y | . | . | . | . | . | Y |
|
||||
| LLaMA | Y | Y | Y | . | Y | Y | Y | Y | Y |
|
||||
| LLaMA-v2 | Y | Y | Y | Y | Y | Y | Y | Y | Y |
|
||||
| OPT | Y | Y | Y | . | . | . | . | . | . |
|
||||
| SantaCoder | Y | Y | Y | . | . | . | . | . | . |
|
||||
| StarCoder | Y | Y | Y | . | . | . | . | . | . |
|
||||
| InternLM | Y | Y | Y | . | Y | Y | Y | . | . |
|
||||
|
||||
|
||||
## Technical Detail: The `QuantMode` Flags
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user