mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Bhuvanesh Sridharan <bhuvan.sridharan@gmail.com> Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com> Co-authored-by: Eddie-Wang1120 <wangjinheng1120@163.com> Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
264 lines
10 KiB
Python
264 lines
10 KiB
Python
import gc
|
|
import os
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from pathlib import Path
|
|
|
|
import tensorrt as trt
|
|
import torch
|
|
from filelock import FileLock
|
|
|
|
from tensorrt_llm.functional import DimRange, Tensor
|
|
from tensorrt_llm.logger import logger
|
|
from tensorrt_llm.network import Network, net_guard
|
|
|
|
from .config import AutoParallelConfig
|
|
from .device_mesh import LogicalDeviceMesh, PhysicalDeviceMesh
|
|
from .node_graph import NodeGraph
|
|
from .parallelization import ParallelConfig, parallelize
|
|
from .pipeline_graph import PipelineGraph
|
|
from .simplifier import GraphConfig, Simplifier, StageType
|
|
from .utils import current_flags
|
|
|
|
|
|
def to_network(graph: PipelineGraph, network: Network):
|
|
logger.debug("Converting graph to network")
|
|
trt_network = graph.as_trt()
|
|
trt_network.name = network.trt_network.name
|
|
new_network = Network()
|
|
new_network._init(trt_network)
|
|
new_network._dtype = network._dtype
|
|
new_network._plugin_config = network._plugin_config
|
|
new_network._unfilled_weights = graph._unfilled_weights
|
|
new_network._auto_parallel_config = graph._auto_parallel_config
|
|
with net_guard(network):
|
|
for i in range(trt_network.num_inputs):
|
|
input = trt_network.get_input(i)
|
|
tensor = Tensor(is_network_input=False)
|
|
if input.name in network._inputs:
|
|
profiles = network._inputs[input.name].profiles
|
|
elif len(network._inputs) == 0:
|
|
profiles = []
|
|
else:
|
|
shape = input.shape
|
|
num_profiles = len(list(network._inputs.values())[0].profiles)
|
|
profile = DimRange(shape, [None] * len(shape))
|
|
profiles = [profile] * num_profiles
|
|
tensor.profiles = profiles
|
|
tensor.trt_tensor = input
|
|
new_network._inputs[input.name] = tensor
|
|
return new_network
|
|
|
|
|
|
def find_solution(
|
|
node_graph: NodeGraph,
|
|
graph_config: GraphConfig,
|
|
lmesh: LogicalDeviceMesh,
|
|
memory_budget: int,
|
|
flags: list,
|
|
device: int,
|
|
dump_path: str,
|
|
) -> ParallelConfig:
|
|
torch.cuda.set_device(device)
|
|
with current_flags(*flags):
|
|
cost_graph = node_graph.get_cost_graph(lmesh)
|
|
num_stages = graph_config.num_stages
|
|
if num_stages == 1:
|
|
stage_types = [None]
|
|
elif num_stages == 2:
|
|
stage_types = [StageType.START, StageType.END]
|
|
else:
|
|
stage_types = [StageType.START, StageType.BLOCK, StageType.END]
|
|
|
|
best_config, best_solution = None, None
|
|
for stage_type in stage_types:
|
|
if stage_type is not None:
|
|
node_graph.set_slowest_stage(stage_type, graph_config)
|
|
solution = node_graph.find_solution(
|
|
cost_graph,
|
|
memory_budget,
|
|
)
|
|
cost = solution.total_cost
|
|
if best_config is None or cost < best_config.cost:
|
|
best_config = ParallelConfig()
|
|
best_config.graph_config = graph_config
|
|
best_config.lmesh = lmesh
|
|
best_config.cost = cost
|
|
best_config.graph_strategy = solution.node_best_strategy
|
|
best_config.stage_type = stage_type
|
|
best_solution = solution
|
|
if dump_path is not None:
|
|
lock = FileLock(f"{dump_path}/path.lock", thread_local=False)
|
|
vlz_name = f"{dump_path}/solution."
|
|
if graph_config.num_micro_batches != 1:
|
|
vlz_name += f"mbs{graph_config.num_micro_batches}."
|
|
if graph_config.num_stages != 1:
|
|
vlz_name += f"stages{graph_config.num_stages}."
|
|
vlz_name += lmesh.cluster_key
|
|
with lock:
|
|
node_graph.visualize_solution(
|
|
best_solution,
|
|
vlz_name,
|
|
ignore_shape_io=True,
|
|
)
|
|
return best_config
|
|
|
|
|
|
def infer_builder_flags(network):
|
|
fp16_enabled = False
|
|
bf16_enabled = False
|
|
int8_enabled = False
|
|
fp8_enabled = False
|
|
|
|
def check_dtype(tensor):
|
|
nonlocal fp16_enabled
|
|
nonlocal bf16_enabled
|
|
nonlocal int8_enabled
|
|
nonlocal fp8_enabled
|
|
if tensor.dtype == trt.DataType.HALF:
|
|
fp16_enabled = True
|
|
elif tensor.dtype == trt.DataType.BF16:
|
|
bf16_enabled = True
|
|
elif tensor.dtype == trt.DataType.INT8:
|
|
int8_enabled = True
|
|
elif tensor.dtype == trt.DataType.FP8:
|
|
fp8_enabled = True
|
|
|
|
trt_network = network.trt_network
|
|
for i in range(trt_network.num_inputs):
|
|
input = trt_network.get_input(i)
|
|
check_dtype(input)
|
|
for i in range(trt_network.num_layers):
|
|
layer = trt_network.get_layer(i)
|
|
for j in range(layer.num_outputs):
|
|
output = layer.get_output(j)
|
|
check_dtype(output)
|
|
|
|
builder_flags = 0
|
|
if fp16_enabled:
|
|
builder_flags |= 1 << int(trt.BuilderFlag.FP16)
|
|
builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
|
|
if bf16_enabled:
|
|
builder_flags |= 1 << int(trt.BuilderFlag.BF16)
|
|
builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
|
|
if int8_enabled:
|
|
builder_flags |= 1 << int(trt.BuilderFlag.INT8)
|
|
if fp8_enabled:
|
|
builder_flags |= 1 << int(trt.BuilderFlag.FP8)
|
|
builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
|
|
return builder_flags
|
|
|
|
|
|
def auto_parallel(network: Network, config: AutoParallelConfig):
|
|
debug_mode = config.debug_mode
|
|
memory_budget = config.get_cluster_info(
|
|
).memory_budget_per_device * 1024 * 1024 * 1024
|
|
enable_pipeline_parallelism = config.enable_pipeline_parallelism
|
|
if config.world_size < config.gpus_per_node:
|
|
num_hosts = 1
|
|
num_devices_per_host = config.world_size
|
|
else:
|
|
assert config.world_size % config.gpus_per_node == 0
|
|
num_hosts = config.world_size // config.gpus_per_node
|
|
num_devices_per_host = config.gpus_per_node
|
|
parallel_config_cache = config.parallel_config_cache
|
|
dump_path = config.dump_path if debug_mode else None
|
|
fill_weights = config.fill_weights
|
|
|
|
if num_hosts == 1 and num_devices_per_host == 1:
|
|
return [network]
|
|
|
|
if dump_path is not None:
|
|
if not os.path.exists(dump_path):
|
|
os.makedirs(dump_path)
|
|
|
|
builder_flags = config.builder_flags or infer_builder_flags(network)
|
|
flags = [builder_flags, network.strongly_typed]
|
|
with current_flags(*flags):
|
|
simplifier = Simplifier(network, config)
|
|
network_hash = simplifier.get_network_hash()
|
|
|
|
best_config = None
|
|
if parallel_config_cache is not None and Path(
|
|
parallel_config_cache).exists():
|
|
parallel_config = ParallelConfig.from_file(parallel_config_cache)
|
|
if (ParallelConfig.VERSION == parallel_config.version
|
|
and network_hash == parallel_config.network_hash
|
|
and config == parallel_config.auto_parallel_config):
|
|
logger.info(
|
|
f"use cache of parallel config from {parallel_config_cache}"
|
|
)
|
|
best_config = parallel_config
|
|
|
|
if best_config is None:
|
|
num_devices = num_hosts * num_devices_per_host
|
|
phy_ids = [[
|
|
i + j * num_devices_per_host
|
|
for i in range(num_devices_per_host)
|
|
] for j in range(num_hosts)]
|
|
phy_mesh = PhysicalDeviceMesh(phy_ids, config)
|
|
if enable_pipeline_parallelism:
|
|
num_micro_batches_list = simplifier.list_all_num_micro_batches()
|
|
else:
|
|
num_micro_batches_list = [1]
|
|
|
|
jobs = []
|
|
for num_micro_batches in num_micro_batches_list:
|
|
simplifier.infer_shapes(num_micro_batches)
|
|
if enable_pipeline_parallelism:
|
|
pipeline_configs = phy_mesh.list_all_pipeline_configs()
|
|
else:
|
|
pipeline_configs = [(1, num_devices)]
|
|
for num_stages, num_devices_per_stage in pipeline_configs:
|
|
# TODO: add fallback path that allows num_micro_batches >= num_stages
|
|
# if no solution satisfies memory budget
|
|
if num_micro_batches < num_stages:
|
|
continue
|
|
simplified_graph, graph_config = simplifier.simplify_graph(
|
|
phy_mesh,
|
|
num_stages,
|
|
num_devices_per_stage,
|
|
)
|
|
if simplified_graph is None:
|
|
continue
|
|
node_graph = NodeGraph(simplified_graph)
|
|
node_graph.assign_cost_weights(graph_config)
|
|
lmeshes = graph_config.stage_phy_meshes[
|
|
0].get_logical_meshes()
|
|
for lmesh in lmeshes:
|
|
jobs.append(
|
|
(node_graph, graph_config, lmesh, memory_budget *
|
|
(num_devices / num_devices_per_stage)))
|
|
|
|
try:
|
|
with ThreadPoolExecutor() as executor:
|
|
best_config = sorted(
|
|
executor.map(
|
|
lambda x: find_solution(
|
|
*x,
|
|
flags,
|
|
torch.cuda.current_device(),
|
|
dump_path,
|
|
),
|
|
jobs,
|
|
),
|
|
key=lambda x: x.cost,
|
|
)[0]
|
|
finally:
|
|
phy_mesh.close()
|
|
|
|
if parallel_config_cache is not None:
|
|
best_config.network_hash = network_hash
|
|
best_config.auto_parallel_config = config
|
|
best_config.save(parallel_config_cache)
|
|
|
|
new_graphs = parallelize(simplifier, best_config)
|
|
|
|
networks = [to_network(new_graph, network) for new_graph in new_graphs]
|
|
if debug_mode and fill_weights:
|
|
networks[0]._fill_weights()
|
|
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
return networks
|