# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import time from functools import partial from typing import Literal, Optional, Tuple, Union # isort: off import torch import tensorrt as trt # isort: on try: import psutil except ImportError: psutil = None try: import pynvml except ImportError: pynvml = None import traceback from tensorrt_llm.logger import logger from ._common import _is_building if psutil is None: logger.warning( "A required package 'psutil' is not installed. Will not " "monitor the host memory usages. Please install the package " "first, e.g, 'pip install psutil'." ) if pynvml is None: logger.warning( "A required package 'pynvml' is not installed. Will not " "monitor the device memory usages. Please install the package " "first, e.g, 'pip install nvidia-ml-py>=12'." ) class Timer: def __init__(self): self._start_times = {} self._total_elapsed_times = {} def start(self, tag): self._start_times[tag] = time.time() def stop(self, tag) -> float: elapsed_time = time.time() - self._start_times[tag] if tag not in self._total_elapsed_times: self._total_elapsed_times[tag] = 0 self._total_elapsed_times[tag] += elapsed_time return elapsed_time def elapsed_time_in_sec(self, tag) -> float: if tag not in self._total_elapsed_times: return None return self._total_elapsed_times[tag] def reset(self, tag=None) -> None: if tag is None: self._start_times.clear() self._total_elapsed_times.clear() else: self._start_times.pop(tag, None) self._total_elapsed_times.pop(tag, None) def summary(self): logger.info("Profile Results") for tag, elapsed_time in self._total_elapsed_times.items(): logger.info(f" - {tag.ljust(30, '.')}: {elapsed_time:.6f} (sec)") _default_timer = Timer() def start(tag): _default_timer.start(tag) def stop(tag): return _default_timer.stop(tag) def elapsed_time_in_sec(tag): return _default_timer.elapsed_time_in_sec(tag) def reset(tag=None): _default_timer.reset(tag=tag) def summary(): _default_timer.summary() MemUnitType = Literal["GiB", "MiB", "KiB"] class PyNVMLContext: def __enter__(self): if pynvml is not None: pynvml.nvmlInit() def __exit__(self, type, value, traceback): if pynvml is not None: pynvml.nvmlShutdown() if pynvml is not None: with PyNVMLContext(): _device_get_memory_info_fn = partial( pynvml.nvmlDeviceGetMemoryInfo, version=pynvml.nvmlMemory_v2, ) def host_memory_info(pid: Optional[int] = None) -> Tuple[int, int, int]: if psutil is not None: process = psutil.Process(pid) # USS reports the amount of memory that would be freed if the process # was terminated right now. # https://psutil.readthedocs.io/en/latest/index.html#psutil.Process.memory_full_info vmem = psutil.virtual_memory() total_mem = vmem.total free_mem = vmem.available alloc_mem = process.memory_full_info().uss return alloc_mem, free_mem, total_mem return 0, 0, 0 # used, free, total def device_memory_info(device: Optional[Union[torch.device, int]] = None) -> Tuple[int, int, int]: if pynvml is not None: if device is None: device = torch.cuda.current_device() index = device.index if isinstance(device, torch.device) else device with PyNVMLContext(): handle = pynvml.nvmlDeviceGetHandleByIndex(index) mem_info = _device_get_memory_info_fn(handle) return mem_info.used, mem_info.free, mem_info.total return 0, 0, 0 # used, free, total def bytes_to_target_unit(mem_bytes: int, unit: MemUnitType) -> float: units = {"GiB": 1 << 30, "MiB": 1 << 20, "KiB": 1 << 10} _rename_map = {"GB": "GiB", "MB": "MiB", "KB": "KiB"} if unit not in units: unit = _rename_map[unit] return float(mem_bytes) / units[unit] def _format(mem_bytes: int, unit: MemUnitType) -> str: mem_usage = bytes_to_target_unit(mem_bytes, unit) return f"{mem_usage:.4f} ({unit})" def _print_mem_message(msg: str, tag: Optional[str] = None): if tag: msg = f"{tag} - {msg}" logger.info(f"[MemUsage] {msg}") def print_host_memory_usage(tag: Optional[str] = None, unit: MemUnitType = "GiB"): if psutil is None: return alloc_mem, _, _ = host_memory_info() msg = f"Allocated Host Memory {_format(alloc_mem, unit)}" _print_mem_message(msg, tag) def print_device_memory_usage( tag: Optional[str] = None, unit: MemUnitType = "GiB", device: Optional[Union[torch.device, int]] = None, ): alloc_mem, _, _ = device_memory_info(device) msg = f"Allocated Device Memory {_format(alloc_mem, unit)}" _print_mem_message(msg, tag) def print_memory_usage( tag: Optional[str] = None, unit: MemUnitType = "GiB", device: Optional[Union[torch.device, int]] = None, ): alloc_host_mem, _, _ = host_memory_info() alloc_device_mem, _, _ = device_memory_info(device=device) msg = ( f"Allocated Memory: Host {_format(alloc_host_mem, unit)} " f"Device {_format(alloc_device_mem, unit)}" ) _print_mem_message(msg, tag) @_is_building def check_gpt_mem_usage( engine, kv_dtype, use_gpt_attention_plugin, paged_kv_cache, max_batch_size, max_beam_width, max_seq_len, local_num_kv_heads, head_size, num_layers, ) -> int: # Get the amount of memory runtime = trt.Runtime(logger.trt_logger) # 1. TensorRT engine activation memory activation_size = 0 try: cuda_engine = runtime.deserialize_cuda_engine(engine) assert cuda_engine is not None activation_size = cuda_engine.device_memory_size_v2 / 1024 / 1024 del cuda_engine except Exception: logger.warning(f"Exception when deserializing engine: {traceback.format_exc()}") logger.warning("Activation memory size will be regarded as 0.") logger.info(f"Activation memory size: {activation_size:.2f} MiB") # 2. Weights weights_size = bytes_to_target_unit(engine.nbytes, "MiB") logger.info(f"Weights memory size: {weights_size:.2f} MiB") # 3. Estimated max KV Cache size kv_cache_size = ( max_batch_size * max_beam_width * 2 * local_num_kv_heads * max_seq_len * head_size * num_layers * kv_dtype.itemsize ) # without plugin, we need two set of kv cache buffers, # one for inputs, and the other for outputs. if not use_gpt_attention_plugin: kv_cache_size *= 2 kv_cache_size = bytes_to_target_unit(kv_cache_size, "MiB") logger.info(f"Max KV Cache memory size: {kv_cache_size:.2f} MiB") # Estimated total amount of memory est_memory_size = activation_size + weights_size + kv_cache_size logger.info(f"Estimated max memory usage on runtime: {est_memory_size:.2f} MiB") _, _, total_mem = device_memory_info(torch.cuda.current_device()) total_mem = bytes_to_target_unit(total_mem, "MiB") if est_memory_size > total_mem: logger.warning( f"Engine is successfully built, but GPU Memory ({total_mem:.2f} MB)" " may not be enough when running inference on max shape." ) if paged_kv_cache: logger.warning( "Since paged_kv_cache is enabled, the max KV Cache " "memory size is a estimate for very extreme cases, " "it's possible that most cases won't meet OOM." ) else: logger.warning( "Enabling `--paged_kv_cache` could help reduce the GPU memory usage on runtime." ) return est_memory_size