[https://nvbugs/5717993][fix] Add execution_stream across PyExecutor, KVCacheManager, PeftCacheManager to ensure proper CUDA stream synchronization between KV cache transfer operations and model forward kernels. (#10060)

Signed-off-by: SimengLiu-nv <simengl@nvidia.com>
This commit is contained in:
Simeng Liu 2025-12-31 09:22:54 -08:00 committed by GitHub
parent 0d2e2718ce
commit 84d107b2f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 321 additions and 36 deletions

View File

@ -77,6 +77,7 @@ class KvCacheCreator:
speculative_config: SpeculativeConfig,
sparse_attention_config: SparseAttentionConfig,
profiling_stage_data: Optional[dict],
execution_stream: Optional[torch.cuda.Stream] = None,
):
self._model_engine = model_engine
self._draft_model_engine = draft_model_engine
@ -97,6 +98,7 @@ class KvCacheCreator:
self._profiling_stage_data = profiling_stage_data
self._kv_cache_manager_cls = get_kv_cache_manager_cls(
model_engine.model.model_config)
self._execution_stream = execution_stream
def _get_kv_size_per_token(self):
model_config = self._model_engine.model.model_config
@ -474,6 +476,7 @@ class KvCacheCreator:
max_beam_width=self._max_beam_width,
kv_connector_manager=self._kv_connector_manager,
estimating_kv_cache=estimating_kv_cache,
execution_stream=self._execution_stream,
)
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to self
@ -527,14 +530,20 @@ class KvCacheCreator:
def _create_kv_cache_manager(
model_engine: PyTorchModelEngine, kv_cache_manager_cls,
mapping: Mapping, kv_cache_config: KvCacheConfig, tokens_per_block: int,
max_seq_len: int, max_batch_size: int,
model_engine: PyTorchModelEngine,
kv_cache_manager_cls,
mapping: Mapping,
kv_cache_config: KvCacheConfig,
tokens_per_block: int,
max_seq_len: int,
max_batch_size: int,
spec_config: Optional[SpeculativeConfig],
sparse_attn_config: Optional[SparseAttentionConfig],
max_num_tokens: int, max_beam_width: int,
max_num_tokens: int,
max_beam_width: int,
kv_connector_manager: Optional[KvCacheConnectorManager],
estimating_kv_cache: bool) -> KVCacheManager:
estimating_kv_cache: bool,
execution_stream: Optional[torch.cuda.Stream] = None) -> KVCacheManager:
"""
Returns:
A KVCacheManager instance for the given model_engine
@ -580,6 +589,7 @@ def _create_kv_cache_manager(
if not estimating_kv_cache else None,
sparse_attn_config=sparse_attn_config,
is_estimating_kv_cache=estimating_kv_cache,
execution_stream=execution_stream,
)
elif is_nemotron_hybrid(config):
if max_beam_width > 1:
@ -623,6 +633,7 @@ def _create_kv_cache_manager(
dtype=kv_cache_dtype,
spec_config=spec_config,
is_estimating_kv_cache=estimating_kv_cache,
execution_stream=execution_stream,
)
elif is_qwen3_next(config):
if max_beam_width > 1:
@ -672,6 +683,7 @@ def _create_kv_cache_manager(
dtype=kv_cache_dtype,
spec_config=spec_config,
is_estimating_kv_cache=estimating_kv_cache,
execution_stream=execution_stream,
)
else:
# NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_from_cpp in KVCahceManager
@ -700,6 +712,7 @@ def _create_kv_cache_manager(
if not estimating_kv_cache else None,
sparse_attn_config=sparse_attn_config,
is_estimating_kv_cache=estimating_kv_cache,
execution_stream=execution_stream,
)
return kv_cache_manager
@ -727,6 +740,7 @@ def create_py_executor_instance(
scheduler_config: Optional[SchedulerConfig] = None,
cache_transceiver_config: Optional[CacheTransceiverConfig] = None,
virtual_memory_pools: Optional[dict] = None,
execution_stream: Optional[torch.cuda.Stream] = None,
) -> PyExecutor:
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
@ -813,6 +827,7 @@ def create_py_executor_instance(
lora_config=lora_config,
model_config=model_binding_config,
world_config=world_config,
execution_stream=execution_stream,
)
resources[ResourceManagerType.PEFT_CACHE_MANAGER] = peft_cache_manager
model_engine.set_lora_model_config(
@ -875,7 +890,8 @@ def create_py_executor_instance(
kv_connector_manager=kv_connector_manager,
max_seq_len=max_seq_len,
peft_cache_config=peft_cache_config,
virtual_memory_pools=virtual_memory_pools)
virtual_memory_pools=virtual_memory_pools,
execution_stream=execution_stream)
def create_torch_sampler_args(

View File

@ -197,6 +197,7 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
dtype: DataType = DataType.HALF,
spec_config: Optional["DecodingBaseConfig"] = None,
is_estimating_kv_cache: bool = False,
execution_stream: Optional[torch.cuda.Stream] = None,
) -> None:
# mamba hybrid cache requires block reuse to be disabled in KV cache config
@ -234,6 +235,7 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
spec_config=spec_config,
layer_mask=layer_mask,
is_estimating_kv_cache=is_estimating_kv_cache,
execution_stream=execution_stream,
)
def prepare_resources(self, scheduled_batch: ScheduledRequests):

View File

@ -136,11 +136,22 @@ class PyExecutor:
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
max_seq_len: Optional[int] = None,
peft_cache_config: Optional[PeftCacheConfig] = None,
virtual_memory_pools: Optional[dict] = None):
virtual_memory_pools: Optional[dict] = None,
execution_stream: Optional[torch.cuda.Stream] = None):
super(PyExecutor, self).__init__()
self.device_id = torch.cuda.current_device()
self.global_rank = dist.rank
# Store the execution stream for model forward operations.
# This stream is used for proper synchronization with KVCacheTransferManager.
# execution_stream can be provided by create_py_executor
# Create a new stream if none provided
self.execution_stream = execution_stream if execution_stream is not None else torch.cuda.Stream(
)
logger.info(
f"[PyExecutor] execution_stream initialized: {self.execution_stream}. "
)
self.peft_cache_config = peft_cache_config
self.iter_counter = 0
@ -245,10 +256,19 @@ class PyExecutor:
self.inflight_req_ids = ReqIdsSet()
# During warmup, we don't enable the profiler
# Run warmup on the execution_stream for proper synchronization with
# KVCacheTransferManager's onboard/offload operations.
self.is_warmup = True
self.model_engine.warmup(self.resource_manager)
if self.draft_model_engine is not None:
self.draft_model_engine.warmup(self.resource_manager)
self.execution_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.execution_stream):
self.model_engine.warmup(self.resource_manager)
if self.draft_model_engine is not None:
self.draft_model_engine.warmup(self.resource_manager)
# Ensure the default stream waits for execution_stream to complete
# before subsequent operations.
torch.cuda.current_stream().wait_stream(self.execution_stream)
self.is_warmup = False
self.is_shutdown = False
@ -2231,10 +2251,19 @@ class PyExecutor:
a.py_return_context_logits
for a in scheduled_requests.context_requests)
cache_indirection_buffer = self.sampler.get_cache_indirection()
outputs = forward(scheduled_requests, self.resource_manager,
new_tensors_device, gather_context_logits,
cache_indirection_buffer,
num_accepted_tokens_device)
# Run model forward on the execution stream for proper synchronization
# with KVCacheTransferManager's onboard/offload operations.
self.execution_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.execution_stream):
outputs = forward(scheduled_requests, self.resource_manager,
new_tensors_device, gather_context_logits,
cache_indirection_buffer,
num_accepted_tokens_device)
# Ensure the default stream waits for execution_stream to complete
# before downstream operations use the outputs.
torch.cuda.current_stream().wait_stream(self.execution_stream)
self._kv_connector_wait_for_save()

View File

@ -601,6 +601,13 @@ def create_py_executor(
resources = {}
estimating_kv_cache = False
kv_cache_creator = None
# Create the execution stream for model forward operations
# for proper synchronization with KVCacheTransferManager's onboard/offload operations.
execution_stream = torch.cuda.Stream()
logger.info(
f"[create_py_executor] Created execution_stream: {execution_stream}")
if model_engine.model.model_config.is_generation:
#NOTE: non-generation models do not have kv cache
kv_cache_creator = KvCacheCreator(
@ -619,6 +626,7 @@ def create_py_executor(
speculative_config=spec_config,
profiling_stage_data=profiling_stage_data,
sparse_attention_config=sparse_attention_config,
execution_stream=execution_stream,
)
estimating_kv_cache = kv_cache_creator.try_prepare_estimation()
with allocation_scope(
@ -676,6 +684,7 @@ def create_py_executor(
scheduler_config=scheduler_config,
cache_transceiver_config=cache_transceiver_config,
virtual_memory_pools=vm_pools if not estimating_kv_cache else None,
execution_stream=execution_stream,
)
# Originally, peft_cache_config might be mutated inside
# create_py_executor_instance. Restore it here.
@ -736,6 +745,7 @@ def create_py_executor(
scheduler_config=scheduler_config,
cache_transceiver_config=cache_transceiver_config,
virtual_memory_pools=vm_pools,
execution_stream=execution_stream,
)
_adjust_torch_mem_fraction()

View File

@ -176,6 +176,7 @@ class KVCacheManager(BaseResourceManager):
indexer_k_cache_quant_block_size: int = 128,
indexer_k_cache_index_head_dim: int = 0,
is_estimating_kv_cache: bool = False,
execution_stream: Optional[torch.cuda.Stream] = None,
**kwargs,
) -> None:
self.mapping = mapping
@ -351,9 +352,13 @@ class KVCacheManager(BaseResourceManager):
# Set up temp_attention_window_inputs
temp_attention_window_inputs = self._set_temp_attention_window_inputs()
# Note that this stream is unused for now. Will be used for copying to host
# when that feature is enabled.
self._stream = torch.cuda.Stream()
# Use the provided execution stream for proper synchronization with KVCacheTransferManager.
# The execution stream is the stream where model forward kernels run, and KVCacheTransferManager
# needs to synchronize with it for onboard/offload operations.
# If no execution stream is provided, create a new one (for backward compatibility).
self._stream = execution_stream if execution_stream is not None else torch.cuda.Stream(
)
logger.info(f"[KVCacheManager] execution_stream: {self._stream}")
kwargs = {
'num_kv_heads_per_layer': self.num_kv_heads_per_layer,
'size_per_head': head_dim,
@ -365,7 +370,7 @@ class KVCacheManager(BaseResourceManager):
'temp_attention_window_inputs': temp_attention_window_inputs,
'dtype': dtype,
'sink_token_length': sink_token_length,
'stream': self._stream.cuda_stream,
'stream': self._stream.cuda_stream, # Pass to BufferManager
'max_sequence_length': max_seq_len,
'enable_block_reuse': kv_cache_config.enable_block_reuse,
'onboard_blocks': kv_cache_config.onboard_blocks,
@ -1442,7 +1447,8 @@ class PeftCacheManager(BaseResourceManager):
peft_cache_config: PeftCacheConfig,
lora_config: LoraConfig,
model_config: ModelConfigCpp,
world_config: WorldConfig | None = None):
world_config: WorldConfig | None = None,
execution_stream: Optional[torch.cuda.Stream] = None):
import tensorrt_llm.bindings as _tb
peft_cache_config = peft_cache_config._to_pybind()
@ -1467,8 +1473,12 @@ class PeftCacheManager(BaseResourceManager):
world_config = _tb.WorldConfig()
BufferManager = tensorrt_llm.bindings.internal.runtime.BufferManager
buffer_manager = BufferManager(torch.cuda.current_stream().cuda_stream,
True)
buffer_manager_stream = execution_stream.cuda_stream if execution_stream is not None else torch.cuda.current_stream(
).cuda_stream
buffer_manager = BufferManager(buffer_manager_stream, True)
logger.info(
f"[PeftCacheManager] buffer_manager_stream: {buffer_manager_stream}"
)
self.impl = PeftCacheManagerCpp(config=peft_cache_manager_config,
model_config=model_config,
world_config=world_config,

View File

@ -380,23 +380,43 @@ class LmEvalEvaluator(Evaluator):
@contextmanager
def _patch_lm_eval(self):
if self.dataset_path is None:
yield
return
from pathlib import Path
import lm_eval
self._task_config_post_init = lm_eval.api.task.TaskConfig.__post_init__
import lm_eval.tasks
def _patched(task_config, *args, **kwargs):
task_config.dataset_path = self.dataset_path
self._task_config_post_init(task_config, *args, **kwargs)
# Patch Path.relative_to to handle custom task paths outside lm_eval/tasks
# This is needed with lm_eval>=0.4.9.2 with new function pretty_print_task (a local function inside
# get_task_dict) calls yaml_path.relative_to(lm_eval_tasks_path) which fails
# when the yaml is from tensorrt_llm/evaluate/lm_eval_tasks
original_relative_to = Path.relative_to
lm_eval.api.task.TaskConfig.__post_init__ = _patched
def _patched_relative_to(self, other, *args, **kwargs):
try:
return original_relative_to(self, other, *args, **kwargs)
except ValueError:
# Return absolute path if relative_to fails (path not under base)
return self
Path.relative_to = _patched_relative_to
# Optionally patch dataset_path if provided
original_post_init = None
if self.dataset_path is not None:
original_post_init = lm_eval.api.task.TaskConfig.__post_init__
def _patched_post_init(task_config, *args, **kwargs):
task_config.dataset_path = self.dataset_path
original_post_init(task_config, *args, **kwargs)
lm_eval.api.task.TaskConfig.__post_init__ = _patched_post_init
try:
yield
finally:
lm_eval.api.task.TaskConfig.__post_init__ = self._task_config_post_init
Path.relative_to = original_relative_to
if original_post_init is not None:
lm_eval.api.task.TaskConfig.__post_init__ = original_post_init
def generate_samples(self) -> Iterable[tuple]:
raise NotImplementedError()

View File

@ -15,7 +15,8 @@ l0_a100:
tests:
- unittest/llmapi/test_llm_pytorch.py
- unittest/llmapi/test_mpi_session.py ISOLATION
- unittest/llmapi/test_memory_profiling.py # profile kvcache for vision encoder
- unittest/llmapi/test_memory_profiling.py::test_profile_kvcache # profile kvcache for vision encoder
- unittest/llmapi/test_memory_profiling.py::test_pyexecutor_and_kvcache_share_execution_stream # test that PyExecutor and KVCacheManager share the same execution_stream
- unittest/trt/model_api/test_model_quantization.py
# executor
- unittest/executor/test_base_worker.py

View File

@ -76,9 +76,10 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_dummy_load_format
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.0] TIMEOUT (90)
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.5] TIMEOUT (90)
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9] TIMEOUT (90)
# Waive known failures in https://nvbugs/5774869
# - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.0] TIMEOUT (90)
# - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.5] TIMEOUT (90)
# - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9] TIMEOUT (90)
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=False-eagle3_one_model=False]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=True-eagle3_one_model=True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=False-eagle3_one_model=True]

View File

@ -298,7 +298,7 @@ full:L40S/accuracy/test_cli_flow.py::TestGpt2::test_weight_streaming_plugin SKIP
full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp1pp2] SKIP (https://nvbugs/5596337)
accuracy/test_llm_api.py::TestMixtral8x7BInstruct::test_awq_tp2 SKIP (https://nvbugs/5598847)
examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-3.5-MoE-instruct] SKIP (https://nvbugs/5465143)
unittest/llmapi/test_memory_profiling.py SKIP (https://nvbugs/5580781)
unittest/llmapi/test_memory_profiling.py::test_profile_kvcache SKIP (https://nvbugs/5580781)
triton_server/test_triton.py::test_llava[llava] SKIP (https://nvbugs/5547414)
full:RTX/accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype SKIP (https://nvbugs/5569696)
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-cutlass-auto] SKIP (https://nvbugs/5596343)

View File

@ -101,7 +101,8 @@ def _create_request(num_tokens, req_id: int):
return result
def create_model_engine_and_kvcache(llm_args: TorchLlmArgs = None):
def create_model_engine_and_kvcache(llm_args: TorchLlmArgs = None,
execution_stream: torch.cuda.Stream = None):
tokens_per_block = 1
max_tokens = 258 # Atleast 1 more than the max seq len
num_layers = 1
@ -135,6 +136,7 @@ def create_model_engine_and_kvcache(llm_args: TorchLlmArgs = None):
max_batch_size=batch_size,
mapping=mapping,
dtype=tensorrt_llm.bindings.DataType.HALF,
execution_stream=execution_stream,
)
return model_engine, kv_cache_manager
@ -480,6 +482,41 @@ class PyTorchModelEngineTestCase(unittest.TestCase):
actual_seq_lens = attn_metadata.seq_lens.cpu().tolist()
self.assertEqual(actual_seq_lens, expected_seq_lens)
def test_kv_cache_manager_with_execution_stream(self):
"""Test that KVCacheManager uses the provided execution_stream.
"""
# Create a dedicated execution stream
execution_stream = torch.cuda.Stream()
model_engine, kv_cache_manager = create_model_engine_and_kvcache(
execution_stream=execution_stream)
# Verify the KVCacheManager uses the provided execution stream
self.assertEqual(
kv_cache_manager._stream.cuda_stream, execution_stream.cuda_stream,
"KVCacheManager should use the provided execution_stream")
resource_manager = ResourceManager(
{ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})
prompt_len = 32
requests = [_create_request(prompt_len, 0)]
batch = ScheduledRequests()
batch.context_requests = requests
batch.generation_requests = []
kv_cache_manager.prepare_resources(batch)
with torch.cuda.stream(execution_stream):
model_engine.forward(batch, resource_manager)
# Verify the stream is still the same after forward pass
self.assertEqual(
kv_cache_manager._stream.cuda_stream, execution_stream.cuda_stream,
"KVCacheManager should still use the provided execution_stream after forward"
)
kv_cache_manager.shutdown()
if __name__ == "__main__":
unittest.main()

View File

@ -764,6 +764,93 @@ class TestResourceManager(unittest.TestCase):
)
kv_cache_manager.free_resources(req3)
def test_kv_cache_manager_with_execution_stream(self):
"""
Test that KVCacheManager uses the provided execution_stream.
"""
# Create a dedicated execution stream
execution_stream = torch.cuda.Stream()
kv_cache_config = KvCacheConfig(
free_gpu_memory_fraction=0.1,
max_tokens=256,
)
# Create KVCacheManager with the execution stream
kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config,
kv_cache_type=tensorrt_llm.bindings.internal.batch_manager.
CacheType.SELF,
num_layers=2,
num_kv_heads=2,
head_dim=128,
tokens_per_block=64,
max_seq_len=1024,
max_batch_size=1,
mapping=Mapping(),
execution_stream=execution_stream,
)
# Verify the KVCacheManager uses the provided execution stream
# The internal stream should be the same as the execution stream we provided
self.assertEqual(
kv_cache_manager._stream.cuda_stream, execution_stream.cuda_stream,
"KVCacheManager should use the provided execution_stream")
kv_cache_manager.shutdown()
def test_kv_cache_manager_without_execution_stream(self):
"""Test that KVCacheManager creates its own stream when no execution_stream is provided.
This verifies backward compatibility.
"""
kv_cache_config = KvCacheConfig(
free_gpu_memory_fraction=0.1,
max_tokens=256,
)
# Create KVCacheManager without providing an execution stream
kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config,
kv_cache_type=tensorrt_llm.bindings.internal.batch_manager.
CacheType.SELF,
num_layers=2,
num_kv_heads=2,
head_dim=128,
tokens_per_block=64,
max_seq_len=1024,
max_batch_size=1,
mapping=Mapping(),
)
# Verify the KVCacheManager creates its own stream
self.assertIsNotNone(
kv_cache_manager._stream,
"KVCacheManager should create its own stream when none is provided")
# The stream should not be the default stream (0)
self.assertNotEqual(kv_cache_manager._stream.cuda_stream, 0,
"KVCacheManager should not use the default stream")
kv_cache_manager.shutdown()
def test_peft_cache_manager_with_execution_stream(self):
"""Test that PeftCacheManager uses the provided execution_stream.
"""
peft_cache_config = self.create_peft_cache_config()
execution_stream = torch.cuda.Stream()
# Create PeftCacheManager with execution_stream
peft_cache_manager = PeftCacheManager(
peft_cache_config=peft_cache_config,
lora_config=LoraConfig(),
model_config=self.model_config,
execution_stream=execution_stream,
)
# The PeftCacheManager should be created successfully with the provided stream
self.assertTrue(peft_cache_manager.impl.enabled)
if __name__ == "__main__":
unittest.main()

View File

@ -3,6 +3,7 @@ import torch
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
create_py_executor
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
DynamicBatchConfig, SchedulerConfig)
from tensorrt_llm.llmapi.llm_args import (CudaGraphConfig, KvCacheConfig,
@ -75,3 +76,74 @@ def test_profile_kvcache():
torch.cuda.empty_cache()
assert vlm_activation_bytes_with_mm_reqs > vlm_activation_bytes_no_mm_reqs, f"Activation bytes should be higher with mm reqs, but got {vlm_activation_bytes_with_mm_reqs} for mm reqs and {vlm_activation_bytes_no_mm_reqs} without mm reqs"
def test_pyexecutor_and_kvcache_share_execution_stream():
"""Test that PyExecutor and KVCacheManager share the same execution_stream.
The execution_stream is created once in create_py_executor and passed to:
- KVCacheManager (via KvCacheCreator -> _create_kv_cache_manager)
- PyExecutor (via create_py_executor_instance)
Both components must use the same stream for proper synchronization.
"""
# Use a simple model for testing
MODEL = "llama-3.2-models/Llama-3.2-1B-Instruct"
MODEL_PATH = get_model_path(MODEL)
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
free_gpu_memory_fraction=0.5)
build_config = BuildConfig(max_beam_width=1, max_num_tokens=4096)
scheduler_config = SchedulerConfig(
capacity_scheduler_policy=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, )
backend = "pytorch"
llm_args = {
"model": MODEL,
"scheduler_config": scheduler_config,
"tokenizer": None,
"tensor_parallel_size": 1,
"pipeline_parallel_size": 1,
"moe_expert_parallel_size": None,
"gpus_per_node": 1,
"trust_remote_code": False,
"max_batch_size": build_config.max_batch_size,
"max_num_tokens": build_config.max_num_tokens,
"max_beam_width": build_config.max_beam_width,
"max_seq_len": build_config.max_seq_len,
"kv_cache_config": kv_cache_config,
"backend": backend,
"num_postprocess_workers": 0,
"postprocess_tokenizer_dir": MODEL,
"reasoning_parser": None,
"fail_fast_on_attention_window_too_large": False,
}
torchllm_args = TorchLlmArgs(**llm_args)
py_executor = create_py_executor(llm_args=torchllm_args,
checkpoint_dir=MODEL_PATH)
# Get the KVCacheManager from the resource manager
kv_cache_manager = py_executor.resource_manager.get_resource_manager(
ResourceManagerType.KV_CACHE_MANAGER)
# Verify both PyExecutor and KVCacheManager have execution_stream
assert py_executor.execution_stream is not None, \
"PyExecutor should have an execution_stream"
assert kv_cache_manager is not None, \
"KVCacheManager should exist in resource_manager"
assert hasattr(kv_cache_manager, '_stream'), \
"KVCacheManager should have _stream attribute"
# Verify they share the same CUDA stream pointer
assert py_executor.execution_stream.cuda_stream == kv_cache_manager._stream.cuda_stream, \
f"PyExecutor.execution_stream ({py_executor.execution_stream.cuda_stream}) " \
f"should have the same cuda_stream pointer as KVCacheManager._stream ({kv_cache_manager._stream.cuda_stream})"
# Verify they are the exact same stream object
assert py_executor.execution_stream is kv_cache_manager._stream, \
"PyExecutor.execution_stream and KVCacheManager._stream should be the exact same stream object"
py_executor.shutdown()
torch.cuda.empty_cache()