mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Move TRT-LLM backend repo to TRT-LLM repo Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> * Address review comments Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> * debug ci Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> * Update triton backend Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> * Fixes after update Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --------- Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
223 lines
7.2 KiB
Python
223 lines
7.2 KiB
Python
import argparse
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
from packaging import version
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--world_size',
|
|
type=int,
|
|
default=1,
|
|
help='world size, only support tensor parallelism now')
|
|
parser.add_argument(
|
|
'--tritonserver',
|
|
type=str,
|
|
help='path to the tritonserver exe',
|
|
default='/opt/tritonserver/bin/tritonserver',
|
|
)
|
|
parser.add_argument(
|
|
'--grpc_port',
|
|
type=str,
|
|
help='tritonserver grpc port',
|
|
default='8001',
|
|
)
|
|
parser.add_argument(
|
|
'--http_port',
|
|
type=str,
|
|
help='tritonserver http port',
|
|
default='8000',
|
|
)
|
|
parser.add_argument(
|
|
'--metrics_port',
|
|
type=str,
|
|
help='tritonserver metrics port',
|
|
default='8002',
|
|
)
|
|
parser.add_argument(
|
|
'--force',
|
|
'-f',
|
|
action='store_true',
|
|
help='launch tritonserver regardless of other instances running')
|
|
parser.add_argument('--log',
|
|
action='store_true',
|
|
help='log triton server stats into log_file')
|
|
parser.add_argument(
|
|
'--log-file',
|
|
type=str,
|
|
help='path to triton log file',
|
|
default='triton_log.txt',
|
|
)
|
|
parser.add_argument(
|
|
'--no-mpi',
|
|
action='store_true',
|
|
help='Launch tritonserver without MPI (single instance mode)',
|
|
default=False,
|
|
)
|
|
|
|
path = str(Path(__file__).parent.absolute()) + '/../all_models/gpt'
|
|
parser.add_argument('--model_repo', type=str, default=path)
|
|
|
|
parser.add_argument(
|
|
'--tensorrt_llm_model_name',
|
|
type=str,
|
|
help=
|
|
'Name(s) of the tensorrt_llm Triton model in the repo. Use comma to separate if multiple model names',
|
|
default='tensorrt_llm',
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--multi-model',
|
|
action='store_true',
|
|
help=
|
|
'Enable support for multiple TRT-LLM models in the Triton model repository'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--disable-spawn-processes',
|
|
action='store_true',
|
|
help='Disable dynamic spawning of child processes when using multi-model'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--multimodal_gpu0_cuda_mem_pool_bytes',
|
|
type=int,
|
|
default=0,
|
|
help=
|
|
'For multimodal usage, model instances need to transfer GPU tensors which requires to have enough cuda pool memory. We currently assume al multimodal_encoderss are on GPU 0.'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--oversubscribe',
|
|
action='store_true',
|
|
help=
|
|
'Append --oversubscribe to the mpirun command. Mainly for SLURM MPI usecases.'
|
|
)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def number_of_gpus():
|
|
output = os.popen('nvidia-smi --list-gpus').read()
|
|
return len(output.strip().split('\n'))
|
|
|
|
|
|
def check_triton_version(required_version):
|
|
try:
|
|
current_version = version.Version(
|
|
os.environ.get('NVIDIA_TRITON_SERVER_VERSION'))
|
|
required_version = version.Version(required_version)
|
|
return current_version > required_version
|
|
except version.InvalidVersion:
|
|
print("Invalid version format. Please use major.minor format.")
|
|
return False
|
|
|
|
|
|
def add_multi_model_config(cmd, args):
|
|
"""Add multi-model configuration to command if enabled."""
|
|
if args.multi_model and check_triton_version(
|
|
'24.06') and not args.disable_spawn_processes:
|
|
cmd += [
|
|
'--pinned-memory-pool-byte-size=0', '--enable-peer-access=false'
|
|
]
|
|
for j in range(number_of_gpus()):
|
|
cmd += [f'--cuda-memory-pool-byte-size={j}:0']
|
|
return cmd
|
|
|
|
|
|
def add_logging_config(cmd, log, log_file, rank=None):
|
|
"""Add logging configuration to command if enabled."""
|
|
if log and (rank is None or rank == 0):
|
|
cmd += ['--log-verbose=3', f'--log-file={log_file}']
|
|
return cmd
|
|
|
|
|
|
def add_port_config(cmd, grpc_port, http_port, metrics_port):
|
|
"""Add port configuration to command."""
|
|
cmd += [
|
|
f'--grpc-port={grpc_port}',
|
|
f'--http-port={http_port}',
|
|
f'--metrics-port={metrics_port}',
|
|
]
|
|
return cmd
|
|
|
|
|
|
def get_cmd(world_size, tritonserver, grpc_port, http_port, metrics_port,
|
|
model_repo, log, log_file, tensorrt_llm_model_name, oversubscribe,
|
|
multimodal_gpu0_cuda_mem_pool_bytes, no_mpi):
|
|
if no_mpi:
|
|
assert world_size == 1, "world size must be 1 when using no-mpi"
|
|
|
|
use_mpi = not no_mpi
|
|
cmd = []
|
|
|
|
if use_mpi:
|
|
cmd = ['mpirun', '--allow-run-as-root']
|
|
if oversubscribe:
|
|
cmd += ['--oversubscribe']
|
|
|
|
for i in range(world_size):
|
|
if use_mpi:
|
|
cmd += ['-n', '1']
|
|
cmd += [tritonserver, f'--model-repository={model_repo}']
|
|
|
|
# Add port configuration
|
|
cmd = add_port_config(cmd, grpc_port, http_port, metrics_port)
|
|
|
|
# Add logging if requested (only for rank 0)
|
|
cmd = add_logging_config(cmd, log, log_file, i)
|
|
|
|
# If rank is not 0, skip loading of models other than `tensorrt_llm_model_name`
|
|
if (i != 0):
|
|
cmd += ['--model-control-mode=explicit']
|
|
model_names = tensorrt_llm_model_name.split(',')
|
|
for name in model_names:
|
|
cmd += [f'--load-model={name}']
|
|
elif i == 0 and multimodal_gpu0_cuda_mem_pool_bytes != 0:
|
|
cmd += [
|
|
f'--cuda-memory-pool-byte-size=0:{multimodal_gpu0_cuda_mem_pool_bytes}'
|
|
]
|
|
|
|
# Add multi-model configuration if enabled
|
|
cmd = add_multi_model_config(cmd, args)
|
|
|
|
# Add port configuration
|
|
cmd = add_port_config(cmd, grpc_port, http_port, metrics_port)
|
|
|
|
cmd += [
|
|
'--disable-auto-complete-config',
|
|
f'--backend-config=python,shm-region-prefix-name=prefix{i}_',
|
|
]
|
|
if use_mpi:
|
|
cmd += [':']
|
|
return cmd
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_arguments()
|
|
res = subprocess.run(['pgrep', '-r', 'R', 'tritonserver'],
|
|
capture_output=True,
|
|
encoding='utf-8')
|
|
if res.stdout:
|
|
pids = res.stdout.replace('\n', ' ').rstrip()
|
|
msg = f'tritonserver process(es) already found with PID(s): {pids}.\n\tUse `kill {pids}` to stop them.'
|
|
if args.force:
|
|
print(msg, file=sys.stderr)
|
|
else:
|
|
raise RuntimeError(msg + ' Or use --force.')
|
|
cmd = get_cmd(int(args.world_size), args.tritonserver, args.grpc_port,
|
|
args.http_port, args.metrics_port, args.model_repo, args.log,
|
|
args.log_file, args.tensorrt_llm_model_name,
|
|
args.oversubscribe, args.multimodal_gpu0_cuda_mem_pool_bytes,
|
|
args.no_mpi)
|
|
env = os.environ.copy()
|
|
if args.multi_model:
|
|
if not args.disable_spawn_processes:
|
|
assert args.world_size == 1, 'World size must be 1 when using multi-model without disable-spawn-processes. Processes will be spawned automatically to run the multi-GPU models'
|
|
env['TRTLLM_ORCHESTRATOR'] = '1'
|
|
env['TRTLLM_ORCHESTRATOR_SPAWN_PROCESSES'] = '0' if args.disable_spawn_processes else '1'
|
|
subprocess.Popen(cmd, env=env)
|