[None][refactor] Refactor Torch Compile Backend, MoeLoadBalancer and warmup Logic (#6615)

Signed-off-by: yizhang-nv <187001205+yizhang-nv@users.noreply.github.com>
Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
This commit is contained in:
Yi Zhang 2025-08-19 09:58:44 +08:00 committed by GitHub
parent 71e28eab36
commit a15af879ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 270 additions and 184 deletions

View File

@ -37,7 +37,7 @@ class Backend:
enable_inductor=True,
enable_userbuffers=False,
enable_piecewise_cuda_graph: bool = False,
cuda_graph_batch_sizes: Optional[List[int]] = None,
capture_num_tokens: Optional[List[int]] = None,
max_num_streams: int = 1,
) -> None:
super().__init__()
@ -48,14 +48,12 @@ class Backend:
self.custom_passes = Backend.get_custom_pass(enable_userbuffers)
self.rank = tensorrt_llm.mpi_rank()
self.enable_inductor = enable_inductor
self.cuda_graph_batch_sizes = (cuda_graph_batch_sizes
if cuda_graph_batch_sizes is not None
else [])
self.capture_num_tokens = capture_num_tokens or []
self.piecewise_cuda_graph = enable_piecewise_cuda_graph
self.no_optimization = False
# We only need to create aux streams.
self.aux_streams = Backend.Streams(
[torch.cuda.Stream() for i in range(max_num_streams - 1)])
[torch.cuda.Stream() for _ in range(max_num_streams - 1)])
self.events = Backend.Events()
inductor_config.enable_auto_functionalized_v2 = False
@ -125,7 +123,7 @@ class Backend:
example_inputs,
self.enable_inductor,
self.input_num_tokens,
self.cuda_graph_batch_sizes,
self.capture_num_tokens,
self._graph_pool_handle,
len(self.aux_streams) + 1,
)

View File

@ -14,8 +14,7 @@ from tensorrt_llm.llmapi.utils import enable_llm_debug
from ..utils import (get_model_extra_attrs, get_piecewise_cuda_graph_flag,
make_weak_ref)
from .multi_stream.auto_multi_stream import multi_stream_schedule
from .utils import (get_enable_piecewise_cuda_graph_capture_flag,
is_call_function)
from .utils import get_capture_piecewise_cuda_graph_flag, is_call_function
class PiecewiseInterpreter(Interpreter):
@ -25,7 +24,7 @@ class PiecewiseInterpreter(Interpreter):
module: GraphModule,
enable_inductor: bool,
compile_time_num_tokens: Union[int | torch.SymInt],
cuda_graph_batch_sizes: list[int],
capture_num_tokens: list[int],
exclude_modules_id: list[int],
graph_pool_handle: tuple[int, int],
garbage_collect_values: bool = True,
@ -37,7 +36,7 @@ class PiecewiseInterpreter(Interpreter):
self.fake_mode = detect_fake_mode()
self.compile_time_num_tokens = compile_time_num_tokens
self.cuda_graph_batch_sizes = cuda_graph_batch_sizes
self.capture_num_tokens = capture_num_tokens
self.exclude_modules = [f"submod_{i}" for i in exclude_modules_id]
self.graph_pool_handle = graph_pool_handle
self.enable_inductor = enable_inductor
@ -86,7 +85,7 @@ class PiecewiseInterpreter(Interpreter):
target,
self.compile_time_num_tokens,
runtime_num_tokens_idx,
self.cuda_graph_batch_sizes,
self.capture_num_tokens,
self.graph_pool_handle,
compile_fx(submod, args) if self.enable_inductor else submod,
self.enable_inductor,
@ -120,7 +119,7 @@ class PiecewiseRunner(object):
name: str,
compile_time_num_tokens: Union[int | torch.SymInt],
runtime_num_tokens_idx: tuple[int],
cuda_graph_batch_sizes: List[int],
capture_num_tokens: List[int],
graph_pool_handle,
default_callable: Callable,
enable_inductor: bool,
@ -139,9 +138,9 @@ class PiecewiseRunner(object):
self.entries: dict[int, Entry] = {}
for bs in cuda_graph_batch_sizes:
self.entries[bs] = Entry(
bs,
for num_tokens in capture_num_tokens:
self.entries[num_tokens] = Entry(
num_tokens,
enable_inductor=self.enable_inductor,
callable=default_callable,
)
@ -167,7 +166,7 @@ class PiecewiseRunner(object):
if entry.cuda_graph is None:
if not get_enable_piecewise_cuda_graph_capture_flag():
if not get_capture_piecewise_cuda_graph_flag():
return entry.callable(*args)
if entry.warmup_count < 3:
@ -228,7 +227,7 @@ def piecewise_optimizer(
example_inputs: List[torch.Tensor],
enable_inductor: bool,
input_num_tokens: Union[int | torch.SymInt],
cuda_graph_batch_sizes: Sequence[int],
capture_num_tokens: Sequence[int],
graph_pool_handle: tuple[int, int],
max_num_streams: int = 1,
) -> tuple[GraphModule, int]:
@ -269,7 +268,7 @@ def piecewise_optimizer(
gm,
enable_inductor,
input_num_tokens,
cuda_graph_batch_sizes,
capture_num_tokens,
exclude_modules_id,
graph_pool_handle,
max_num_streams=max_num_streams,

View File

@ -1,3 +1,4 @@
import contextlib
from typing import Callable, List, Union
import torch
@ -33,16 +34,26 @@ def is_call_function(node: Node, target: Union[List[Callable], Callable]):
_enable_piecewise_cuda_graph_capture = False
def set_enable_piecewise_cuda_graph_capture_flag(enable: bool):
def set_capture_piecewise_cuda_graph_flag(enable: bool):
global _enable_piecewise_cuda_graph_capture
_enable_piecewise_cuda_graph_capture = enable
def get_enable_piecewise_cuda_graph_capture_flag() -> bool:
def get_capture_piecewise_cuda_graph_flag() -> bool:
global _enable_piecewise_cuda_graph_capture
return _enable_piecewise_cuda_graph_capture
@contextlib.contextmanager
def capture_piecewise_cuda_graph(enable: bool):
prev_enable = get_capture_piecewise_cuda_graph_flag()
set_capture_piecewise_cuda_graph_flag(enable)
try:
yield
finally:
set_capture_piecewise_cuda_graph_flag(prev_enable)
def inplace_info():
inplace_map = {
torch.ops.trtllm.flashinfer_fused_add_rmsnorm.default: {

View File

@ -8,6 +8,7 @@ from tensorrt_llm._utils import get_sm_version
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
OptimizationProfile, TunableRunner, TuningConfig)
from ..modules.multi_stream_utils import do_multi_stream
from ..utils import (fp4_scale_infer_shape,
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2)
@ -925,6 +926,8 @@ def get_stream(stream_id: int):
@torch.library.custom_op("trtllm::set_stream", mutates_args=())
def set_stream(stream_id: int) -> None:
if not do_multi_stream():
return
stream = get_stream(stream_id)
assert stream is not None
torch.cuda.set_stream(stream)
@ -932,18 +935,24 @@ def set_stream(stream_id: int) -> None:
@torch.library.custom_op("trtllm::record_event", mutates_args=())
def record_event(event_idx: int) -> None:
if not do_multi_stream():
return
event = get_event(event_idx)
event.record()
@torch.library.custom_op("trtllm::wait_event", mutates_args=())
def wait_event(event_idx: int) -> None:
if not do_multi_stream():
return
event = get_event(event_idx)
event.wait()
@torch.library.custom_op("trtllm::record_stream", mutates_args=())
def record_stream(tensor: torch.Tensor, stream_id: int) -> None:
if not do_multi_stream():
return
stream = get_stream(stream_id)
assert stream is not None
tensor.record_stream(stream)

View File

@ -9,12 +9,12 @@ from mpi4py import MPI
import tensorrt_llm
import tensorrt_llm.bindings.internal.runtime as _tbr
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import is_graph_capturing
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from ...distributed import AllReduce
from ...utils import EventType
from ..multi_stream_utils import do_multi_stream
def _tensor_to_weight(t: torch.Tensor) -> _tbr.MoeWeight:
@ -472,7 +472,7 @@ class SingleLayerMoeLoadBalancer:
assert self.func_called_count["start_wait_gpu_stage"] == 0
self.func_called_count["start_wait_gpu_stage"] += 1
if self.updates_enabled:
if is_graph_capturing():
if do_multi_stream():
self.event_dict[EventType.Main].record()
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.Main].wait()
@ -491,7 +491,7 @@ class SingleLayerMoeLoadBalancer:
assert self.func_called_count["done_wait_gpu_stage"] == 0
self.func_called_count["done_wait_gpu_stage"] += 1
if self.updates_enabled:
if is_graph_capturing():
if do_multi_stream():
self.event_dict[EventType.MoeBalancer].wait()
def start_set_cpu_stage(self):
@ -502,7 +502,7 @@ class SingleLayerMoeLoadBalancer:
assert self.func_called_count["start_set_cpu_stage"] == 0
self.func_called_count["start_set_cpu_stage"] += 1
if self.updates_enabled:
if is_graph_capturing():
if do_multi_stream():
self.event_dict[EventType.Main].record()
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.Main].wait()
@ -522,7 +522,7 @@ class SingleLayerMoeLoadBalancer:
self.func_called_count[name] = 0
self.statistic_flag_tensor = None
if self.updates_enabled:
if is_graph_capturing():
if do_multi_stream():
self.event_dict[EventType.MoeBalancer].wait()
def update_local_statistic(self, local_raw_expert_ids: torch.Tensor,
@ -544,7 +544,7 @@ class SingleLayerMoeLoadBalancer:
(self.expert_count, ),
dtype=torch.int32,
device=torch.device('cuda'))
if is_graph_capturing():
if do_multi_stream():
self.event_dict[EventType.Main].record()
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.Main].wait()
@ -569,7 +569,7 @@ class SingleLayerMoeLoadBalancer:
assert self.func_called_count["update_local_statistic"] > 0
self.func_called_count["get_local_statistic_tensor"] += 1
if self.updates_enabled:
if is_graph_capturing():
if do_multi_stream():
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.MoeBalancer].record()
self.event_dict[EventType.MoeBalancer].wait()
@ -598,7 +598,7 @@ class SingleLayerMoeLoadBalancer:
self.single_layer_load_balancer_ptr)
if self.updates_enabled:
if is_graph_capturing():
if do_multi_stream():
self.event_dict[EventType.Main].record()
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.Main].wait()
@ -636,7 +636,7 @@ class SingleLayerMoeLoadBalancer:
if self.updates_enabled:
self.update_local_statistic(local_raw_expert_ids, is_first_stage,
is_last_stage)
if is_graph_capturing():
if do_multi_stream():
with torch.cuda.stream(self.aux_stream):
_update_statistic()
else:
@ -660,7 +660,7 @@ class SingleLayerMoeLoadBalancer:
assert self.func_called_count["update_statistic_with_local_ids"] == 0
self.func_called_count["update_statistic_with_global_ids"] += 1
if self.updates_enabled:
if is_graph_capturing():
if do_multi_stream():
self.event_dict[EventType.Main].record()
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.Main].wait()
@ -851,8 +851,8 @@ class MoeLoadBalancer:
"""
self.load_balancer_impl.set_warm_up_iter_count(iter_count)
def set_next_iter_info(self, enable_statistic: Optional[bool],
enable_update_weights: Optional[bool]):
def set_iter_info(self, enable_statistic: Optional[bool],
enable_update_weights: Optional[bool]):
if enable_statistic is not None:
self.enable_statistic = enable_statistic
if enable_update_weights is not None:
@ -998,8 +998,8 @@ class MoeLoadBalancerIterContext:
"""
if self.moe_load_balancer is not None and not self.moe_load_balancer.is_static_routing(
):
self.moe_load_balancer.set_next_iter_info(self.enable_statistic,
self.enable_updates)
self.moe_load_balancer.set_iter_info(self.enable_statistic,
self.enable_updates)
self.moe_load_balancer.start_iter()
return self

View File

@ -1,8 +1,35 @@
import threading
from contextlib import contextmanager
from typing import Any, Callable, Optional
import torch
from ..pyexecutor.cuda_graph_runner import is_graph_capturing
class do_multi_stream_local(threading.local):
def __init__(self):
self.do_multi_stream = False
_local = do_multi_stream_local()
def set_do_multi_stream(enable: bool):
_local.do_multi_stream = enable
def do_multi_stream() -> bool:
return _local.do_multi_stream
@contextmanager
def with_multi_stream(enable: bool):
prev_do_multi_stream = _local.do_multi_stream
set_do_multi_stream(enable)
try:
yield
finally:
set_do_multi_stream(prev_do_multi_stream)
def maybe_execute_in_parallel(
@ -30,9 +57,9 @@ def maybe_execute_in_parallel(
tuple[Any, Any]: the return values of fn0() and fn1()
"""
do_multi_stream = is_graph_capturing() and aux_stream is not None
multi_stream = do_multi_stream() and aux_stream is not None
if do_multi_stream:
if multi_stream:
event0.record()
result0 = fn0()

View File

@ -242,8 +242,8 @@ class KvCacheCreator:
torch_used_bytes = torch.cuda.memory_stats(
)["allocated_bytes.all.current"]
finally:
py_executor.shutdown()
py_executor.is_warmup = False
py_executor.shutdown()
py_executor.enable_iter_perf_stats = origin_iter_stats
py_executor.set_gather_responses(False)

View File

@ -79,6 +79,7 @@ class PyTorchConfig:
torch_compile_fullgraph: bool = True
torch_compile_inductor_enabled: bool = False
torch_compile_piecewise_cuda_graph: bool = False
torch_compile_piecewise_cuda_graph_num_tokens: Optional[List[int]] = None
# When torch compile is enabled, userbuffers is enabled by default
torch_compile_enable_userbuffers: bool = True
torch_compile_max_num_streams: int = 1

View File

@ -1,28 +1,11 @@
import threading
from typing import Any, Callable, Dict, Optional, Tuple
import torch
from ..attention_backend.interface import AttentionMetadata
from ..modules.multi_stream_utils import with_multi_stream
from ..speculative.interface import SpecMetadata
from ..utils import make_weak_ref, set_piecewise_cuda_graph_flag
class graph_capturing_local(threading.local):
def __init__(self):
self.is_graph_capturing = False
_local = graph_capturing_local()
def set_graph_capturing(enable: bool):
_local.is_graph_capturing = enable
def is_graph_capturing() -> bool:
return _local.is_graph_capturing
from ..utils import make_weak_ref, piecewise_cuda_graph
class DecodingCUDAGraphRunner:
@ -97,14 +80,11 @@ class DecodingCUDAGraphRunner:
# internal states according to the docs:
# https://pytorch.org/docs/stable/notes/cuda.html#cuda-graph-semantics
# This also lets us initialize states in the attn_metadata.
set_graph_capturing(True)
set_piecewise_cuda_graph_flag(False)
for _ in range(2):
forward_fn(inputs)
with torch.cuda.graph(self._graph, pool=pool):
output = forward_fn(inputs)
set_graph_capturing(False)
set_piecewise_cuda_graph_flag(True)
with with_multi_stream(True), piecewise_cuda_graph(False):
for _ in range(2):
forward_fn(inputs)
with torch.cuda.graph(self._graph, pool=pool):
output = forward_fn(inputs)
# Mark weak ref here. The output tensor should be freed properly.
self._output = make_weak_ref(output)
return self._graph.pool()

View File

@ -40,7 +40,7 @@ from ..attention_backend.utils import get_attention_backend
from ..attention_backend.vanilla import VanillaAttentionMetadata
from ..autotuner import AutoTuner, autotune
from ..compilation.backend import Backend
from ..compilation.utils import set_enable_piecewise_cuda_graph_capture_flag
from ..compilation.utils import capture_piecewise_cuda_graph
from ..distributed import MPIDist
from ..distributed.communicator import init_pp_comm
from ..expert_statistic import ExpertStatistic
@ -293,8 +293,6 @@ class PyTorchModelEngine(ModelEngine):
self.enable_spec_decode = self.is_spec_decode
self.is_draft_model = is_draft_model
self.in_warmup = False
self.attn_runtime_features = attn_runtime_features or AttentionRuntimeFeatures(
)
@ -335,6 +333,15 @@ class PyTorchModelEngine(ModelEngine):
pytorch_backend_config.torch_compile_piecewise_cuda_graph
and not self.enable_attention_dp)
piecewise_cuda_graph_num_tokens = (
pytorch_backend_config.torch_compile_piecewise_cuda_graph_num_tokens
or pytorch_backend_config.cuda_graph_batch_sizes or [])
self._piecewise_cuda_graph_num_tokens = [
i for i in piecewise_cuda_graph_num_tokens
if i <= self.max_num_tokens
]
try:
use_ub_for_nccl = (
pytorch_backend_config.allreduce_strategy == "NCCL_SYMMETRIC"
@ -349,8 +356,7 @@ class PyTorchModelEngine(ModelEngine):
enable_userbuffers=use_ub,
enable_piecewise_cuda_graph=self.
_torch_compile_piecewise_cuda_graph,
cuda_graph_batch_sizes=pytorch_backend_config.
cuda_graph_batch_sizes,
capture_num_tokens=self._piecewise_cuda_graph_num_tokens,
max_num_streams=pytorch_backend_config.
torch_compile_max_num_streams)
if isinstance(self.model, DecoderModelForCausalLM):
@ -373,6 +379,8 @@ class PyTorchModelEngine(ModelEngine):
traceback.print_exception(Exception, e, e.__traceback__)
raise e
self.is_warmup = False
self.attn_backend = get_attention_backend(attn_backend)
if self.is_spec_decode:
@ -478,17 +486,44 @@ class PyTorchModelEngine(ModelEngine):
logger.debug(f"Detected use_mrope: {use_mrope}")
return use_mrope
@property
def is_warmup(self):
return getattr(self, "_is_warmup", False)
@is_warmup.setter
def is_warmup(self, value: bool):
self._is_warmup = value
self.moe_load_balancer_iter_info = (not value, not value)
@property
def moe_load_balancer_iter_info(self):
moe_load_balancer: MoeLoadBalancer = getattr(self, 'moe_load_balancer',
None)
if moe_load_balancer is not None:
return moe_load_balancer.enable_statistic, moe_load_balancer.enable_update_weights
return False, False
@moe_load_balancer_iter_info.setter
def moe_load_balancer_iter_info(self, value: Tuple[bool, bool]):
moe_load_balancer: MoeLoadBalancer = getattr(self, 'moe_load_balancer',
None)
if moe_load_balancer is not None:
moe_load_balancer.set_iter_info(enable_statistic=value[0],
enable_update_weights=value[1])
@property
def use_beam_search(self):
return self.max_beam_width > 1
@contextmanager
def set_warmup_flag(self):
self.in_warmup = True
prev_is_warmup = self.is_warmup
self.is_warmup = True
try:
yield
finally:
self.in_warmup = False
self.is_warmup = prev_is_warmup
@staticmethod
def with_warmup_flag(method):
@ -669,120 +704,110 @@ class PyTorchModelEngine(ModelEngine):
if cp_type == CpType.STAR:
return
with contextlib.ExitStack() as stack:
if self._torch_compile_enabled:
if self._torch_compile_enabled:
def disable_optimization(backend: Backend):
# Disable torch.compile optimization and fallback to eager execution
backend.bypass_optimization()
# Disable piecewise CUDA graph capture since the capture run will produce wrong results
set_enable_piecewise_cuda_graph_capture_flag(False)
stack.callback(disable_optimization,
self._torch_compile_backend)
self._torch_compile_backend.enable_optimization()
# Disable cuda graph capture here so that we can properly capture it later
with self.no_cuda_graph():
available_tokens = kv_cache_manager.get_num_available_tokens(
self.runtime_draft_len)
warmup_batch_size = [1, self.batch_size // 2]
if self.batch_size < 2:
warmup_batch_size = [1]
for bs in warmup_batch_size:
for num_tokens_per_request in [
1,
min(self.max_num_tokens // max(bs, 1),
min(available_tokens, self.max_seq_len - 1))
]:
with release_batch(
get_torch_compile_warmup_request(
bs, num_tokens_per_request)) as batch:
if batch is None:
# No KV cache space!
continue
logger.info(
f"Run warmup for batch size={bs}, pure {'context' if num_tokens_per_request > 1 else 'generation'} phase"
)
self.forward(batch,
new_tensors_device=None,
resource_manager=resource_manager)
torch.cuda.synchronize()
if self.pytorch_backend_config.enable_autotuner:
with self.no_cuda_graph(), autotune():
result = get_autotune_warmup_request()
with release_batch(result) as batch:
if batch is None:
# No KV cache space!
pass
else:
# Disable cuda graph capture here so that we can properly capture it later
with self.no_cuda_graph():
available_tokens = kv_cache_manager.get_num_available_tokens(
self.runtime_draft_len)
warmup_batch_size = [1, self.batch_size // 2]
if self.batch_size < 2:
warmup_batch_size = [1]
for bs in warmup_batch_size:
for num_tokens_per_request in [
1,
min(self.max_num_tokens // max(bs, 1),
min(available_tokens, self.max_seq_len - 1))
]:
with release_batch(
get_torch_compile_warmup_request(
bs, num_tokens_per_request)) as batch:
if batch is None:
# No KV cache space!
continue
logger.info(
f"Run warmup for batch size={bs}, pure {'context' if num_tokens_per_request > 1 else 'generation'} phase"
)
self.forward(batch,
new_tensors_device=None,
resource_manager=resource_manager)
torch.cuda.synchronize()
logger.info(
f"[Autotuner] Cache size after warmup is {len(AutoTuner.get().profiling_cache)}"
)
AutoTuner.get().print_profiling_cache()
if not (self._run_cuda_graphs
or self._torch_compile_piecewise_cuda_graph):
return
logger.info(
f"Creating CUDA graph instances for {len(self._cuda_graph_batch_sizes)} batch sizes."
)
# Reverse the order of the cuda graph batch sizes to make smaller batch size graph could reuse larger batch size graph memory
cuda_graph_batch_sizes = sorted(self._cuda_graph_batch_sizes,
reverse=True)
# Create CUDA graphs for different draft lengths
draft_lengths = [self.max_draft_len]
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
# so that when we disable spec decode at runtime, we can still run the captured graph.
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
if (not self.is_draft_model and self.max_draft_len > 0
and not self.spec_config.spec_dec_mode.use_one_engine()
# Assume that speculation is always on if the user didn't give us a max_concurrency
# value. This will save on memory.
and self.spec_config.max_concurrency is not None):
draft_lengths.append(0)
for bs in cuda_graph_batch_sizes:
if bs > self.batch_size:
# skip batch size larger than self.batch_size
continue
for draft_len in draft_lengths:
with release_batch(
get_cuda_graph_warmup_request(bs,
draft_len)) as batch:
if batch is None:
# No KV cache space!
return
logger.info(
f"Run generation only CUDA graph warmup for batch size={bs}, draft_len={draft_len}"
)
self.enable_spec_decode = draft_len > 0 or self.is_draft_model
if self.pytorch_backend_config.enable_autotuner:
with self.no_cuda_graph(), autotune():
result = get_autotune_warmup_request()
with release_batch(result) as batch:
if batch is None:
# No KV cache space!
pass
else:
self.forward(batch,
new_tensors_device=None,
resource_manager=resource_manager)
torch.cuda.synchronize()
if self._torch_compile_piecewise_cuda_graph and self._torch_compile_enabled:
for seq_lens in cuda_graph_batch_sizes:
set_enable_piecewise_cuda_graph_capture_flag(True)
logger.info(
f"[Autotuner] Cache size after warmup is {len(AutoTuner.get().profiling_cache)}"
)
AutoTuner.get().print_profiling_cache()
if not (self._run_cuda_graphs
or self._torch_compile_piecewise_cuda_graph):
return
logger.info(
f"Creating CUDA graph instances for {len(self._cuda_graph_batch_sizes)} batch sizes."
)
# Reverse the order of the cuda graph batch sizes to make smaller batch size graph could reuse larger batch size graph memory
cuda_graph_batch_sizes = sorted(self._cuda_graph_batch_sizes,
reverse=True)
# Create CUDA graphs for different draft lengths
draft_lengths = [self.max_draft_len]
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
# so that when we disable spec decode at runtime, we can still run the captured graph.
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
if (not self.is_draft_model and self.max_draft_len > 0
and not self.spec_config.spec_dec_mode.use_one_engine()
# Assume that speculation is always on if the user didn't give us a max_concurrency
# value. This will save on memory.
and self.spec_config.max_concurrency is not None):
draft_lengths.append(0)
for bs in cuda_graph_batch_sizes:
if bs > self.batch_size:
# skip batch size larger than self.batch_size
continue
for draft_len in draft_lengths:
with release_batch(get_cuda_graph_warmup_request(
bs, draft_len)) as batch:
if batch is None:
# No KV cache space!
return
logger.info(
f"Run generation only CUDA graph warmup for batch size={bs}, draft_len={draft_len}"
)
self.enable_spec_decode = draft_len > 0 or self.is_draft_model
self.forward(batch,
new_tensors_device=None,
resource_manager=resource_manager)
torch.cuda.synchronize()
if self._torch_compile_piecewise_cuda_graph and self._torch_compile_enabled:
piecewise_cuda_graph_num_tokens = sorted(
self._piecewise_cuda_graph_num_tokens, reverse=True)
with capture_piecewise_cuda_graph(True):
for num_tokens in piecewise_cuda_graph_num_tokens:
with self.no_cuda_graph():
with release_batch(
get_torch_compile_warmup_request(
1, seq_lens)) as batch:
1, num_tokens)) as batch:
logger.info(
f"Run piecewise CUDA graph warmup for seq_lens={seq_lens}"
f"Run piecewise CUDA graph warmup for num tokens={num_tokens}"
)
# self.model.mtp_worker.stored_input_ids = []
for _ in range(3):
self.forward(batch,
new_tensors_device=None,
@ -793,7 +818,6 @@ class PyTorchModelEngine(ModelEngine):
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
set_enable_piecewise_cuda_graph_capture_flag(False)
# Set the value back to the original value
self.enable_spec_decode = self.is_spec_decode
@ -1541,7 +1565,7 @@ class PyTorchModelEngine(ModelEngine):
# Cache indirection is only used for beam search on generation requests
if self.use_beam_search and num_generation_requests > 0:
# CUDA Graph needs to set beam width during warmup (where the graph is captured), to ensure that cache indirection buffer is correctly picked up by the CUDA graph
is_cuda_graph_during_warmup = self.in_warmup and attn_metadata.is_cuda_graph
is_cuda_graph_during_warmup = self.is_warmup and attn_metadata.is_cuda_graph
if cache_indirection_buffer is not None:
#Copy cache indirection to local buffer with offsets changing: seq_slots[i] -> i
self.cache_indirection_attention[:num_generation_requests].copy_(
@ -2151,14 +2175,8 @@ class PyTorchModelEngine(ModelEngine):
spec_resource_manager = None
spec_metadata = None
moe_load_balancer = None
if hasattr(self, 'moe_load_balancer'):
moe_load_balancer = getattr(self, 'moe_load_balancer')
if not self.in_warmup:
moe_enable_statistic = True
moe_enable_update = True
moe_load_balancer.set_next_iter_info(moe_enable_statistic,
moe_enable_update)
moe_load_balancer: MoeLoadBalancer = getattr(self, 'moe_load_balancer',
None)
if kv_cache_manager is None:
inputs, gather_ids = self._prepare_tp_inputs_no_cache(

View File

@ -161,7 +161,6 @@ class PyExecutor:
self.profile_start_iters, self.profile_stop_iters = _load_iteration_indexes(
PROFILE_START_STOP_ENV_VAR_NAME)
self.gc_nvtx_watcher_handle = _gc_nvtx_watcher()
self.is_warmup = False # During warmup, we don't enable the profiler
# related modules
self.resource_manager = resource_manager
@ -220,9 +219,12 @@ class PyExecutor:
self.inflight_req_ids = ReqIdsSet()
# During warmup, we don't enable the profiler
self.is_warmup = True
self.model_engine.warmup(self.resource_manager)
if self.draft_model_engine is not None:
self.draft_model_engine.warmup(self.resource_manager)
self.is_warmup = False
self.is_shutdown = False
self.max_batch_size = max_batch_size
@ -280,6 +282,18 @@ class PyExecutor:
finally:
self._executor_loop_cleanup()
@property
def is_warmup(self) -> bool:
return getattr(self, "_is_warmup", False)
@is_warmup.setter
def is_warmup(self, value: bool):
self._is_warmup = value
# Set warmup flag in model engine to trigger torch compile and avoid moe load balancer statistics update
self.model_engine.is_warmup = value
if self.draft_model_engine is not None:
self.draft_model_engine.is_warmup = value
def start_worker(self):
with self.worker_lock:
if self.worker_started == False:

View File

@ -265,3 +265,13 @@ def set_piecewise_cuda_graph_flag(enable: bool):
def get_piecewise_cuda_graph_flag() -> bool:
global _enable_piecewise_cuda_graph
return _enable_piecewise_cuda_graph
@contextlib.contextmanager
def piecewise_cuda_graph(enable: bool):
prev_enable = get_piecewise_cuda_graph_flag()
set_piecewise_cuda_graph_flag(enable)
try:
yield
finally:
set_piecewise_cuda_graph_flag(prev_enable)

View File

@ -1990,6 +1990,21 @@ class TorchCompileConfig(StrictBaseModel):
default=False,
description="Enable piecewise CUDA graph in torch.compile.")
capture_num_tokens: Optional[List[int]] = Field(
default=None,
description=
"List of num of tokens to capture the piecewise CUDA graph for. If not provided, the number of tokens will be the same as cuda_graph_config.batch_sizes."
)
@field_validator('capture_num_tokens')
@classmethod
def validate_capture_num_tokens(cls, v):
if v is None:
return v
if any(t <= 0 for t in v):
raise ValueError("capture_num_tokens must contain positive ints.")
return sorted(set(v), reverse=True)
enable_userbuffers: bool = Field(
default=True,
description=
@ -2368,6 +2383,10 @@ class TorchLlmArgs(BaseLlmArgs):
enable_piecewise_cuda_graph
if self.torch_compile_config is not None else TorchCompileConfig.
model_fields['enable_piecewise_cuda_graph'].default,
torch_compile_piecewise_cuda_graph_num_tokens=self.
torch_compile_config.capture_num_tokens
if self.torch_compile_config is not None else
TorchCompileConfig.model_fields['capture_num_tokens'].default,
torch_compile_enable_userbuffers=self.torch_compile_config.
enable_userbuffers if self.torch_compile_config is not None else
TorchCompileConfig.model_fields['enable_userbuffers'].default,

View File

@ -269,7 +269,7 @@ class TestMoeLoadBalancer(unittest.TestCase):
mock_load_balancer_impl.return_value.set_warm_up_iter_count.assert_called_once_with(
10)
balancer.set_next_iter_info(True, True)
balancer.set_iter_info(True, True)
with MoeLoadBalancerIterContext(balancer):
mock_load_balancer_impl.return_value.start_iter.assert_called_once_with(
@ -308,7 +308,7 @@ class TestMoeLoadBalancer(unittest.TestCase):
balancer.finalize_model()
# enable statistic, disable weight update
balancer.set_next_iter_info(True, False)
balancer.set_iter_info(True, False)
# Create sample token data - each token selects 2 experts
# 4 tokens, each selecting 2 experts
@ -373,7 +373,7 @@ class TestMoeLoadBalancer(unittest.TestCase):
balancer.finalize_model()
# enable statistic, disable weight update
balancer.set_next_iter_info(True, False)
balancer.set_iter_info(True, False)
# Create sample token data - tokens selecting different experts
token_selected_experts = torch.tensor(