mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Support two model flow with no overlap scheduler or chain drafter. Drafting model is in PyTorch backend. Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
This commit is contained in:
parent
1f8ed71d5f
commit
a9eb5afc9f
@ -75,6 +75,8 @@ transforms:
|
||||
stage: pattern_matcher
|
||||
quantize_mxfp4_moe:
|
||||
stage: pattern_matcher
|
||||
detect_hidden_states_for_capture:
|
||||
stage: pattern_matcher
|
||||
detect_sharding:
|
||||
stage: sharding
|
||||
simple_shard_only: false
|
||||
@ -163,6 +165,9 @@ transforms:
|
||||
insert_cached_delta_rule:
|
||||
stage: cache_init
|
||||
backend: fla_delta
|
||||
insert_cached_residual_add:
|
||||
stage: cache_init
|
||||
backend: cached_residual_add
|
||||
initialize_cache:
|
||||
stage: cache_init
|
||||
run_per_gm: false
|
||||
|
||||
@ -8,7 +8,14 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
|
||||
from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, SamplerType, _ParallelConfig
|
||||
from ...llmapi.llm_args import (
|
||||
BaseLlmArgs,
|
||||
BuildConfig,
|
||||
EagleDecodingConfig,
|
||||
KvCacheConfig,
|
||||
SamplerType,
|
||||
_ParallelConfig,
|
||||
)
|
||||
from .models import ModelFactory, ModelFactoryRegistry
|
||||
from .utils._config import DynamicYamlMixInForSettings
|
||||
from .utils.logger import ad_logger
|
||||
@ -150,6 +157,11 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
|
||||
enable_chunked_prefill: bool = Field(default=False, description="Enable chunked prefill.")
|
||||
|
||||
draft_checkpoint_loader: Optional[object] = Field(
|
||||
default=None,
|
||||
description="The checkpoint loader to use for the draft model when using speculative decoding with two models.",
|
||||
)
|
||||
|
||||
### INFERENCE OPTIMIZER CONFIG #################################################################
|
||||
mode: Literal["graph", "transformers"] = Field(
|
||||
default="graph",
|
||||
@ -190,11 +202,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
),
|
||||
)
|
||||
|
||||
draft_checkpoint_loader: Optional[object] = Field(
|
||||
default=None,
|
||||
description="The checkpoint loader to use for the draft model when using speculative decoding with two models.",
|
||||
)
|
||||
|
||||
### SEQUENCE INTERFACE CONFIG ##################################################################
|
||||
max_input_len: int = Field(default=1024, description="The maximum input length.")
|
||||
max_num_tokens: Optional[int] = Field(default=None, description="The maximum number of tokens.")
|
||||
@ -420,6 +427,19 @@ class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings):
|
||||
msg = "AutoDeploy only supports parallelization via the `world_size` argument."
|
||||
return _check_for_default_value_only(cls, value, info, msg)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_hidden_state_capture(self):
|
||||
if self.speculative_config is None or not isinstance(
|
||||
self.speculative_config, EagleDecodingConfig
|
||||
):
|
||||
return self
|
||||
|
||||
self.transforms["detect_hidden_states_for_capture"]["capture_hidden_states"] = True
|
||||
self.transforms["detect_hidden_states_for_capture"]["eagle3_layers_to_capture"] = (
|
||||
self.speculative_config.eagle3_layers_to_capture
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_parallel_config(self):
|
||||
"""Setup parallel config according to world_size.
|
||||
|
||||
@ -13,14 +13,17 @@ import copy
|
||||
import types
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
from types import MethodType, SimpleNamespace
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from strenum import StrEnum
|
||||
from torch._prims_common import DeviceLikeType
|
||||
|
||||
from tensorrt_llm._torch.attention_backend.interface import AttentionRuntimeFeatures
|
||||
from tensorrt_llm._torch.auto_deploy.utils._graph import get_input_embeddings, get_lm_head_weights
|
||||
from tensorrt_llm._torch.models.modeling_speculative import Eagle3ForCausalLM
|
||||
from tensorrt_llm._torch.pyexecutor._util import (
|
||||
_create_kv_cache_manager,
|
||||
get_decoding_mode,
|
||||
@ -32,9 +35,11 @@ from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest, get_draft_tok
|
||||
from tensorrt_llm._torch.pyexecutor.py_executor_creator import get_guided_decoding_config
|
||||
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
|
||||
from tensorrt_llm._torch.speculative import get_spec_drafter
|
||||
from tensorrt_llm._torch.speculative.eagle3 import Eagle3ResourceManager
|
||||
from tensorrt_llm._utils import nvtx_range
|
||||
from tensorrt_llm.llmapi.llm_args import (
|
||||
ContextChunkingPolicy,
|
||||
EagleDecodingConfig,
|
||||
LoadFormat,
|
||||
SamplerType,
|
||||
TorchLlmArgs,
|
||||
@ -57,6 +62,7 @@ from ...pyexecutor.sampler import TorchSampler, TRTLLMSampler
|
||||
from ...pyexecutor.scheduler import (
|
||||
BindCapacityScheduler,
|
||||
BindMicroBatchScheduler,
|
||||
RequestList,
|
||||
ScheduledRequests,
|
||||
SimpleScheduler,
|
||||
)
|
||||
@ -113,6 +119,90 @@ class _CacheManagerWithFakePool(KVCacheManager):
|
||||
return self.num_blocks, 0
|
||||
|
||||
|
||||
class ADHiddenStateManager(Eagle3ResourceManager):
|
||||
def __init__(
|
||||
self,
|
||||
cache_seq_interface: CachedSequenceInterface,
|
||||
config: EagleDecodingConfig,
|
||||
max_num_requests: int,
|
||||
max_seq_len: int,
|
||||
max_num_tokens: int,
|
||||
):
|
||||
hidden_state_buffer = self._get_hidden_state_buffers(cache_seq_interface)[0]
|
||||
dtype = hidden_state_buffer.dtype
|
||||
hidden_size = hidden_state_buffer.shape[1]
|
||||
|
||||
super().__init__(config, dtype, hidden_size, max_num_requests, max_seq_len, max_num_tokens)
|
||||
|
||||
self.hidden_state_write_indices: torch.Tensor = torch.empty(
|
||||
max_num_tokens, dtype=torch.long, device="cuda"
|
||||
)
|
||||
|
||||
def _get_hidden_state_buffers(
|
||||
self, cache_seq_interface: CachedSequenceInterface
|
||||
) -> List[torch.Tensor]:
|
||||
hidden_state_buffers = []
|
||||
for name, tensor in cache_seq_interface.named_args.items():
|
||||
if "hidden_states_cache" in name:
|
||||
hidden_state_buffers.append(tensor)
|
||||
|
||||
if not hidden_state_buffers:
|
||||
raise ValueError(
|
||||
"No hidden_state_buffers found in cache_seq_interface. Check if we are actually running Eagle3."
|
||||
)
|
||||
return hidden_state_buffers
|
||||
|
||||
def prepare_hidden_states_capture(
|
||||
self, ordered_requests: RequestList, cache_seq_interface: CachedSequenceInterface
|
||||
) -> None:
|
||||
"""Prepare the hidden states for capture by establishing indices that the hidden states will be written to."""
|
||||
seq_lens = cache_seq_interface.info.seq_len
|
||||
num_tokens = sum(seq_lens)
|
||||
|
||||
start_idx = 0
|
||||
hidden_states_write_indices = []
|
||||
for request, seq_len in zip(ordered_requests, seq_lens):
|
||||
request_id = request.request_id
|
||||
slot_id = self.slot_manager.get_slot(request_id)
|
||||
self.start_indices[slot_id] = start_idx
|
||||
hidden_states_write_indices.extend(range(start_idx, start_idx + seq_len))
|
||||
start_idx += max(seq_len, self.max_total_draft_tokens + 1)
|
||||
assert start_idx < self.hidden_states.shape[0], (
|
||||
f"start_idx {start_idx} exceeds hidden_states capacity {self.hidden_states.shape[0]}"
|
||||
)
|
||||
|
||||
if len(hidden_states_write_indices) != num_tokens:
|
||||
raise ValueError(
|
||||
f"len(hidden_state_write_indices) ({len(hidden_states_write_indices)}) != num_tokens \
|
||||
({num_tokens}). Check whether ordered_requests matches up with seq_lens."
|
||||
)
|
||||
|
||||
hidden_state_write_indices_host = torch.tensor(
|
||||
hidden_states_write_indices, dtype=torch.long
|
||||
)
|
||||
|
||||
self.hidden_state_write_indices[:num_tokens].copy_(
|
||||
hidden_state_write_indices_host, non_blocking=True
|
||||
)
|
||||
|
||||
def capture_hidden_states(self, cache_seq_interface: CachedSequenceInterface) -> None:
|
||||
"""Capture configured hidden states that have been written by the model,
|
||||
in a format that can be used by the draft model.
|
||||
"""
|
||||
full_hidden_states = self._get_hidden_state_buffers(cache_seq_interface)
|
||||
if not full_hidden_states:
|
||||
return
|
||||
|
||||
num_tokens = sum(cache_seq_interface.info.seq_len)
|
||||
|
||||
hidden_states = [hidden_state[:num_tokens] for hidden_state in full_hidden_states]
|
||||
hidden_states = torch.cat(hidden_states, dim=1)
|
||||
hidden_states = hidden_states.to(dtype=self.dtype)
|
||||
|
||||
token_idx = self.hidden_state_write_indices[:num_tokens]
|
||||
self.hidden_states[:, : hidden_states.shape[1]].index_copy_(0, token_idx, hidden_states)
|
||||
|
||||
|
||||
def construct_draft_llm_args(
|
||||
ad_config: LlmArgs,
|
||||
) -> TorchLlmArgs:
|
||||
@ -461,6 +551,10 @@ class ADEngine(ModelEngine):
|
||||
kv_cache_manager = resource_manager.get_resource_manager(
|
||||
ResourceManagerType.KV_CACHE_MANAGER
|
||||
)
|
||||
# resource manager for hidden state capture
|
||||
spec_resource_manager = resource_manager.get_resource_manager(
|
||||
ResourceManagerType.SPEC_RESOURCE_MANAGER
|
||||
)
|
||||
|
||||
# requests in order of context, generate
|
||||
context_requests = scheduled_requests.context_requests
|
||||
@ -471,6 +565,7 @@ class ADEngine(ModelEngine):
|
||||
r for r in scheduled_requests.generation_requests if get_draft_token_length(r) == 0
|
||||
]
|
||||
gen_requests = extend_requests + generation_requests
|
||||
ordered_requests = context_requests + gen_requests
|
||||
# info to be extracted
|
||||
input_ids: List[List[int]] = []
|
||||
position_ids: List[List[int]] = []
|
||||
@ -670,6 +765,13 @@ class ADEngine(ModelEngine):
|
||||
|
||||
self.cache_seq_interface.info.run_host_prepare_for_attention_forward()
|
||||
|
||||
if spec_resource_manager is not None and isinstance(
|
||||
spec_resource_manager, ADHiddenStateManager
|
||||
):
|
||||
spec_resource_manager.prepare_hidden_states_capture(
|
||||
ordered_requests, self.cache_seq_interface
|
||||
)
|
||||
|
||||
self.iter_states["num_ctx_requests"] = num_ctx_requests
|
||||
self.iter_states["num_ctx_tokens"] = num_ctx_tokens
|
||||
# TODO: handle extend requests and draft requests for specdec
|
||||
@ -710,14 +812,74 @@ class ADEngine(ModelEngine):
|
||||
outputs = {
|
||||
"logits": self._compute_logits(),
|
||||
}
|
||||
|
||||
# save hidden states after running model.forward() in _compute_logits()
|
||||
spec_resource_manager = resource_manager.get_resource_manager(
|
||||
ResourceManagerType.SPEC_RESOURCE_MANAGER
|
||||
)
|
||||
if spec_resource_manager is not None and isinstance(
|
||||
spec_resource_manager, ADHiddenStateManager
|
||||
):
|
||||
spec_resource_manager.capture_hidden_states(self.cache_seq_interface)
|
||||
|
||||
if self.mapping is not None:
|
||||
self._execute_logit_post_processors(scheduled_requests, outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def share_target_weights_with_draft(
|
||||
target_model_engine: "ADEngine", draft_model_engine: PyTorchModelEngine
|
||||
):
|
||||
"""
|
||||
Certain speculative decoding methods (e.g. Eagle3) require sharing the target model's embedding and lm_head weights
|
||||
with the draft model. This function does this sharing if necessary.
|
||||
"""
|
||||
|
||||
assert isinstance(draft_model_engine.model, Eagle3ForCausalLM), (
|
||||
f"Expected draft_model_engine.model to be Eagle3ForCausalLM, got {type(draft_model_engine.model)}"
|
||||
)
|
||||
|
||||
def share_embedding_weights_with_draft(
|
||||
target_model_engine: "ADEngine", draft_model_engine: PyTorchModelEngine
|
||||
):
|
||||
embedding_weight = get_input_embeddings(target_model_engine.model)
|
||||
|
||||
world_size = mpi_world_size()
|
||||
assert world_size <= 1, f"This code assumes tp<=1. World size: {world_size}"
|
||||
|
||||
# Note: This simple forward function implementation assumes tp=1.
|
||||
# TODO(govind): Handle the tp>1 case.
|
||||
def new_embedding_forward(self, input_ids):
|
||||
return F.embedding(input_ids, self.weight)
|
||||
|
||||
if draft_model_engine.model.model.embed_tokens is None:
|
||||
submodule = torch.nn.Module()
|
||||
submodule.forward = MethodType(new_embedding_forward, submodule)
|
||||
submodule.weight = embedding_weight
|
||||
draft_model_engine.model.model.embed_tokens = submodule
|
||||
|
||||
def share_lm_head_weights_with_draft(
|
||||
target_model_engine: "ADEngine", draft_model_engine: PyTorchModelEngine
|
||||
):
|
||||
vocab_size = target_model_engine.cache_seq_interface.info.vocab_size_padded
|
||||
|
||||
lm_head_weight = get_lm_head_weights(target_model_engine.model)
|
||||
|
||||
assert lm_head_weight.shape[0] == vocab_size, (
|
||||
f"Expected lm_head weight first dimension to be vocab_size={vocab_size}, "
|
||||
f"but got shape {lm_head_weight.shape}"
|
||||
)
|
||||
|
||||
if draft_model_engine.model.load_lm_head_from_target:
|
||||
draft_model_engine.model.lm_head.weight = lm_head_weight
|
||||
|
||||
share_embedding_weights_with_draft(target_model_engine, draft_model_engine)
|
||||
share_lm_head_weights_with_draft(target_model_engine, draft_model_engine)
|
||||
|
||||
|
||||
def create_draft_model_engine_maybe(
|
||||
ad_config: LlmArgs, engine, dist_mapping: Mapping, mpi_dist: MPIDist
|
||||
ad_config: LlmArgs, target_engine: ADEngine, dist_mapping: Mapping, mpi_dist: MPIDist
|
||||
) -> Optional[PyTorchModelEngine]:
|
||||
"""Create a draft model engine for speculative decoding.
|
||||
|
||||
@ -745,7 +907,7 @@ def create_draft_model_engine_maybe(
|
||||
chunked_prefill=ad_config.enable_chunked_prefill,
|
||||
cache_reuse=kv_cache_config.enable_block_reuse,
|
||||
has_speculative_draft_tokens=has_spec_drafter,
|
||||
chunk_size=engine.llm_args.max_num_tokens,
|
||||
chunk_size=target_engine.llm_args.max_num_tokens,
|
||||
)
|
||||
|
||||
# Construct TorchLlmArgs for the draft model
|
||||
@ -753,6 +915,10 @@ def create_draft_model_engine_maybe(
|
||||
ad_config=ad_config,
|
||||
)
|
||||
|
||||
# chain drafter is not supported currently for AutoDeploy.
|
||||
# TODO(govind): Do this when we want to optimize 2-model spec dec performance.
|
||||
drafting_loop_wrapper = None
|
||||
|
||||
draft_model_engine = PyTorchModelEngine(
|
||||
model_path=draft_spec_config.speculative_model_dir,
|
||||
llm_args=draft_llm_args,
|
||||
@ -761,9 +927,14 @@ def create_draft_model_engine_maybe(
|
||||
dist=mpi_dist,
|
||||
spec_config=draft_spec_config,
|
||||
is_draft_model=True,
|
||||
drafting_loop_wrapper=None,
|
||||
drafting_loop_wrapper=drafting_loop_wrapper,
|
||||
)
|
||||
|
||||
if draft_spec_config.spec_dec_mode.is_eagle3():
|
||||
share_target_weights_with_draft(
|
||||
target_model_engine=target_engine, draft_model_engine=draft_model_engine
|
||||
)
|
||||
|
||||
draft_model_engine.kv_cache_manager_key = ResourceManagerType.DRAFT_KV_CACHE_MANAGER
|
||||
|
||||
return draft_model_engine
|
||||
@ -855,9 +1026,11 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
|
||||
engine = ADEngine.build_from_config(ad_config=ad_config, mapping=dist_mapping)
|
||||
|
||||
spec_config = ad_config.speculative_config
|
||||
if spec_config is not None and not spec_config.spec_dec_mode.is_draft_target():
|
||||
if spec_config is not None and not (
|
||||
spec_config.spec_dec_mode.is_draft_target() or spec_config.spec_dec_mode.is_eagle3()
|
||||
):
|
||||
raise ValueError(
|
||||
"Currently, AutoDeploy only supports speculative decoding in draft target mode."
|
||||
"Currently, AutoDeploy only supports speculative decoding in draft target or eagle3 mode."
|
||||
)
|
||||
|
||||
if spec_config is not None and ad_config.guided_decoding_backend is not None:
|
||||
@ -865,11 +1038,20 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
|
||||
"Guided decoding is not currently supported for speculative decoding in AutoDeploy."
|
||||
)
|
||||
|
||||
# Speculative resource manager not needed for DraftTargetDecoding.
|
||||
spec_resource_manager = None
|
||||
|
||||
draft_model_engine = create_draft_model_engine_maybe(
|
||||
ad_config=ad_config, engine=engine, dist_mapping=dist_mapping, mpi_dist=mpi_dist
|
||||
ad_config=ad_config, target_engine=engine, dist_mapping=dist_mapping, mpi_dist=mpi_dist
|
||||
)
|
||||
|
||||
spec_resource_manager = (
|
||||
ADHiddenStateManager(
|
||||
cache_seq_interface=engine.cache_seq_interface,
|
||||
config=spec_config,
|
||||
max_num_requests=ad_config.max_batch_size,
|
||||
max_seq_len=engine.llm_args.max_seq_len,
|
||||
max_num_tokens=engine.llm_args.max_num_tokens,
|
||||
)
|
||||
if isinstance(spec_config, EagleDecodingConfig)
|
||||
else None
|
||||
)
|
||||
|
||||
# check kvcache config for partial block reuse
|
||||
|
||||
@ -25,7 +25,9 @@ from typing import Tuple
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...utils.node_utils import is_linear_op, is_op
|
||||
from tensorrt_llm._torch.auto_deploy.utils._graph import get_lm_head_node
|
||||
|
||||
from ...utils.node_utils import is_linear_op
|
||||
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
|
||||
|
||||
|
||||
@ -54,9 +56,7 @@ class GatherLogitsBeforeLmHeadTransform(BaseTransform):
|
||||
self._log_info("Applying GatherLogitsBeforeLmHead transform...")
|
||||
|
||||
# assume lm head node is the input to the output node
|
||||
lm_head_node = gm.graph.find_nodes(op="output")[0].all_input_nodes[0]
|
||||
if is_op(lm_head_node, torch.ops.aten.to):
|
||||
lm_head_node = lm_head_node.all_input_nodes[0]
|
||||
lm_head_node = get_lm_head_node(gm)
|
||||
|
||||
if is_linear_op(lm_head_node):
|
||||
node_to_gather = lm_head_node.all_input_nodes[0]
|
||||
|
||||
@ -0,0 +1,239 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The transform passes to capture the hidden states of the target model."""
|
||||
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch._ops import OpOverloadPacket
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...custom_ops.attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
CacheConfig,
|
||||
CacheInitializerDict,
|
||||
MHACallable,
|
||||
SequenceInfo,
|
||||
)
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.node_utils import get_all_layer_subgraphs, is_op
|
||||
from ..interface import (
|
||||
BaseTransform,
|
||||
SharedConfig,
|
||||
TransformConfig,
|
||||
TransformInfo,
|
||||
TransformRegistry,
|
||||
)
|
||||
from .kvcache import InsertCachedAttention
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::residual_add_for_capture", mutates_args=())
|
||||
def residual_add_for_capture(t1: torch.Tensor, t2: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.aten.add(t1, t2)
|
||||
|
||||
|
||||
@residual_add_for_capture.register_fake
|
||||
def residual_add_for_capture_fake(t1: torch.Tensor, t2: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.aten.add(t1, t2)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::cached_residual_add", mutates_args=())
|
||||
def cached_residual_add(
|
||||
t1: torch.Tensor, t2: torch.Tensor, hidden_states_cache: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
ret = torch.ops.aten.add(t1, t2)
|
||||
b, s, _ = ret.shape
|
||||
num_tokens = b * s
|
||||
|
||||
hidden_states_cache[:num_tokens].copy_(ret.view(num_tokens, -1), non_blocking=True)
|
||||
return ret
|
||||
|
||||
|
||||
@cached_residual_add.register_fake
|
||||
def cached_residual_add_fake(
|
||||
t1: torch.Tensor, t2: torch.Tensor, hidden_states_cache: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.aten.add(t1, t2)
|
||||
|
||||
|
||||
class DetectHiddenStatesForCaptureConfig(TransformConfig):
|
||||
"""Configuration for the hidden states detection transform."""
|
||||
|
||||
# Whether to capture hidden states at all. If False we will not capture any layers.
|
||||
capture_hidden_states: bool = False
|
||||
|
||||
# TODO: figure out how to get layers to capture.
|
||||
# We should consider if we can use the layer indices stored in eagle checkpoints, e.g.
|
||||
# https://huggingface.co/nvidia/gpt-oss-120b-Eagle3/blob/main/config.json#L9-L14
|
||||
eagle3_layers_to_capture: Optional[Set[int]] = None
|
||||
|
||||
def set_default_eagle3_layers_to_capture(self, num_hidden_layers: int):
|
||||
"""
|
||||
Used to set default layers to capture when we want to capture hidden states, but
|
||||
no layers to capture are provided.
|
||||
"""
|
||||
if num_hidden_layers <= 6:
|
||||
raise ValueError("Not enough hidden layers for default EAGLE3 capture")
|
||||
self.eagle3_layers_to_capture = {1, num_hidden_layers // 2 - 1, num_hidden_layers - 4}
|
||||
|
||||
|
||||
@TransformRegistry.register("detect_hidden_states_for_capture")
|
||||
class DetectHiddenStatesForCapture(BaseTransform):
|
||||
"""Detect the hidden states we should capture in the graph."""
|
||||
|
||||
config: DetectHiddenStatesForCaptureConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return DetectHiddenStatesForCaptureConfig
|
||||
|
||||
def collect_residual_add_nodes(self, gm: GraphModule) -> Dict[int, Node]:
|
||||
def _get_layer_number(lin_node: Node) -> Optional[int]:
|
||||
weight = lin_node.args[1]
|
||||
if weight.op == "get_attr":
|
||||
subnames = weight.target.split(".")
|
||||
for subname in subnames:
|
||||
if subname.isdigit():
|
||||
return int(subname)
|
||||
|
||||
return None
|
||||
|
||||
# find last closing linear node of each layer
|
||||
# from there we will find the residual add node for that layer
|
||||
layer_subgraphs, unprocessed_linear_nodes = get_all_layer_subgraphs(gm)
|
||||
residual_add_nodes: Dict[int, Node] = {}
|
||||
for layer_subgraph in layer_subgraphs:
|
||||
lin_node_closing = layer_subgraph.terminating_node
|
||||
# need layer number to correctly identify the residual add node
|
||||
layer_number = _get_layer_number(lin_node_closing)
|
||||
if layer_number is None:
|
||||
continue
|
||||
|
||||
# Conditions to identify as the hidden states after the residual
|
||||
# The first node after the linear closing node that satisfies:
|
||||
# 1. is an add node with > 1 users (hidden states before the last are used directly by next layer
|
||||
# as well as having a residual add to the next hidden state). Stopping here prevents us from
|
||||
# using a future residual add node for the next layer.
|
||||
# 2. is the last add node in a 1 user chain (for last layer or layers with no following residual add)
|
||||
# This stops us before we go to the next layer.
|
||||
res_node = lin_node_closing
|
||||
while len(res_node.users) == 1:
|
||||
user_node = list(res_node.users)[0]
|
||||
if not is_op(user_node, torch.ops.aten.add):
|
||||
break
|
||||
res_node = user_node
|
||||
|
||||
if is_op(res_node, torch.ops.aten.add):
|
||||
# this stores the last residual add node encountered for each layer
|
||||
residual_add_nodes[layer_number] = res_node
|
||||
|
||||
return residual_add_nodes
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
if not self.config.capture_hidden_states:
|
||||
info = TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True)
|
||||
return gm, info
|
||||
|
||||
if gm.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.auto_deploy.residual_add_for_capture.default
|
||||
):
|
||||
info = TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True)
|
||||
return gm, info
|
||||
|
||||
residual_add_nodes = self.collect_residual_add_nodes(gm)
|
||||
|
||||
if self.config.eagle3_layers_to_capture is None:
|
||||
num_hidden_layers = len(residual_add_nodes)
|
||||
self.config.set_default_eagle3_layers_to_capture(num_hidden_layers)
|
||||
|
||||
residual_add_nodes = {
|
||||
k: v for k, v in residual_add_nodes.items() if k in self.config.eagle3_layers_to_capture
|
||||
}
|
||||
|
||||
assert residual_add_nodes.keys() == self.config.eagle3_layers_to_capture, (
|
||||
f"Unable to find residual add nodes for layers. Expected: {self.config.eagle3_layers_to_capture}, \
|
||||
Found: {residual_add_nodes.keys()}"
|
||||
)
|
||||
|
||||
# replace residual add nodes with special placeholder nodes
|
||||
for layer_number, res_node in residual_add_nodes.items():
|
||||
with gm.graph.inserting_before(res_node):
|
||||
new_node = gm.graph.call_function(
|
||||
torch.ops.auto_deploy.residual_add_for_capture.default,
|
||||
args=res_node.args,
|
||||
kwargs=res_node.kwargs,
|
||||
)
|
||||
res_node.replace_all_uses_with(new_node)
|
||||
gm.graph.erase_node(res_node)
|
||||
|
||||
cnt = len(residual_add_nodes)
|
||||
info = TransformInfo(
|
||||
skipped=False, num_matches=cnt, is_clean=(cnt == 0), has_valid_shapes=(cnt == 0)
|
||||
)
|
||||
return gm, info
|
||||
|
||||
|
||||
@AttentionRegistry.register("cached_residual_add")
|
||||
class CachedResidualAdd(AttentionDescriptor):
|
||||
@classmethod
|
||||
def is_paged(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> AttentionLayout:
|
||||
return "bsnd"
|
||||
|
||||
@classmethod
|
||||
def get_num_qkv_args(cls) -> int:
|
||||
return 2
|
||||
|
||||
@classmethod
|
||||
def get_source_attention_op(cls) -> OpOverloadPacket:
|
||||
return torch.ops.auto_deploy.residual_add_for_capture
|
||||
|
||||
@classmethod
|
||||
def get_cached_attention_op(cls) -> MHACallable:
|
||||
return torch.ops.auto_deploy.cached_residual_add
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: CacheConfig
|
||||
) -> CacheInitializerDict:
|
||||
hidden_size = source_attn_node.meta["val"].shape[-1]
|
||||
hidden_type = source_attn_node.meta["val"].dtype
|
||||
|
||||
def _get_hidden_states_cache(si: SequenceInfo):
|
||||
return torch.empty(si.max_num_tokens, hidden_size, device=si.device, dtype=hidden_type)
|
||||
|
||||
return {"hidden_states_cache": _get_hidden_states_cache}
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
@TransformRegistry.register("insert_cached_residual_add")
|
||||
class InsertCachedResidualAdd(InsertCachedAttention):
|
||||
"""A transform to handle residual add cache operations."""
|
||||
@ -17,7 +17,7 @@ from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
from torch.utils._pytree import _LEAF_SPEC
|
||||
|
||||
from .logger import ad_logger
|
||||
from .node_utils import is_op
|
||||
from .node_utils import get_weight_tensor, is_op
|
||||
|
||||
_NoValType = type("_NoValType", (), {})
|
||||
_NO_VAL = _NoValType()
|
||||
@ -344,3 +344,60 @@ def placeholders_on_meta(mod: nn.Module) -> bool:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_input_embeddings(model: nn.Module) -> torch.Tensor:
|
||||
"""Find the unique embedding node across all graph modules."""
|
||||
embedding_weights = []
|
||||
for _, gm in named_graphmodules(model):
|
||||
found_nodes = gm.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.aten.embedding.default
|
||||
)
|
||||
for node in found_nodes:
|
||||
embedding_weights.append(get_weight_tensor(gm, node))
|
||||
|
||||
if hasattr(model, "get_input_embeddings"):
|
||||
embedding_weights.append(model.get_input_embeddings())
|
||||
|
||||
for _, gm in named_graphmodules(model):
|
||||
if hasattr(gm, "get_input_embeddings"):
|
||||
embedding_weights.append(gm.get_input_embeddings())
|
||||
|
||||
assert len(embedding_weights) > 0, "No embedding weights found"
|
||||
unique_embedding_weights = [embedding_weights[0]]
|
||||
for weight in embedding_weights:
|
||||
if weight is not unique_embedding_weights[0]:
|
||||
unique_embedding_weights.append(weight)
|
||||
|
||||
assert len(unique_embedding_weights) == 1, (
|
||||
f"Expected exactly 1 unique embedding weight, but found {len(unique_embedding_weights)}."
|
||||
)
|
||||
|
||||
return unique_embedding_weights[0]
|
||||
|
||||
|
||||
def get_output_node(model: nn.Module) -> tuple[GraphModule, Node]:
|
||||
"""Find the unique output node across all graph modules."""
|
||||
output_nodes = []
|
||||
for _, gm in named_graphmodules(model):
|
||||
output_nodes.extend([(gm, node) for node in gm.graph.find_nodes(op="output")])
|
||||
|
||||
assert len(output_nodes) == 1, f"Expected exactly 1 output node, but found {len(output_nodes)}."
|
||||
return output_nodes[0]
|
||||
|
||||
|
||||
def get_lm_head_node(gm: GraphModule, output_node: Optional[Node] = None) -> Node:
|
||||
if output_node is None:
|
||||
output_node = gm.graph.find_nodes(op="output")[0]
|
||||
|
||||
lm_head_node = output_node.all_input_nodes[0]
|
||||
if is_op(lm_head_node, torch.ops.aten.to):
|
||||
lm_head_node = lm_head_node.all_input_nodes[0]
|
||||
|
||||
return lm_head_node
|
||||
|
||||
|
||||
def get_lm_head_weights(model: nn.Module) -> torch.Tensor:
|
||||
gm, output_node = get_output_node(model)
|
||||
lm_head_node = get_lm_head_node(gm, output_node)
|
||||
return get_weight_tensor(gm, lm_head_node)
|
||||
|
||||
@ -923,6 +923,12 @@ def shape(node: Node) -> Tuple[int, ...]:
|
||||
return node.meta["val"].shape
|
||||
|
||||
|
||||
def get_weight_tensor(gm: GraphModule, node: Node) -> "torch.Tensor":
|
||||
"""Extract the weight tensor from a node within a GraphModule."""
|
||||
weight_name = extract_param_names_from_node(node)[0]
|
||||
return gm.get_parameter(weight_name)
|
||||
|
||||
|
||||
def draw_graph(gm: GraphModule, filename: str):
|
||||
"""
|
||||
Dump graphmodule to SVG file using PyTorch's built-in drawer.
|
||||
|
||||
@ -19,39 +19,58 @@ import pytest
|
||||
from build_and_run_ad import ExperimentConfig, main
|
||||
from defs.conftest import llm_models_root
|
||||
|
||||
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, KvCacheConfig
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch.auto_deploy.llm import LLM
|
||||
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, EagleDecodingConfig, KvCacheConfig
|
||||
|
||||
prompts = [
|
||||
"What is the capital of France?",
|
||||
"Please explain the concept of gravity in simple words and a single sentence.",
|
||||
"What is the capital of Norway?",
|
||||
"What is the highest mountain in the world?",
|
||||
]
|
||||
|
||||
EAGLE_MODEL_SUBPATH = "EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
LLAMA_BASE_SUBPATH = "llama-3.1-model/Llama-3.1-8B-Instruct"
|
||||
DRAFT_TARGET_MAX_DRAFT_LEN = 3
|
||||
EAGLE_MAX_DRAFT_LEN = 3
|
||||
|
||||
|
||||
def get_model_paths():
|
||||
"""Get model paths using llm_models_root()."""
|
||||
models_root = llm_models_root()
|
||||
base_model = os.path.join(
|
||||
models_root,
|
||||
"llama-3.1-model/Llama-3.1-8B-Instruct",
|
||||
)
|
||||
speculative_model = os.path.join(
|
||||
base_model = os.path.join(models_root, LLAMA_BASE_SUBPATH)
|
||||
draft_target_model = os.path.join(
|
||||
models_root,
|
||||
"llama-models-v2/TinyLlama-1.1B-Chat-v1.0",
|
||||
)
|
||||
eagle_model = os.path.join(models_root, EAGLE_MODEL_SUBPATH)
|
||||
|
||||
print(f"Base model path: {base_model}")
|
||||
print(f"Speculative model path: {speculative_model}")
|
||||
return base_model, speculative_model
|
||||
print(f"DraftTarget draft model path: {draft_target_model}")
|
||||
print(f"EAGLE model path: {eagle_model}")
|
||||
return base_model, draft_target_model, eagle_model
|
||||
|
||||
|
||||
def run_with_autodeploy(model, speculative_model_dir, batch_size):
|
||||
def make_draft_target_config(spec_model_path: str):
|
||||
return DraftTargetDecodingConfig(
|
||||
max_draft_len=DRAFT_TARGET_MAX_DRAFT_LEN, speculative_model_dir=spec_model_path
|
||||
)
|
||||
|
||||
|
||||
def make_eagle3_config(spec_model_path: str):
|
||||
return EagleDecodingConfig(
|
||||
max_draft_len=EAGLE_MAX_DRAFT_LEN,
|
||||
speculative_model_dir=spec_model_path,
|
||||
eagle3_one_model=False,
|
||||
eagle3_layers_to_capture=None,
|
||||
)
|
||||
|
||||
|
||||
def run_with_autodeploy(model, speculative_config, batch_size):
|
||||
"""Run AutoDeploy with or without speculative decoding.
|
||||
|
||||
Args:
|
||||
model: Path to the base model
|
||||
speculative_model_dir: Path to the speculative model (None for baseline mode)
|
||||
speculative_config: Speculative decoding config (None for baseline mode)
|
||||
batch_size: Number of prompts to process
|
||||
|
||||
Returns:
|
||||
@ -60,16 +79,11 @@ def run_with_autodeploy(model, speculative_model_dir, batch_size):
|
||||
# Select prompts based on batch size
|
||||
selected_prompts = prompts[:batch_size]
|
||||
|
||||
# Configure speculative decoding if speculative_model_dir is provided
|
||||
spec_config = None
|
||||
if speculative_model_dir is not None:
|
||||
spec_config = DraftTargetDecodingConfig(
|
||||
max_draft_len=3, speculative_model_dir=speculative_model_dir
|
||||
)
|
||||
spec_config = speculative_config
|
||||
|
||||
# Configure KV cache
|
||||
kv_cache_config = KvCacheConfig(
|
||||
free_gpu_memory_fraction=0.1,
|
||||
free_gpu_memory_fraction=0.01,
|
||||
)
|
||||
|
||||
# Configure AutoDeploy LLM arguments
|
||||
@ -81,9 +95,6 @@ def run_with_autodeploy(model, speculative_model_dir, batch_size):
|
||||
"world_size": 1,
|
||||
"kv_cache_config": kv_cache_config,
|
||||
"disable_overlap_scheduler": True,
|
||||
"transforms": {
|
||||
"fuse_rmsnorm": {"rmsnorm_backend": "triton"},
|
||||
},
|
||||
"max_num_tokens": 64,
|
||||
}
|
||||
|
||||
@ -116,34 +127,46 @@ def run_with_autodeploy(model, speculative_model_dir, batch_size):
|
||||
return result["prompts_and_outputs"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 4])
|
||||
def test_autodeploy_spec_dec(batch_size):
|
||||
"""Test AutoDeploy speculative decoding with different batch sizes.
|
||||
# Note: This test tests exact equality of outputs between speculative and baseline modes.
|
||||
# This can fail for larger batch sizes due to nondeterminism with in flight batching.
|
||||
# TODO: Figure out a robust test for output correctness that can pass for larger batch sizes.
|
||||
@pytest.mark.parametrize("spec_dec_mode", ["draft_target", "eagle3"])
|
||||
def test_autodeploy_spec_dec_output(spec_dec_mode):
|
||||
"""Test AutoDeploy speculative decoding output correctness.
|
||||
|
||||
Runs with and without speculative decoding and verifies outputs are identical.
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print(f"Testing AutoDeploy Speculative Decoding - Batch Size {batch_size}")
|
||||
print(f"Testing AutoDeploy Speculative Decoding ({spec_dec_mode}) - Output Correctness")
|
||||
print("=" * 80)
|
||||
|
||||
base_model, speculative_model = get_model_paths()
|
||||
base_model, draft_target_model, eagle_model = get_model_paths()
|
||||
|
||||
# Select model and config based on mode
|
||||
if spec_dec_mode == "draft_target":
|
||||
spec_model = draft_target_model
|
||||
spec_config = make_draft_target_config(spec_model)
|
||||
elif spec_dec_mode == "eagle3": # eagle3
|
||||
spec_model = eagle_model
|
||||
spec_config = make_eagle3_config(spec_model)
|
||||
else:
|
||||
raise ValueError(f"Unsupported speculative decoding mode: {spec_dec_mode}")
|
||||
|
||||
print(f"\nBase Model: {base_model}")
|
||||
print(f"Speculative Model: {speculative_model}")
|
||||
print(f"Batch Size: {batch_size}")
|
||||
print(f"Speculative Model ({spec_dec_mode}): {spec_model}")
|
||||
|
||||
# Run with speculative decoding
|
||||
print("\n[1/2] Running with speculative decoding enabled...")
|
||||
spec_outputs = run_with_autodeploy(
|
||||
model=base_model, speculative_model_dir=speculative_model, batch_size=batch_size
|
||||
model=base_model,
|
||||
speculative_config=spec_config,
|
||||
batch_size=1,
|
||||
)
|
||||
print(f"Generated {len(spec_outputs)} outputs with speculative decoding")
|
||||
|
||||
# Run without speculative decoding (baseline)
|
||||
print("\n[2/2] Running without speculative decoding (baseline)...")
|
||||
baseline_outputs = run_with_autodeploy(
|
||||
model=base_model, speculative_model_dir=None, batch_size=batch_size
|
||||
)
|
||||
baseline_outputs = run_with_autodeploy(model=base_model, speculative_config=None, batch_size=1)
|
||||
print(f"Generated {len(baseline_outputs)} outputs in baseline mode")
|
||||
|
||||
# Verify outputs are identical
|
||||
@ -171,3 +194,85 @@ def test_autodeploy_spec_dec(batch_size):
|
||||
print("\n" + "=" * 80)
|
||||
print("SUCCESS! All outputs are identical between spec-dec and baseline modes")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
def test_autodeploy_eagle3_acceptance_rate():
|
||||
"""Test Eagle3 acceptance rate with AutoDeploy engine.
|
||||
|
||||
Runs Eagle3 speculative decoding with streaming and verifies
|
||||
that the acceptance rate is above a minimum threshold.
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("Testing AutoDeploy Eagle3 Acceptance Rate")
|
||||
print("=" * 80)
|
||||
|
||||
base_model, _, eagle_model = get_model_paths()
|
||||
|
||||
print(f"\nBase Model: {base_model}")
|
||||
print(f"Eagle3 Model: {eagle_model}")
|
||||
|
||||
max_draft_len = EAGLE_MAX_DRAFT_LEN
|
||||
|
||||
# Configure Eagle3 speculative decoding
|
||||
speculative_config = EagleDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model,
|
||||
eagle3_one_model=False,
|
||||
eagle3_layers_to_capture=None,
|
||||
)
|
||||
|
||||
# Configure KV cache
|
||||
kv_cache_config = KvCacheConfig(
|
||||
free_gpu_memory_fraction=0.01,
|
||||
)
|
||||
|
||||
# Create AutoDeploy LLM with Eagle3 speculative decoding
|
||||
# We directly instantiate the LLM class instead of using the main() function
|
||||
# so that we can stream the outputs to see acceptance rates without needing to
|
||||
# collect them in the executor.
|
||||
llm = LLM(
|
||||
model=base_model,
|
||||
skip_loading_weights=False,
|
||||
runtime="trtllm",
|
||||
world_size=1,
|
||||
kv_cache_config=kv_cache_config,
|
||||
speculative_config=speculative_config,
|
||||
disable_overlap_scheduler=True,
|
||||
max_num_tokens=64,
|
||||
)
|
||||
|
||||
# Tokenize 2 prompts to test multiple sequential requests
|
||||
batch_tok_ids = [llm.tokenizer.encode(p) for p in prompts[:2]]
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=128, temperature=0, seed=42)
|
||||
|
||||
print("\nRunning Eagle3 speculative decoding with streaming...")
|
||||
|
||||
# Process each request sequentially and verify acceptance rate
|
||||
for i in range(len(batch_tok_ids)):
|
||||
num_tokens = 0
|
||||
num_drafted = 0
|
||||
num_accepted = 0
|
||||
|
||||
for output in llm.generate_async(batch_tok_ids[i], sampling_params, streaming=True):
|
||||
new_tokens = output.outputs[0].token_ids
|
||||
num_drafted += max_draft_len
|
||||
num_accepted += len(new_tokens) - num_tokens - 1
|
||||
num_tokens = len(new_tokens)
|
||||
|
||||
accept_rate = num_accepted / num_drafted
|
||||
|
||||
print(f"\nRequest {i + 1} Acceptance Rate Statistics:")
|
||||
print(f" Total tokens drafted: {num_drafted}")
|
||||
print(f" Total tokens accepted: {num_accepted}")
|
||||
print(f" Acceptance rate: {accept_rate:.2%}")
|
||||
|
||||
# Verify acceptance rate is above minimum threshold (10%)
|
||||
min_acceptance_rate = 0.10
|
||||
assert accept_rate > min_acceptance_rate, (
|
||||
f"Request {i + 1}: Acceptance rate {accept_rate:.2%} is below minimum threshold {min_acceptance_rate:.0%}"
|
||||
)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("SUCCESS! All requests passed acceptance rate threshold")
|
||||
print("=" * 80)
|
||||
|
||||
@ -116,8 +116,9 @@ l0_h100:
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[True]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16
|
||||
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec[1]
|
||||
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec[4]
|
||||
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[draft_target]
|
||||
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[eagle3]
|
||||
- examples/test_ad_speculative_decoding.py::test_autodeploy_eagle3_acceptance_rate
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user