fix: improve PyExecutor resource allocations (#4299)

chore: restore symmetry of worker start/shutdown
chore: fix return type of cal_max_tokens
chore: type some more return values
fix: free resources before re-claiming

Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
ixlmar 2025-05-16 17:28:10 +02:00 committed by GitHub
parent 7b19acfab1
commit 13b61405e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 14 deletions

View File

@ -1,6 +1,7 @@
import math
import random
from collections.abc import Iterable
from typing import Optional
import torch
@ -75,7 +76,8 @@ def get_fraction_from_executor_config(executor_config):
def cal_max_tokens(peak_memory, total_gpu_memory, fraction, model_config,
draft_model_config, mapping: Mapping, alloc_kv_tokens: int):
draft_model_config, mapping: Mapping,
alloc_kv_tokens: int) -> int:
model_kv_size_per_token = get_cache_size_per_token(model_config, mapping)
draft_kv_size_per_token = get_cache_size_per_token(
draft_model_config, mapping) if draft_model_config is not None else 0
@ -135,12 +137,11 @@ def get_token_num_for_estimation(executor_config, model_config):
return None
def estimate_max_kv_cache_tokens(py_executor: PyExecutor,
model_engine: PyTorchModelEngine,
executor_config: ExecutorConfig,
mapping: Mapping, origin_seq_len: int,
ctx_chunk_config,
draft_model_engine: PyTorchModelEngine):
def estimate_max_kv_cache_tokens(
py_executor: PyExecutor, model_engine: PyTorchModelEngine,
executor_config: ExecutorConfig, mapping: Mapping, origin_seq_len: int,
ctx_chunk_config,
draft_model_engine: PyTorchModelEngine) -> Optional[int]:
# TODO: support CP by generating dummy requests for it.
if 'cp_type' in mapping.cp_config:
return executor_config.max_num_tokens
@ -198,7 +199,7 @@ def estimate_max_kv_cache_tokens(py_executor: PyExecutor,
model_engine.model.model_config, draft_model_config, mapping,
kv_stats.max_num_blocks * kv_stats.tokens_per_block)
if kv_cache_max_tokens_in is not None and kv_cache_max_tokens is not None:
if kv_cache_max_tokens_in is not None:
kv_cache_max_tokens = min(kv_cache_max_tokens, kv_cache_max_tokens_in)
logger.info(f"Estimated max tokens in KV cache : {kv_cache_max_tokens}")
@ -209,8 +210,7 @@ def estimate_max_kv_cache_tokens(py_executor: PyExecutor,
"kv_cache_manager").shutdown()
py_executor.is_warmup = False
if py_executor.dist.mapping.rank == 0:
py_executor.shutdown()
py_executor.shutdown()
return kv_cache_max_tokens
@ -299,6 +299,7 @@ def create_kv_cache_manager(model_engine: PyTorchModelEngine, mapping: Mapping,
if model_engine.kv_cache_manager_key == KV_CACHE_MANAGER_KEY:
executor_config.max_seq_len = kv_cache_manager.max_seq_len
assert kv_cache_manager is not None
return kv_cache_manager
@ -312,7 +313,7 @@ def create_py_executor_instance(dist,
draft_model_engine,
start_worker,
decoder,
lora_config: LoraConfig = None):
lora_config: LoraConfig = None) -> PyExecutor:
kv_cache_manager = resources.get(KV_CACHE_MANAGER_KEY, None)
spec_config = model_engine.spec_config

View File

@ -19,12 +19,13 @@ from .config import PyTorchConfig
from .config_utils import is_mla
from .model_engine import (DRAFT_KV_CACHE_MANAGER_KEY, KV_CACHE_MANAGER_KEY,
PyTorchModelEngine)
from .py_executor import PyExecutor
def create_py_executor(executor_config: ExecutorConfig,
checkpoint_dir: str = None,
engine_dir: str = None,
lora_config: LoraConfig = None):
lora_config: LoraConfig = None) -> PyExecutor:
if executor_config.pytorch_backend_config is None:
executor_config.pytorch_backend_config = PyTorchConfig()
@ -197,23 +198,27 @@ def create_py_executor(executor_config: ExecutorConfig,
origin_seq_len, ctx_chunk_config, draft_model_engine)
# This may be None if no max number tokens set and enable cp.
if kv_cache_max_tokens is not None:
del py_executor # free before constructing new
del kv_cache_manager # free before constructing new
executor_config.kv_cache_config.max_tokens = kv_cache_max_tokens
kv_cache_manager = create_kv_cache_manager(model_engine, mapping,
executor_config)
resources[KV_CACHE_MANAGER_KEY] = kv_cache_manager
if model_engine.attn_metadata is not None and kv_cache_manager is not None:
if model_engine.attn_metadata is not None:
if pytorch_backend_config.use_cuda_graph:
model_engine._release_cuda_graphs()
del model_engine.attn_metadata
model_engine.attn_metadata = None
if draft_model_engine is not None:
del draft_kv_cache_manager # free before constructing new
draft_kv_cache_manager = create_kv_cache_manager(
draft_model_engine, mapping, executor_config)
resources[DRAFT_KV_CACHE_MANAGER_KEY] = draft_kv_cache_manager
if draft_model_engine.attn_metadata is not None and draft_kv_cache_manager is not None:
if draft_model_engine.attn_metadata is not None:
if pytorch_backend_config.use_cuda_graph:
draft_model_engine._release_cuda_graphs()
del draft_model_engine.attn_metadata