From fa34cb723457c77292109a9e788681263dc69adf Mon Sep 17 00:00:00 2001 From: Mike Iovine Date: Wed, 16 Jul 2025 15:45:46 -0400 Subject: [PATCH] [refactor] Clean up drafter/resource manager creation logic (#5805) Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> --- .../_torch/pyexecutor/py_executor_creator.py | 13 +++++++------ tensorrt_llm/_torch/speculative/drafter.py | 8 -------- tensorrt_llm/_torch/speculative/ngram.py | 2 +- tensorrt_llm/_torch/speculative/utils.py | 17 +++++++---------- tensorrt_llm/llmapi/llm_args.py | 5 +++-- 5 files changed, 18 insertions(+), 27 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index b99037d8a0..09976cb512 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -360,18 +360,19 @@ def create_py_executor( if estimating_kv_cache else _ExecutorCreationStage.KV_CACHE): kv_cache_creator.build_managers(resources) - # Drafter for speculative decoding - with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER): - drafter = get_spec_drafter(model_engine) - # Resource managers for speculative decoding + # For user-specified drafters, use extra_resource_managers in PyTorchBackend config + # to provide a resource manager if required. spec_resource_manager = get_spec_resource_manager(model_engine, - draft_model_engine, - drafter) + draft_model_engine) if spec_resource_manager is not None: resources[ ResourceManagerType.SPEC_RESOURCE_MANAGER] = spec_resource_manager + # Drafter for speculative decoding + with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER): + drafter = get_spec_drafter(model_engine, spec_resource_manager) + with mem_monitor.observe_creation_stage( _ExecutorCreationStage.INIT_EXTRA_RESOURCES if estimating_kv_cache else _ExecutorCreationStage.EXTRA_RESOURCES): diff --git a/tensorrt_llm/_torch/speculative/drafter.py b/tensorrt_llm/_torch/speculative/drafter.py index d0f5a44d77..d99c5dd92d 100644 --- a/tensorrt_llm/_torch/speculative/drafter.py +++ b/tensorrt_llm/_torch/speculative/drafter.py @@ -1,18 +1,10 @@ from abc import ABC, abstractmethod -from typing import Optional -from ..pyexecutor.resource_manager import BaseResourceManager from ..pyexecutor.scheduler import ScheduledRequests class Drafter(ABC): - def __init__( - self, - spec_resource_manager: Optional[BaseResourceManager] = None, - ): - self.spec_resource_manager = spec_resource_manager - @abstractmethod def prepare_draft_tokens( self, diff --git a/tensorrt_llm/_torch/speculative/ngram.py b/tensorrt_llm/_torch/speculative/ngram.py index 1d015a58b9..57f3045e66 100644 --- a/tensorrt_llm/_torch/speculative/ngram.py +++ b/tensorrt_llm/_torch/speculative/ngram.py @@ -167,8 +167,8 @@ class NGramDrafter(Drafter): ngram_pool_manager: NGramPoolManager = None, ): assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool." - super().__init__(spec_resource_manager=ngram_pool_manager) self.max_draft_len = spec_config.max_draft_len + self.spec_resource_manager = ngram_pool_manager def prepare_draft_tokens( self, diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index 882dfdf924..667d1a14b0 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -55,9 +55,7 @@ def get_spec_metadata(spec_config, return None -def get_spec_resource_manager(model_engine, - draft_model_engine=None, - drafter=None): +def get_spec_resource_manager(model_engine, draft_model_engine=None): spec_config = model_engine.spec_config if spec_config is None: return None @@ -93,9 +91,10 @@ def get_spec_resource_manager(model_engine, max_seq_len, max_num_tokens, ) - if spec_dec_mode.is_ngram() or spec_dec_mode.is_user_provided(): - assert drafter is not None, "Drafter is required for ngram or user provided speculative decoding." - return drafter.spec_resource_manager + if spec_dec_mode.is_ngram(): + return NGramPoolManager(spec_config, max_num_requests) + if spec_dec_mode.is_user_provided(): + return spec_config.resource_manager return None @@ -113,14 +112,12 @@ def get_spec_decoder(sampler_args: TorchSampler.Args, f"Unsupported speculative decoding mode: {spec_config.spec_dec_mode}") -def get_spec_drafter(model_engine): +def get_spec_drafter(model_engine, spec_resource_manager): spec_config = model_engine.spec_config - max_num_requests = model_engine.batch_size if spec_config is None: return None if spec_config.spec_dec_mode.is_ngram(): - return NGramDrafter(spec_config, - NGramPoolManager(spec_config, max_num_requests)) + return NGramDrafter(spec_config, spec_resource_manager) if spec_config.spec_dec_mode.is_user_provided(): return spec_config.drafter return None diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 76fbaf473b..111d779ef3 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -354,8 +354,9 @@ class EagleDecodingConfig(DecodingBaseConfig): class UserProvidedDecodingConfig(DecodingBaseConfig): - # Type should be Drafter, but it leads to circular import - drafter: object + # Cannot use real type annotations due to circular imports + drafter: object # Type is Drafter + resource_manager: object = None # Type is Optional[ResourceManager] @classmethod def from_dict(cls, data: dict):