mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
* Update TensorRT-LLM --------- Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com> Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
151 lines
5.1 KiB
Python
151 lines
5.1 KiB
Python
import numpy as np
|
|
import tensorrt as trt
|
|
import torch
|
|
|
|
from tensorrt_llm.logger import logger
|
|
from tensorrt_llm.network import get_plugin_info
|
|
|
|
from .shape_info import get_per_layer_graph
|
|
from .utils import get_cache_key, get_trt_network, get_updated_plugin
|
|
|
|
|
|
class NvtxProfiler(object):
|
|
|
|
def __init__(self, nvtx_name, enable=True):
|
|
self.nvtx_name = nvtx_name
|
|
self.enable = enable
|
|
|
|
def __enter__(self):
|
|
if self.enable:
|
|
torch.cuda.nvtx.range_push(self.nvtx_name)
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if self.enable:
|
|
torch.cuda.nvtx.range_pop()
|
|
|
|
|
|
class LayerProfiler(trt.IProfiler):
|
|
|
|
def __init__(self):
|
|
trt.IProfiler.__init__(self)
|
|
self.layer_count = 0
|
|
self.time = 0
|
|
|
|
def report_layer_time(self, layer_name, ms):
|
|
logger.debug(f'{layer_name=}, {self.layer_count=}, time = {ms} ms')
|
|
self.time += ms
|
|
self.layer_count += 1
|
|
|
|
|
|
class RuntimeProfiler(object):
|
|
|
|
def __init__(self):
|
|
self.timing_cache = None
|
|
|
|
def _profile(self, layer, layer_attrs, shapes, values, io_buffer_mapping):
|
|
is_plugin = layer.type == trt.LayerType.PLUGIN_V2
|
|
if is_plugin and len(layer_attrs) > 0:
|
|
plugin_info = get_plugin_info(
|
|
get_trt_network(layer),
|
|
layer.name,
|
|
)
|
|
new_plugin, _ = get_updated_plugin(plugin_info, layer_attrs)
|
|
layer_attrs = {"plugin": new_plugin}
|
|
graph, output_mapping = get_per_layer_graph(layer, shapes, values,
|
|
layer_attrs)
|
|
graph._io_buffer_mapping = io_buffer_mapping
|
|
network = graph.as_trt()
|
|
if network.num_outputs > 0 and np.all([
|
|
network.get_output(i).is_shape_tensor
|
|
for i in range(network.num_outputs)
|
|
]):
|
|
return 0.0
|
|
for proxy_output, output in output_mapping.items():
|
|
shapes[proxy_output] = shapes[output]
|
|
if not self.timing_cache:
|
|
self.timing_cache = network.builder.create_builder_config(
|
|
).create_timing_cache(b"")
|
|
runner = graph.get_runner(
|
|
shapes,
|
|
values,
|
|
timing_cache=self.timing_cache,
|
|
)
|
|
context = runner.session.context
|
|
context.profiler = LayerProfiler()
|
|
runner.run()
|
|
profiler_time_first_run = context.profiler.time
|
|
runner.run()
|
|
return (context.profiler.time - profiler_time_first_run) * 1000.0
|
|
|
|
def runtime_profile(self, layer, layer_attrs, input_values, strategy,
|
|
device_mesh):
|
|
logger.debug(f"start to profile layer {layer.name}")
|
|
shapes = {}
|
|
values = {}
|
|
dtypes = {}
|
|
trt_layer = layer.as_trt()
|
|
|
|
sharding_sequences = ()
|
|
for i in range(layer.num_inputs):
|
|
input = trt_layer.get_input(i)
|
|
if input is not None:
|
|
shapes[input.name] = strategy.sharding_specs[
|
|
f'input{i}'].get_sharded_shape_per_device()
|
|
dtypes[input.name] = input.dtype
|
|
sharding_sequences += (str(
|
|
strategy.sharding_specs[f"input{i}"].sharding_sequence), )
|
|
if i in input_values:
|
|
values[input.name] = input_values[i]
|
|
else:
|
|
value = layer.get_input(i).value
|
|
if value is not None:
|
|
values[input.name] = value
|
|
else:
|
|
sharding_sequences += (None, )
|
|
|
|
for i in range(layer.num_outputs):
|
|
output = trt_layer.get_output(i)
|
|
if f'output{i}' in strategy.communication_actions:
|
|
shapes[output.name] = strategy.communication_actions[
|
|
f'output{i}'].sharding_spec.get_sharded_shape_per_device()
|
|
else:
|
|
shapes[output.name] = strategy.sharding_specs[
|
|
f'output{i}'].get_sharded_shape_per_device()
|
|
dtypes[output.name] = output.dtype
|
|
sharding_sequences += (str(
|
|
strategy.sharding_specs[f"output{i}"].sharding_sequence), )
|
|
data_key = get_cache_key(
|
|
trt_layer,
|
|
shapes,
|
|
values,
|
|
dtypes=dtypes,
|
|
updated_attrs=layer_attrs,
|
|
)
|
|
data_key += (sharding_sequences, )
|
|
elapsed_time = device_mesh.prof_database.query(
|
|
device_mesh.cluster_key,
|
|
data_key,
|
|
)
|
|
if elapsed_time:
|
|
logger.debug(
|
|
f'runtime profiling cache hit {data_key}: {elapsed_time} us')
|
|
return elapsed_time
|
|
with NvtxProfiler(f'{layer.name}_{data_key}', enable=True):
|
|
elapsed_time = self._profile(
|
|
layer.as_trt(),
|
|
layer_attrs,
|
|
shapes,
|
|
values,
|
|
layer.graph._io_buffer_mapping,
|
|
)
|
|
logger.debug(
|
|
f'runtime profiling cache miss {data_key}: {elapsed_time} us')
|
|
|
|
device_mesh.prof_database.update(
|
|
device_mesh.cluster_key,
|
|
data_key,
|
|
(elapsed_time, strategy.alpha_beta_cost),
|
|
)
|
|
|
|
return elapsed_time
|