TensorRT-LLMs/tensorrt_llm/auto_parallel/runtime_profiling.py
Kaiyu Xie 4bb65f216f
Update TensorRT-LLM (#1274)
* Update TensorRT-LLM

---------

Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-03-12 18:15:52 +08:00

151 lines
5.1 KiB
Python

import numpy as np
import tensorrt as trt
import torch
from tensorrt_llm.logger import logger
from tensorrt_llm.network import get_plugin_info
from .shape_info import get_per_layer_graph
from .utils import get_cache_key, get_trt_network, get_updated_plugin
class NvtxProfiler(object):
def __init__(self, nvtx_name, enable=True):
self.nvtx_name = nvtx_name
self.enable = enable
def __enter__(self):
if self.enable:
torch.cuda.nvtx.range_push(self.nvtx_name)
def __exit__(self, exc_type, exc_val, exc_tb):
if self.enable:
torch.cuda.nvtx.range_pop()
class LayerProfiler(trt.IProfiler):
def __init__(self):
trt.IProfiler.__init__(self)
self.layer_count = 0
self.time = 0
def report_layer_time(self, layer_name, ms):
logger.debug(f'{layer_name=}, {self.layer_count=}, time = {ms} ms')
self.time += ms
self.layer_count += 1
class RuntimeProfiler(object):
def __init__(self):
self.timing_cache = None
def _profile(self, layer, layer_attrs, shapes, values, io_buffer_mapping):
is_plugin = layer.type == trt.LayerType.PLUGIN_V2
if is_plugin and len(layer_attrs) > 0:
plugin_info = get_plugin_info(
get_trt_network(layer),
layer.name,
)
new_plugin, _ = get_updated_plugin(plugin_info, layer_attrs)
layer_attrs = {"plugin": new_plugin}
graph, output_mapping = get_per_layer_graph(layer, shapes, values,
layer_attrs)
graph._io_buffer_mapping = io_buffer_mapping
network = graph.as_trt()
if network.num_outputs > 0 and np.all([
network.get_output(i).is_shape_tensor
for i in range(network.num_outputs)
]):
return 0.0
for proxy_output, output in output_mapping.items():
shapes[proxy_output] = shapes[output]
if not self.timing_cache:
self.timing_cache = network.builder.create_builder_config(
).create_timing_cache(b"")
runner = graph.get_runner(
shapes,
values,
timing_cache=self.timing_cache,
)
context = runner.session.context
context.profiler = LayerProfiler()
runner.run()
profiler_time_first_run = context.profiler.time
runner.run()
return (context.profiler.time - profiler_time_first_run) * 1000.0
def runtime_profile(self, layer, layer_attrs, input_values, strategy,
device_mesh):
logger.debug(f"start to profile layer {layer.name}")
shapes = {}
values = {}
dtypes = {}
trt_layer = layer.as_trt()
sharding_sequences = ()
for i in range(layer.num_inputs):
input = trt_layer.get_input(i)
if input is not None:
shapes[input.name] = strategy.sharding_specs[
f'input{i}'].get_sharded_shape_per_device()
dtypes[input.name] = input.dtype
sharding_sequences += (str(
strategy.sharding_specs[f"input{i}"].sharding_sequence), )
if i in input_values:
values[input.name] = input_values[i]
else:
value = layer.get_input(i).value
if value is not None:
values[input.name] = value
else:
sharding_sequences += (None, )
for i in range(layer.num_outputs):
output = trt_layer.get_output(i)
if f'output{i}' in strategy.communication_actions:
shapes[output.name] = strategy.communication_actions[
f'output{i}'].sharding_spec.get_sharded_shape_per_device()
else:
shapes[output.name] = strategy.sharding_specs[
f'output{i}'].get_sharded_shape_per_device()
dtypes[output.name] = output.dtype
sharding_sequences += (str(
strategy.sharding_specs[f"output{i}"].sharding_sequence), )
data_key = get_cache_key(
trt_layer,
shapes,
values,
dtypes=dtypes,
updated_attrs=layer_attrs,
)
data_key += (sharding_sequences, )
elapsed_time = device_mesh.prof_database.query(
device_mesh.cluster_key,
data_key,
)
if elapsed_time:
logger.debug(
f'runtime profiling cache hit {data_key}: {elapsed_time} us')
return elapsed_time
with NvtxProfiler(f'{layer.name}_{data_key}', enable=True):
elapsed_time = self._profile(
layer.as_trt(),
layer_attrs,
shapes,
values,
layer.graph._io_buffer_mapping,
)
logger.debug(
f'runtime profiling cache miss {data_key}: {elapsed_time} us')
device_mesh.prof_database.update(
device_mesh.cluster_key,
data_key,
(elapsed_time, strategy.alpha_beta_cost),
)
return elapsed_time