[refactor] Clean up drafter/resource manager creation logic (#5805)

Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
Mike Iovine 2025-07-16 15:45:46 -04:00 committed by GitHub
parent e0836f9ca9
commit fa34cb7234
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 18 additions and 27 deletions

View File

@ -360,18 +360,19 @@ def create_py_executor(
if estimating_kv_cache else _ExecutorCreationStage.KV_CACHE):
kv_cache_creator.build_managers(resources)
# Drafter for speculative decoding
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER):
drafter = get_spec_drafter(model_engine)
# Resource managers for speculative decoding
# For user-specified drafters, use extra_resource_managers in PyTorchBackend config
# to provide a resource manager if required.
spec_resource_manager = get_spec_resource_manager(model_engine,
draft_model_engine,
drafter)
draft_model_engine)
if spec_resource_manager is not None:
resources[
ResourceManagerType.SPEC_RESOURCE_MANAGER] = spec_resource_manager
# Drafter for speculative decoding
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER):
drafter = get_spec_drafter(model_engine, spec_resource_manager)
with mem_monitor.observe_creation_stage(
_ExecutorCreationStage.INIT_EXTRA_RESOURCES
if estimating_kv_cache else _ExecutorCreationStage.EXTRA_RESOURCES):

View File

@ -1,18 +1,10 @@
from abc import ABC, abstractmethod
from typing import Optional
from ..pyexecutor.resource_manager import BaseResourceManager
from ..pyexecutor.scheduler import ScheduledRequests
class Drafter(ABC):
def __init__(
self,
spec_resource_manager: Optional[BaseResourceManager] = None,
):
self.spec_resource_manager = spec_resource_manager
@abstractmethod
def prepare_draft_tokens(
self,

View File

@ -167,8 +167,8 @@ class NGramDrafter(Drafter):
ngram_pool_manager: NGramPoolManager = None,
):
assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool."
super().__init__(spec_resource_manager=ngram_pool_manager)
self.max_draft_len = spec_config.max_draft_len
self.spec_resource_manager = ngram_pool_manager
def prepare_draft_tokens(
self,

View File

@ -55,9 +55,7 @@ def get_spec_metadata(spec_config,
return None
def get_spec_resource_manager(model_engine,
draft_model_engine=None,
drafter=None):
def get_spec_resource_manager(model_engine, draft_model_engine=None):
spec_config = model_engine.spec_config
if spec_config is None:
return None
@ -93,9 +91,10 @@ def get_spec_resource_manager(model_engine,
max_seq_len,
max_num_tokens,
)
if spec_dec_mode.is_ngram() or spec_dec_mode.is_user_provided():
assert drafter is not None, "Drafter is required for ngram or user provided speculative decoding."
return drafter.spec_resource_manager
if spec_dec_mode.is_ngram():
return NGramPoolManager(spec_config, max_num_requests)
if spec_dec_mode.is_user_provided():
return spec_config.resource_manager
return None
@ -113,14 +112,12 @@ def get_spec_decoder(sampler_args: TorchSampler.Args,
f"Unsupported speculative decoding mode: {spec_config.spec_dec_mode}")
def get_spec_drafter(model_engine):
def get_spec_drafter(model_engine, spec_resource_manager):
spec_config = model_engine.spec_config
max_num_requests = model_engine.batch_size
if spec_config is None:
return None
if spec_config.spec_dec_mode.is_ngram():
return NGramDrafter(spec_config,
NGramPoolManager(spec_config, max_num_requests))
return NGramDrafter(spec_config, spec_resource_manager)
if spec_config.spec_dec_mode.is_user_provided():
return spec_config.drafter
return None

View File

@ -354,8 +354,9 @@ class EagleDecodingConfig(DecodingBaseConfig):
class UserProvidedDecodingConfig(DecodingBaseConfig):
# Type should be Drafter, but it leads to circular import
drafter: object
# Cannot use real type annotations due to circular imports
drafter: object # Type is Drafter
resource_manager: object = None # Type is Optional[ResourceManager]
@classmethod
def from_dict(cls, data: dict):