mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Bugfix] Sync block_size from EngineCore to frontend for hybrid Mamba… (#42967)
Signed-off-by: Amit Gruner <agruner@crusoe.ai> Co-authored-by: Amit Gruner <agruner@crusoe.ai> Co-authored-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
This commit is contained in:
@@ -27,12 +27,13 @@ from vllm.platforms import current_platform
|
||||
from vllm.pooling_params import LateInteractionParams, PoolingParams
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine import EngineCoreReadyResponse, EngineCoreRequest
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
from vllm.v1.engine.core_client import (
|
||||
AsyncMPClient,
|
||||
DPLBAsyncMPClient,
|
||||
EngineCoreClient,
|
||||
MPClient,
|
||||
SyncMPClient,
|
||||
)
|
||||
from vllm.v1.engine.utils import CoreEngineProcManager
|
||||
@@ -236,6 +237,30 @@ def test_dplb_non_late_interaction_still_uses_lb():
|
||||
assert client.lb_engines[1][0] == 1
|
||||
|
||||
|
||||
def test_apply_ready_response_syncs_block_size():
|
||||
import msgspec
|
||||
|
||||
client = object.__new__(MPClient)
|
||||
client.vllm_config = SimpleNamespace(
|
||||
cache_config=SimpleNamespace(block_size=16, num_gpu_blocks=0),
|
||||
model_config=SimpleNamespace(max_model_len=8192),
|
||||
)
|
||||
client.stats_update_address = None
|
||||
|
||||
payload = msgspec.msgpack.encode(
|
||||
EngineCoreReadyResponse(
|
||||
max_model_len=8192,
|
||||
num_gpu_blocks=100,
|
||||
block_size=1056,
|
||||
dp_stats_address=None,
|
||||
dtype="bfloat16",
|
||||
vllm_version="test",
|
||||
)
|
||||
)
|
||||
client._apply_ready_response(payload)
|
||||
assert client.vllm_config.cache_config.block_size == 1056
|
||||
|
||||
|
||||
def loop_until_done(client: EngineCoreClient, outputs: dict):
|
||||
while True:
|
||||
engine_core_outputs = client.get_output().outputs
|
||||
|
||||
@@ -74,6 +74,7 @@ class EngineCoreReadyResponse:
|
||||
|
||||
max_model_len: int
|
||||
num_gpu_blocks: int
|
||||
block_size: int
|
||||
dp_stats_address: str | None
|
||||
dtype: str
|
||||
vllm_version: str
|
||||
|
||||
@@ -1462,6 +1462,7 @@ class EngineCoreProc(EngineCore):
|
||||
ready_response = EngineCoreReadyResponse(
|
||||
max_model_len=self.vllm_config.model_config.max_model_len,
|
||||
num_gpu_blocks=self.vllm_config.cache_config.num_gpu_blocks or 0,
|
||||
block_size=self.vllm_config.cache_config.block_size,
|
||||
dp_stats_address=self.frontend_stats_publish_address,
|
||||
dtype=str(self.vllm_config.model_config.dtype).removeprefix("torch."),
|
||||
vllm_version=VLLM_VERSION,
|
||||
|
||||
@@ -713,6 +713,10 @@ class MPClient(EngineCoreClient):
|
||||
num_gpu_blocks += response.num_gpu_blocks
|
||||
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
|
||||
# Sync block_size: may be enlarged by _align_hybrid_block_size in the
|
||||
# worker for hybrid Mamba models.
|
||||
vllm_config.cache_config.block_size = response.block_size
|
||||
|
||||
# In external DP LB mode, the coordinator address that the
|
||||
# front-end procs connect to is obtained by each engine via it's
|
||||
# initial handshake with the rank 0 front-end.
|
||||
|
||||
Reference in New Issue
Block a user