mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5717993][fix] Add execution_stream across PyExecutor, KVCacheManager, PeftCacheManager to ensure proper CUDA stream synchronization between KV cache transfer operations and model forward kernels. (#10060)
Signed-off-by: SimengLiu-nv <simengl@nvidia.com>
This commit is contained in:
parent
0d2e2718ce
commit
84d107b2f0
@ -77,6 +77,7 @@ class KvCacheCreator:
|
||||
speculative_config: SpeculativeConfig,
|
||||
sparse_attention_config: SparseAttentionConfig,
|
||||
profiling_stage_data: Optional[dict],
|
||||
execution_stream: Optional[torch.cuda.Stream] = None,
|
||||
):
|
||||
self._model_engine = model_engine
|
||||
self._draft_model_engine = draft_model_engine
|
||||
@ -97,6 +98,7 @@ class KvCacheCreator:
|
||||
self._profiling_stage_data = profiling_stage_data
|
||||
self._kv_cache_manager_cls = get_kv_cache_manager_cls(
|
||||
model_engine.model.model_config)
|
||||
self._execution_stream = execution_stream
|
||||
|
||||
def _get_kv_size_per_token(self):
|
||||
model_config = self._model_engine.model.model_config
|
||||
@ -474,6 +476,7 @@ class KvCacheCreator:
|
||||
max_beam_width=self._max_beam_width,
|
||||
kv_connector_manager=self._kv_connector_manager,
|
||||
estimating_kv_cache=estimating_kv_cache,
|
||||
execution_stream=self._execution_stream,
|
||||
)
|
||||
|
||||
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to self
|
||||
@ -527,14 +530,20 @@ class KvCacheCreator:
|
||||
|
||||
|
||||
def _create_kv_cache_manager(
|
||||
model_engine: PyTorchModelEngine, kv_cache_manager_cls,
|
||||
mapping: Mapping, kv_cache_config: KvCacheConfig, tokens_per_block: int,
|
||||
max_seq_len: int, max_batch_size: int,
|
||||
model_engine: PyTorchModelEngine,
|
||||
kv_cache_manager_cls,
|
||||
mapping: Mapping,
|
||||
kv_cache_config: KvCacheConfig,
|
||||
tokens_per_block: int,
|
||||
max_seq_len: int,
|
||||
max_batch_size: int,
|
||||
spec_config: Optional[SpeculativeConfig],
|
||||
sparse_attn_config: Optional[SparseAttentionConfig],
|
||||
max_num_tokens: int, max_beam_width: int,
|
||||
max_num_tokens: int,
|
||||
max_beam_width: int,
|
||||
kv_connector_manager: Optional[KvCacheConnectorManager],
|
||||
estimating_kv_cache: bool) -> KVCacheManager:
|
||||
estimating_kv_cache: bool,
|
||||
execution_stream: Optional[torch.cuda.Stream] = None) -> KVCacheManager:
|
||||
"""
|
||||
Returns:
|
||||
A KVCacheManager instance for the given model_engine
|
||||
@ -580,6 +589,7 @@ def _create_kv_cache_manager(
|
||||
if not estimating_kv_cache else None,
|
||||
sparse_attn_config=sparse_attn_config,
|
||||
is_estimating_kv_cache=estimating_kv_cache,
|
||||
execution_stream=execution_stream,
|
||||
)
|
||||
elif is_nemotron_hybrid(config):
|
||||
if max_beam_width > 1:
|
||||
@ -623,6 +633,7 @@ def _create_kv_cache_manager(
|
||||
dtype=kv_cache_dtype,
|
||||
spec_config=spec_config,
|
||||
is_estimating_kv_cache=estimating_kv_cache,
|
||||
execution_stream=execution_stream,
|
||||
)
|
||||
elif is_qwen3_next(config):
|
||||
if max_beam_width > 1:
|
||||
@ -672,6 +683,7 @@ def _create_kv_cache_manager(
|
||||
dtype=kv_cache_dtype,
|
||||
spec_config=spec_config,
|
||||
is_estimating_kv_cache=estimating_kv_cache,
|
||||
execution_stream=execution_stream,
|
||||
)
|
||||
else:
|
||||
# NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_from_cpp in KVCahceManager
|
||||
@ -700,6 +712,7 @@ def _create_kv_cache_manager(
|
||||
if not estimating_kv_cache else None,
|
||||
sparse_attn_config=sparse_attn_config,
|
||||
is_estimating_kv_cache=estimating_kv_cache,
|
||||
execution_stream=execution_stream,
|
||||
)
|
||||
return kv_cache_manager
|
||||
|
||||
@ -727,6 +740,7 @@ def create_py_executor_instance(
|
||||
scheduler_config: Optional[SchedulerConfig] = None,
|
||||
cache_transceiver_config: Optional[CacheTransceiverConfig] = None,
|
||||
virtual_memory_pools: Optional[dict] = None,
|
||||
execution_stream: Optional[torch.cuda.Stream] = None,
|
||||
) -> PyExecutor:
|
||||
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
|
||||
|
||||
@ -813,6 +827,7 @@ def create_py_executor_instance(
|
||||
lora_config=lora_config,
|
||||
model_config=model_binding_config,
|
||||
world_config=world_config,
|
||||
execution_stream=execution_stream,
|
||||
)
|
||||
resources[ResourceManagerType.PEFT_CACHE_MANAGER] = peft_cache_manager
|
||||
model_engine.set_lora_model_config(
|
||||
@ -875,7 +890,8 @@ def create_py_executor_instance(
|
||||
kv_connector_manager=kv_connector_manager,
|
||||
max_seq_len=max_seq_len,
|
||||
peft_cache_config=peft_cache_config,
|
||||
virtual_memory_pools=virtual_memory_pools)
|
||||
virtual_memory_pools=virtual_memory_pools,
|
||||
execution_stream=execution_stream)
|
||||
|
||||
|
||||
def create_torch_sampler_args(
|
||||
|
||||
@ -197,6 +197,7 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
|
||||
dtype: DataType = DataType.HALF,
|
||||
spec_config: Optional["DecodingBaseConfig"] = None,
|
||||
is_estimating_kv_cache: bool = False,
|
||||
execution_stream: Optional[torch.cuda.Stream] = None,
|
||||
) -> None:
|
||||
|
||||
# mamba hybrid cache requires block reuse to be disabled in KV cache config
|
||||
@ -234,6 +235,7 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
|
||||
spec_config=spec_config,
|
||||
layer_mask=layer_mask,
|
||||
is_estimating_kv_cache=is_estimating_kv_cache,
|
||||
execution_stream=execution_stream,
|
||||
)
|
||||
|
||||
def prepare_resources(self, scheduled_batch: ScheduledRequests):
|
||||
|
||||
@ -136,11 +136,22 @@ class PyExecutor:
|
||||
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
|
||||
max_seq_len: Optional[int] = None,
|
||||
peft_cache_config: Optional[PeftCacheConfig] = None,
|
||||
virtual_memory_pools: Optional[dict] = None):
|
||||
virtual_memory_pools: Optional[dict] = None,
|
||||
execution_stream: Optional[torch.cuda.Stream] = None):
|
||||
super(PyExecutor, self).__init__()
|
||||
self.device_id = torch.cuda.current_device()
|
||||
self.global_rank = dist.rank
|
||||
|
||||
# Store the execution stream for model forward operations.
|
||||
# This stream is used for proper synchronization with KVCacheTransferManager.
|
||||
# execution_stream can be provided by create_py_executor
|
||||
# Create a new stream if none provided
|
||||
self.execution_stream = execution_stream if execution_stream is not None else torch.cuda.Stream(
|
||||
)
|
||||
logger.info(
|
||||
f"[PyExecutor] execution_stream initialized: {self.execution_stream}. "
|
||||
)
|
||||
|
||||
self.peft_cache_config = peft_cache_config
|
||||
|
||||
self.iter_counter = 0
|
||||
@ -245,10 +256,19 @@ class PyExecutor:
|
||||
self.inflight_req_ids = ReqIdsSet()
|
||||
|
||||
# During warmup, we don't enable the profiler
|
||||
# Run warmup on the execution_stream for proper synchronization with
|
||||
# KVCacheTransferManager's onboard/offload operations.
|
||||
self.is_warmup = True
|
||||
self.model_engine.warmup(self.resource_manager)
|
||||
if self.draft_model_engine is not None:
|
||||
self.draft_model_engine.warmup(self.resource_manager)
|
||||
|
||||
self.execution_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.execution_stream):
|
||||
self.model_engine.warmup(self.resource_manager)
|
||||
if self.draft_model_engine is not None:
|
||||
self.draft_model_engine.warmup(self.resource_manager)
|
||||
|
||||
# Ensure the default stream waits for execution_stream to complete
|
||||
# before subsequent operations.
|
||||
torch.cuda.current_stream().wait_stream(self.execution_stream)
|
||||
self.is_warmup = False
|
||||
|
||||
self.is_shutdown = False
|
||||
@ -2231,10 +2251,19 @@ class PyExecutor:
|
||||
a.py_return_context_logits
|
||||
for a in scheduled_requests.context_requests)
|
||||
cache_indirection_buffer = self.sampler.get_cache_indirection()
|
||||
outputs = forward(scheduled_requests, self.resource_manager,
|
||||
new_tensors_device, gather_context_logits,
|
||||
cache_indirection_buffer,
|
||||
num_accepted_tokens_device)
|
||||
|
||||
# Run model forward on the execution stream for proper synchronization
|
||||
# with KVCacheTransferManager's onboard/offload operations.
|
||||
self.execution_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.execution_stream):
|
||||
outputs = forward(scheduled_requests, self.resource_manager,
|
||||
new_tensors_device, gather_context_logits,
|
||||
cache_indirection_buffer,
|
||||
num_accepted_tokens_device)
|
||||
|
||||
# Ensure the default stream waits for execution_stream to complete
|
||||
# before downstream operations use the outputs.
|
||||
torch.cuda.current_stream().wait_stream(self.execution_stream)
|
||||
|
||||
self._kv_connector_wait_for_save()
|
||||
|
||||
|
||||
@ -601,6 +601,13 @@ def create_py_executor(
|
||||
resources = {}
|
||||
estimating_kv_cache = False
|
||||
kv_cache_creator = None
|
||||
|
||||
# Create the execution stream for model forward operations
|
||||
# for proper synchronization with KVCacheTransferManager's onboard/offload operations.
|
||||
execution_stream = torch.cuda.Stream()
|
||||
logger.info(
|
||||
f"[create_py_executor] Created execution_stream: {execution_stream}")
|
||||
|
||||
if model_engine.model.model_config.is_generation:
|
||||
#NOTE: non-generation models do not have kv cache
|
||||
kv_cache_creator = KvCacheCreator(
|
||||
@ -619,6 +626,7 @@ def create_py_executor(
|
||||
speculative_config=spec_config,
|
||||
profiling_stage_data=profiling_stage_data,
|
||||
sparse_attention_config=sparse_attention_config,
|
||||
execution_stream=execution_stream,
|
||||
)
|
||||
estimating_kv_cache = kv_cache_creator.try_prepare_estimation()
|
||||
with allocation_scope(
|
||||
@ -676,6 +684,7 @@ def create_py_executor(
|
||||
scheduler_config=scheduler_config,
|
||||
cache_transceiver_config=cache_transceiver_config,
|
||||
virtual_memory_pools=vm_pools if not estimating_kv_cache else None,
|
||||
execution_stream=execution_stream,
|
||||
)
|
||||
# Originally, peft_cache_config might be mutated inside
|
||||
# create_py_executor_instance. Restore it here.
|
||||
@ -736,6 +745,7 @@ def create_py_executor(
|
||||
scheduler_config=scheduler_config,
|
||||
cache_transceiver_config=cache_transceiver_config,
|
||||
virtual_memory_pools=vm_pools,
|
||||
execution_stream=execution_stream,
|
||||
)
|
||||
|
||||
_adjust_torch_mem_fraction()
|
||||
|
||||
@ -176,6 +176,7 @@ class KVCacheManager(BaseResourceManager):
|
||||
indexer_k_cache_quant_block_size: int = 128,
|
||||
indexer_k_cache_index_head_dim: int = 0,
|
||||
is_estimating_kv_cache: bool = False,
|
||||
execution_stream: Optional[torch.cuda.Stream] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.mapping = mapping
|
||||
@ -351,9 +352,13 @@ class KVCacheManager(BaseResourceManager):
|
||||
# Set up temp_attention_window_inputs
|
||||
temp_attention_window_inputs = self._set_temp_attention_window_inputs()
|
||||
|
||||
# Note that this stream is unused for now. Will be used for copying to host
|
||||
# when that feature is enabled.
|
||||
self._stream = torch.cuda.Stream()
|
||||
# Use the provided execution stream for proper synchronization with KVCacheTransferManager.
|
||||
# The execution stream is the stream where model forward kernels run, and KVCacheTransferManager
|
||||
# needs to synchronize with it for onboard/offload operations.
|
||||
# If no execution stream is provided, create a new one (for backward compatibility).
|
||||
self._stream = execution_stream if execution_stream is not None else torch.cuda.Stream(
|
||||
)
|
||||
logger.info(f"[KVCacheManager] execution_stream: {self._stream}")
|
||||
kwargs = {
|
||||
'num_kv_heads_per_layer': self.num_kv_heads_per_layer,
|
||||
'size_per_head': head_dim,
|
||||
@ -365,7 +370,7 @@ class KVCacheManager(BaseResourceManager):
|
||||
'temp_attention_window_inputs': temp_attention_window_inputs,
|
||||
'dtype': dtype,
|
||||
'sink_token_length': sink_token_length,
|
||||
'stream': self._stream.cuda_stream,
|
||||
'stream': self._stream.cuda_stream, # Pass to BufferManager
|
||||
'max_sequence_length': max_seq_len,
|
||||
'enable_block_reuse': kv_cache_config.enable_block_reuse,
|
||||
'onboard_blocks': kv_cache_config.onboard_blocks,
|
||||
@ -1442,7 +1447,8 @@ class PeftCacheManager(BaseResourceManager):
|
||||
peft_cache_config: PeftCacheConfig,
|
||||
lora_config: LoraConfig,
|
||||
model_config: ModelConfigCpp,
|
||||
world_config: WorldConfig | None = None):
|
||||
world_config: WorldConfig | None = None,
|
||||
execution_stream: Optional[torch.cuda.Stream] = None):
|
||||
import tensorrt_llm.bindings as _tb
|
||||
|
||||
peft_cache_config = peft_cache_config._to_pybind()
|
||||
@ -1467,8 +1473,12 @@ class PeftCacheManager(BaseResourceManager):
|
||||
world_config = _tb.WorldConfig()
|
||||
|
||||
BufferManager = tensorrt_llm.bindings.internal.runtime.BufferManager
|
||||
buffer_manager = BufferManager(torch.cuda.current_stream().cuda_stream,
|
||||
True)
|
||||
buffer_manager_stream = execution_stream.cuda_stream if execution_stream is not None else torch.cuda.current_stream(
|
||||
).cuda_stream
|
||||
buffer_manager = BufferManager(buffer_manager_stream, True)
|
||||
logger.info(
|
||||
f"[PeftCacheManager] buffer_manager_stream: {buffer_manager_stream}"
|
||||
)
|
||||
self.impl = PeftCacheManagerCpp(config=peft_cache_manager_config,
|
||||
model_config=model_config,
|
||||
world_config=world_config,
|
||||
|
||||
@ -380,23 +380,43 @@ class LmEvalEvaluator(Evaluator):
|
||||
|
||||
@contextmanager
|
||||
def _patch_lm_eval(self):
|
||||
if self.dataset_path is None:
|
||||
yield
|
||||
return
|
||||
from pathlib import Path
|
||||
|
||||
import lm_eval
|
||||
self._task_config_post_init = lm_eval.api.task.TaskConfig.__post_init__
|
||||
import lm_eval.tasks
|
||||
|
||||
def _patched(task_config, *args, **kwargs):
|
||||
task_config.dataset_path = self.dataset_path
|
||||
self._task_config_post_init(task_config, *args, **kwargs)
|
||||
# Patch Path.relative_to to handle custom task paths outside lm_eval/tasks
|
||||
# This is needed with lm_eval>=0.4.9.2 with new function pretty_print_task (a local function inside
|
||||
# get_task_dict) calls yaml_path.relative_to(lm_eval_tasks_path) which fails
|
||||
# when the yaml is from tensorrt_llm/evaluate/lm_eval_tasks
|
||||
original_relative_to = Path.relative_to
|
||||
|
||||
lm_eval.api.task.TaskConfig.__post_init__ = _patched
|
||||
def _patched_relative_to(self, other, *args, **kwargs):
|
||||
try:
|
||||
return original_relative_to(self, other, *args, **kwargs)
|
||||
except ValueError:
|
||||
# Return absolute path if relative_to fails (path not under base)
|
||||
return self
|
||||
|
||||
Path.relative_to = _patched_relative_to
|
||||
|
||||
# Optionally patch dataset_path if provided
|
||||
original_post_init = None
|
||||
if self.dataset_path is not None:
|
||||
original_post_init = lm_eval.api.task.TaskConfig.__post_init__
|
||||
|
||||
def _patched_post_init(task_config, *args, **kwargs):
|
||||
task_config.dataset_path = self.dataset_path
|
||||
original_post_init(task_config, *args, **kwargs)
|
||||
|
||||
lm_eval.api.task.TaskConfig.__post_init__ = _patched_post_init
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lm_eval.api.task.TaskConfig.__post_init__ = self._task_config_post_init
|
||||
Path.relative_to = original_relative_to
|
||||
if original_post_init is not None:
|
||||
lm_eval.api.task.TaskConfig.__post_init__ = original_post_init
|
||||
|
||||
def generate_samples(self) -> Iterable[tuple]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -15,7 +15,8 @@ l0_a100:
|
||||
tests:
|
||||
- unittest/llmapi/test_llm_pytorch.py
|
||||
- unittest/llmapi/test_mpi_session.py ISOLATION
|
||||
- unittest/llmapi/test_memory_profiling.py # profile kvcache for vision encoder
|
||||
- unittest/llmapi/test_memory_profiling.py::test_profile_kvcache # profile kvcache for vision encoder
|
||||
- unittest/llmapi/test_memory_profiling.py::test_pyexecutor_and_kvcache_share_execution_stream # test that PyExecutor and KVCacheManager share the same execution_stream
|
||||
- unittest/trt/model_api/test_model_quantization.py
|
||||
# executor
|
||||
- unittest/executor/test_base_worker.py
|
||||
|
||||
@ -76,9 +76,10 @@ l0_h100:
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_dummy_load_format
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.0] TIMEOUT (90)
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.5] TIMEOUT (90)
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9] TIMEOUT (90)
|
||||
# Waive known failures in https://nvbugs/5774869
|
||||
# - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.0] TIMEOUT (90)
|
||||
# - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.5] TIMEOUT (90)
|
||||
# - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9] TIMEOUT (90)
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=False-eagle3_one_model=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=True-eagle3_one_model=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=False-eagle3_one_model=True]
|
||||
|
||||
@ -298,7 +298,7 @@ full:L40S/accuracy/test_cli_flow.py::TestGpt2::test_weight_streaming_plugin SKIP
|
||||
full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp1pp2] SKIP (https://nvbugs/5596337)
|
||||
accuracy/test_llm_api.py::TestMixtral8x7BInstruct::test_awq_tp2 SKIP (https://nvbugs/5598847)
|
||||
examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-3.5-MoE-instruct] SKIP (https://nvbugs/5465143)
|
||||
unittest/llmapi/test_memory_profiling.py SKIP (https://nvbugs/5580781)
|
||||
unittest/llmapi/test_memory_profiling.py::test_profile_kvcache SKIP (https://nvbugs/5580781)
|
||||
triton_server/test_triton.py::test_llava[llava] SKIP (https://nvbugs/5547414)
|
||||
full:RTX/accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype SKIP (https://nvbugs/5569696)
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-cutlass-auto] SKIP (https://nvbugs/5596343)
|
||||
|
||||
@ -101,7 +101,8 @@ def _create_request(num_tokens, req_id: int):
|
||||
return result
|
||||
|
||||
|
||||
def create_model_engine_and_kvcache(llm_args: TorchLlmArgs = None):
|
||||
def create_model_engine_and_kvcache(llm_args: TorchLlmArgs = None,
|
||||
execution_stream: torch.cuda.Stream = None):
|
||||
tokens_per_block = 1
|
||||
max_tokens = 258 # Atleast 1 more than the max seq len
|
||||
num_layers = 1
|
||||
@ -135,6 +136,7 @@ def create_model_engine_and_kvcache(llm_args: TorchLlmArgs = None):
|
||||
max_batch_size=batch_size,
|
||||
mapping=mapping,
|
||||
dtype=tensorrt_llm.bindings.DataType.HALF,
|
||||
execution_stream=execution_stream,
|
||||
)
|
||||
|
||||
return model_engine, kv_cache_manager
|
||||
@ -480,6 +482,41 @@ class PyTorchModelEngineTestCase(unittest.TestCase):
|
||||
actual_seq_lens = attn_metadata.seq_lens.cpu().tolist()
|
||||
self.assertEqual(actual_seq_lens, expected_seq_lens)
|
||||
|
||||
def test_kv_cache_manager_with_execution_stream(self):
|
||||
"""Test that KVCacheManager uses the provided execution_stream.
|
||||
"""
|
||||
# Create a dedicated execution stream
|
||||
execution_stream = torch.cuda.Stream()
|
||||
|
||||
model_engine, kv_cache_manager = create_model_engine_and_kvcache(
|
||||
execution_stream=execution_stream)
|
||||
|
||||
# Verify the KVCacheManager uses the provided execution stream
|
||||
self.assertEqual(
|
||||
kv_cache_manager._stream.cuda_stream, execution_stream.cuda_stream,
|
||||
"KVCacheManager should use the provided execution_stream")
|
||||
|
||||
resource_manager = ResourceManager(
|
||||
{ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})
|
||||
|
||||
prompt_len = 32
|
||||
requests = [_create_request(prompt_len, 0)]
|
||||
|
||||
batch = ScheduledRequests()
|
||||
batch.context_requests = requests
|
||||
batch.generation_requests = []
|
||||
kv_cache_manager.prepare_resources(batch)
|
||||
with torch.cuda.stream(execution_stream):
|
||||
model_engine.forward(batch, resource_manager)
|
||||
|
||||
# Verify the stream is still the same after forward pass
|
||||
self.assertEqual(
|
||||
kv_cache_manager._stream.cuda_stream, execution_stream.cuda_stream,
|
||||
"KVCacheManager should still use the provided execution_stream after forward"
|
||||
)
|
||||
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -764,6 +764,93 @@ class TestResourceManager(unittest.TestCase):
|
||||
)
|
||||
kv_cache_manager.free_resources(req3)
|
||||
|
||||
def test_kv_cache_manager_with_execution_stream(self):
|
||||
"""
|
||||
Test that KVCacheManager uses the provided execution_stream.
|
||||
"""
|
||||
# Create a dedicated execution stream
|
||||
execution_stream = torch.cuda.Stream()
|
||||
|
||||
kv_cache_config = KvCacheConfig(
|
||||
free_gpu_memory_fraction=0.1,
|
||||
max_tokens=256,
|
||||
)
|
||||
|
||||
# Create KVCacheManager with the execution stream
|
||||
kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
kv_cache_type=tensorrt_llm.bindings.internal.batch_manager.
|
||||
CacheType.SELF,
|
||||
num_layers=2,
|
||||
num_kv_heads=2,
|
||||
head_dim=128,
|
||||
tokens_per_block=64,
|
||||
max_seq_len=1024,
|
||||
max_batch_size=1,
|
||||
mapping=Mapping(),
|
||||
execution_stream=execution_stream,
|
||||
)
|
||||
|
||||
# Verify the KVCacheManager uses the provided execution stream
|
||||
# The internal stream should be the same as the execution stream we provided
|
||||
self.assertEqual(
|
||||
kv_cache_manager._stream.cuda_stream, execution_stream.cuda_stream,
|
||||
"KVCacheManager should use the provided execution_stream")
|
||||
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
def test_kv_cache_manager_without_execution_stream(self):
|
||||
"""Test that KVCacheManager creates its own stream when no execution_stream is provided.
|
||||
|
||||
This verifies backward compatibility.
|
||||
"""
|
||||
kv_cache_config = KvCacheConfig(
|
||||
free_gpu_memory_fraction=0.1,
|
||||
max_tokens=256,
|
||||
)
|
||||
|
||||
# Create KVCacheManager without providing an execution stream
|
||||
kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
kv_cache_type=tensorrt_llm.bindings.internal.batch_manager.
|
||||
CacheType.SELF,
|
||||
num_layers=2,
|
||||
num_kv_heads=2,
|
||||
head_dim=128,
|
||||
tokens_per_block=64,
|
||||
max_seq_len=1024,
|
||||
max_batch_size=1,
|
||||
mapping=Mapping(),
|
||||
)
|
||||
|
||||
# Verify the KVCacheManager creates its own stream
|
||||
self.assertIsNotNone(
|
||||
kv_cache_manager._stream,
|
||||
"KVCacheManager should create its own stream when none is provided")
|
||||
|
||||
# The stream should not be the default stream (0)
|
||||
self.assertNotEqual(kv_cache_manager._stream.cuda_stream, 0,
|
||||
"KVCacheManager should not use the default stream")
|
||||
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
def test_peft_cache_manager_with_execution_stream(self):
|
||||
"""Test that PeftCacheManager uses the provided execution_stream.
|
||||
"""
|
||||
peft_cache_config = self.create_peft_cache_config()
|
||||
execution_stream = torch.cuda.Stream()
|
||||
|
||||
# Create PeftCacheManager with execution_stream
|
||||
peft_cache_manager = PeftCacheManager(
|
||||
peft_cache_config=peft_cache_config,
|
||||
lora_config=LoraConfig(),
|
||||
model_config=self.model_config,
|
||||
execution_stream=execution_stream,
|
||||
)
|
||||
|
||||
# The PeftCacheManager should be created successfully with the provided stream
|
||||
self.assertTrue(peft_cache_manager.impl.enabled)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -3,6 +3,7 @@ import torch
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
|
||||
create_py_executor
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
|
||||
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
|
||||
DynamicBatchConfig, SchedulerConfig)
|
||||
from tensorrt_llm.llmapi.llm_args import (CudaGraphConfig, KvCacheConfig,
|
||||
@ -75,3 +76,74 @@ def test_profile_kvcache():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert vlm_activation_bytes_with_mm_reqs > vlm_activation_bytes_no_mm_reqs, f"Activation bytes should be higher with mm reqs, but got {vlm_activation_bytes_with_mm_reqs} for mm reqs and {vlm_activation_bytes_no_mm_reqs} without mm reqs"
|
||||
|
||||
|
||||
def test_pyexecutor_and_kvcache_share_execution_stream():
|
||||
"""Test that PyExecutor and KVCacheManager share the same execution_stream.
|
||||
|
||||
The execution_stream is created once in create_py_executor and passed to:
|
||||
- KVCacheManager (via KvCacheCreator -> _create_kv_cache_manager)
|
||||
- PyExecutor (via create_py_executor_instance)
|
||||
|
||||
Both components must use the same stream for proper synchronization.
|
||||
"""
|
||||
# Use a simple model for testing
|
||||
MODEL = "llama-3.2-models/Llama-3.2-1B-Instruct"
|
||||
MODEL_PATH = get_model_path(MODEL)
|
||||
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
|
||||
free_gpu_memory_fraction=0.5)
|
||||
|
||||
build_config = BuildConfig(max_beam_width=1, max_num_tokens=4096)
|
||||
scheduler_config = SchedulerConfig(
|
||||
capacity_scheduler_policy=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, )
|
||||
backend = "pytorch"
|
||||
llm_args = {
|
||||
"model": MODEL,
|
||||
"scheduler_config": scheduler_config,
|
||||
"tokenizer": None,
|
||||
"tensor_parallel_size": 1,
|
||||
"pipeline_parallel_size": 1,
|
||||
"moe_expert_parallel_size": None,
|
||||
"gpus_per_node": 1,
|
||||
"trust_remote_code": False,
|
||||
"max_batch_size": build_config.max_batch_size,
|
||||
"max_num_tokens": build_config.max_num_tokens,
|
||||
"max_beam_width": build_config.max_beam_width,
|
||||
"max_seq_len": build_config.max_seq_len,
|
||||
"kv_cache_config": kv_cache_config,
|
||||
"backend": backend,
|
||||
"num_postprocess_workers": 0,
|
||||
"postprocess_tokenizer_dir": MODEL,
|
||||
"reasoning_parser": None,
|
||||
"fail_fast_on_attention_window_too_large": False,
|
||||
}
|
||||
|
||||
torchllm_args = TorchLlmArgs(**llm_args)
|
||||
|
||||
py_executor = create_py_executor(llm_args=torchllm_args,
|
||||
checkpoint_dir=MODEL_PATH)
|
||||
|
||||
# Get the KVCacheManager from the resource manager
|
||||
kv_cache_manager = py_executor.resource_manager.get_resource_manager(
|
||||
ResourceManagerType.KV_CACHE_MANAGER)
|
||||
|
||||
# Verify both PyExecutor and KVCacheManager have execution_stream
|
||||
assert py_executor.execution_stream is not None, \
|
||||
"PyExecutor should have an execution_stream"
|
||||
assert kv_cache_manager is not None, \
|
||||
"KVCacheManager should exist in resource_manager"
|
||||
assert hasattr(kv_cache_manager, '_stream'), \
|
||||
"KVCacheManager should have _stream attribute"
|
||||
|
||||
# Verify they share the same CUDA stream pointer
|
||||
assert py_executor.execution_stream.cuda_stream == kv_cache_manager._stream.cuda_stream, \
|
||||
f"PyExecutor.execution_stream ({py_executor.execution_stream.cuda_stream}) " \
|
||||
f"should have the same cuda_stream pointer as KVCacheManager._stream ({kv_cache_manager._stream.cuda_stream})"
|
||||
|
||||
# Verify they are the exact same stream object
|
||||
assert py_executor.execution_stream is kv_cache_manager._stream, \
|
||||
"PyExecutor.execution_stream and KVCacheManager._stream should be the exact same stream object"
|
||||
|
||||
py_executor.shutdown()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user