TensorRT-LLMs/tests/unittest/llmapi/test_memory_profiling.py

150 lines
6.2 KiB
Python

import pytest
import torch
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
create_py_executor
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
DynamicBatchConfig, SchedulerConfig)
from tensorrt_llm.llmapi.llm_args import (CudaGraphConfig, KvCacheConfig,
TorchLlmArgs)
# isort: off
from .test_llm import get_model_path
# isort: on
pytestmark = pytest.mark.threadleak(enabled=False)
def test_profile_kvcache():
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
free_gpu_memory_fraction=0.9)
cuda_graph_config = CudaGraphConfig(max_batch_size=512)
VLM_MODEL = "Qwen2.5-VL-7B-Instruct"
VLM_MODEL_PATH = get_model_path(VLM_MODEL)
build_config = BuildConfig(max_beam_width=1, max_num_tokens=16384)
dynamic_batch_config = DynamicBatchConfig(
enable_batch_size_tuning=True,
enable_max_num_tokens_tuning=False,
dynamic_batch_moving_average_window=128)
scheduler_config = SchedulerConfig(
capacity_scheduler_policy=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT,
dynamic_batch_config=dynamic_batch_config,
)
backend = "pytorch"
llm_args = {
"model": VLM_MODEL,
"scheduler_config": scheduler_config,
"tokenizer": None,
"tensor_parallel_size": 1,
"pipeline_parallel_size": 1,
"moe_expert_parallel_size": None,
"gpus_per_node": 1,
"trust_remote_code": False,
"build_config": build_config,
"max_batch_size": build_config.max_batch_size,
"max_num_tokens": build_config.max_num_tokens,
"max_beam_width": build_config.max_beam_width,
"max_seq_len": build_config.max_seq_len,
"kv_cache_config": kv_cache_config,
"backend": backend,
"num_postprocess_workers": 0,
"postprocess_tokenizer_dir": VLM_MODEL,
"reasoning_parser": None,
"fail_fast_on_attention_window_too_large": False,
"cuda_graph_config": cuda_graph_config,
}
torchllm_args = TorchLlmArgs(**llm_args)
profiling_data = {"enable_mm_reqs": True}
py_executor = create_py_executor(llm_args=torchllm_args,
checkpoint_dir=VLM_MODEL_PATH,
profiling_stage_data=profiling_data)
vlm_activation_bytes_with_mm_reqs = profiling_data["activation_bytes"]
py_executor.shutdown()
torch.cuda.empty_cache()
profiling_data = {"enable_mm_reqs": False}
torchllm_args = TorchLlmArgs(**llm_args)
py_executor_2 = create_py_executor(llm_args=torchllm_args,
checkpoint_dir=VLM_MODEL_PATH,
profiling_stage_data=profiling_data)
vlm_activation_bytes_no_mm_reqs = profiling_data["activation_bytes"]
py_executor_2.shutdown()
torch.cuda.empty_cache()
assert vlm_activation_bytes_with_mm_reqs > vlm_activation_bytes_no_mm_reqs, f"Activation bytes should be higher with mm reqs, but got {vlm_activation_bytes_with_mm_reqs} for mm reqs and {vlm_activation_bytes_no_mm_reqs} without mm reqs"
def test_pyexecutor_and_kvcache_share_execution_stream():
"""Test that PyExecutor and KVCacheManager share the same execution_stream.
The execution_stream is created once in create_py_executor and passed to:
- KVCacheManager (via KvCacheCreator -> _create_kv_cache_manager)
- PyExecutor (via create_py_executor_instance)
Both components must use the same stream for proper synchronization.
"""
# Use a simple model for testing
MODEL = "llama-3.2-models/Llama-3.2-1B-Instruct"
MODEL_PATH = get_model_path(MODEL)
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
free_gpu_memory_fraction=0.5)
build_config = BuildConfig(max_beam_width=1, max_num_tokens=4096)
scheduler_config = SchedulerConfig(
capacity_scheduler_policy=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, )
backend = "pytorch"
llm_args = {
"model": MODEL,
"scheduler_config": scheduler_config,
"tokenizer": None,
"tensor_parallel_size": 1,
"pipeline_parallel_size": 1,
"moe_expert_parallel_size": None,
"gpus_per_node": 1,
"trust_remote_code": False,
"max_batch_size": build_config.max_batch_size,
"max_num_tokens": build_config.max_num_tokens,
"max_beam_width": build_config.max_beam_width,
"max_seq_len": build_config.max_seq_len,
"kv_cache_config": kv_cache_config,
"backend": backend,
"num_postprocess_workers": 0,
"postprocess_tokenizer_dir": MODEL,
"reasoning_parser": None,
"fail_fast_on_attention_window_too_large": False,
}
torchllm_args = TorchLlmArgs(**llm_args)
py_executor = create_py_executor(llm_args=torchllm_args,
checkpoint_dir=MODEL_PATH)
# Get the KVCacheManager from the resource manager
kv_cache_manager = py_executor.resource_manager.get_resource_manager(
ResourceManagerType.KV_CACHE_MANAGER)
# Verify both PyExecutor and KVCacheManager have execution_stream
assert py_executor.execution_stream is not None, \
"PyExecutor should have an execution_stream"
assert kv_cache_manager is not None, \
"KVCacheManager should exist in resource_manager"
assert hasattr(kv_cache_manager, '_stream'), \
"KVCacheManager should have _stream attribute"
# Verify they share the same CUDA stream pointer
assert py_executor.execution_stream.cuda_stream == kv_cache_manager._stream.cuda_stream, \
f"PyExecutor.execution_stream ({py_executor.execution_stream.cuda_stream}) " \
f"should have the same cuda_stream pointer as KVCacheManager._stream ({kv_cache_manager._stream.cuda_stream})"
# Verify they are the exact same stream object
assert py_executor.execution_stream is kv_cache_manager._stream, \
"PyExecutor.execution_stream and KVCacheManager._stream should be the exact same stream object"
py_executor.shutdown()
torch.cuda.empty_cache()