mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[refactor] Clean up drafter/resource manager creation logic (#5805)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
parent
e0836f9ca9
commit
fa34cb7234
@ -360,18 +360,19 @@ def create_py_executor(
|
||||
if estimating_kv_cache else _ExecutorCreationStage.KV_CACHE):
|
||||
kv_cache_creator.build_managers(resources)
|
||||
|
||||
# Drafter for speculative decoding
|
||||
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER):
|
||||
drafter = get_spec_drafter(model_engine)
|
||||
|
||||
# Resource managers for speculative decoding
|
||||
# For user-specified drafters, use extra_resource_managers in PyTorchBackend config
|
||||
# to provide a resource manager if required.
|
||||
spec_resource_manager = get_spec_resource_manager(model_engine,
|
||||
draft_model_engine,
|
||||
drafter)
|
||||
draft_model_engine)
|
||||
if spec_resource_manager is not None:
|
||||
resources[
|
||||
ResourceManagerType.SPEC_RESOURCE_MANAGER] = spec_resource_manager
|
||||
|
||||
# Drafter for speculative decoding
|
||||
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER):
|
||||
drafter = get_spec_drafter(model_engine, spec_resource_manager)
|
||||
|
||||
with mem_monitor.observe_creation_stage(
|
||||
_ExecutorCreationStage.INIT_EXTRA_RESOURCES
|
||||
if estimating_kv_cache else _ExecutorCreationStage.EXTRA_RESOURCES):
|
||||
|
||||
@ -1,18 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from ..pyexecutor.resource_manager import BaseResourceManager
|
||||
from ..pyexecutor.scheduler import ScheduledRequests
|
||||
|
||||
|
||||
class Drafter(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spec_resource_manager: Optional[BaseResourceManager] = None,
|
||||
):
|
||||
self.spec_resource_manager = spec_resource_manager
|
||||
|
||||
@abstractmethod
|
||||
def prepare_draft_tokens(
|
||||
self,
|
||||
|
||||
@ -167,8 +167,8 @@ class NGramDrafter(Drafter):
|
||||
ngram_pool_manager: NGramPoolManager = None,
|
||||
):
|
||||
assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool."
|
||||
super().__init__(spec_resource_manager=ngram_pool_manager)
|
||||
self.max_draft_len = spec_config.max_draft_len
|
||||
self.spec_resource_manager = ngram_pool_manager
|
||||
|
||||
def prepare_draft_tokens(
|
||||
self,
|
||||
|
||||
@ -55,9 +55,7 @@ def get_spec_metadata(spec_config,
|
||||
return None
|
||||
|
||||
|
||||
def get_spec_resource_manager(model_engine,
|
||||
draft_model_engine=None,
|
||||
drafter=None):
|
||||
def get_spec_resource_manager(model_engine, draft_model_engine=None):
|
||||
spec_config = model_engine.spec_config
|
||||
if spec_config is None:
|
||||
return None
|
||||
@ -93,9 +91,10 @@ def get_spec_resource_manager(model_engine,
|
||||
max_seq_len,
|
||||
max_num_tokens,
|
||||
)
|
||||
if spec_dec_mode.is_ngram() or spec_dec_mode.is_user_provided():
|
||||
assert drafter is not None, "Drafter is required for ngram or user provided speculative decoding."
|
||||
return drafter.spec_resource_manager
|
||||
if spec_dec_mode.is_ngram():
|
||||
return NGramPoolManager(spec_config, max_num_requests)
|
||||
if spec_dec_mode.is_user_provided():
|
||||
return spec_config.resource_manager
|
||||
return None
|
||||
|
||||
|
||||
@ -113,14 +112,12 @@ def get_spec_decoder(sampler_args: TorchSampler.Args,
|
||||
f"Unsupported speculative decoding mode: {spec_config.spec_dec_mode}")
|
||||
|
||||
|
||||
def get_spec_drafter(model_engine):
|
||||
def get_spec_drafter(model_engine, spec_resource_manager):
|
||||
spec_config = model_engine.spec_config
|
||||
max_num_requests = model_engine.batch_size
|
||||
if spec_config is None:
|
||||
return None
|
||||
if spec_config.spec_dec_mode.is_ngram():
|
||||
return NGramDrafter(spec_config,
|
||||
NGramPoolManager(spec_config, max_num_requests))
|
||||
return NGramDrafter(spec_config, spec_resource_manager)
|
||||
if spec_config.spec_dec_mode.is_user_provided():
|
||||
return spec_config.drafter
|
||||
return None
|
||||
|
||||
@ -354,8 +354,9 @@ class EagleDecodingConfig(DecodingBaseConfig):
|
||||
|
||||
|
||||
class UserProvidedDecodingConfig(DecodingBaseConfig):
|
||||
# Type should be Drafter, but it leads to circular import
|
||||
drafter: object
|
||||
# Cannot use real type annotations due to circular imports
|
||||
drafter: object # Type is Drafter
|
||||
resource_manager: object = None # Type is Optional[ResourceManager]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user