mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][refactor] Refactor Torch Compile Backend, MoeLoadBalancer and warmup Logic (#6615)
Signed-off-by: yizhang-nv <187001205+yizhang-nv@users.noreply.github.com> Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
This commit is contained in:
parent
71e28eab36
commit
a15af879ec
@ -37,7 +37,7 @@ class Backend:
|
||||
enable_inductor=True,
|
||||
enable_userbuffers=False,
|
||||
enable_piecewise_cuda_graph: bool = False,
|
||||
cuda_graph_batch_sizes: Optional[List[int]] = None,
|
||||
capture_num_tokens: Optional[List[int]] = None,
|
||||
max_num_streams: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -48,14 +48,12 @@ class Backend:
|
||||
self.custom_passes = Backend.get_custom_pass(enable_userbuffers)
|
||||
self.rank = tensorrt_llm.mpi_rank()
|
||||
self.enable_inductor = enable_inductor
|
||||
self.cuda_graph_batch_sizes = (cuda_graph_batch_sizes
|
||||
if cuda_graph_batch_sizes is not None
|
||||
else [])
|
||||
self.capture_num_tokens = capture_num_tokens or []
|
||||
self.piecewise_cuda_graph = enable_piecewise_cuda_graph
|
||||
self.no_optimization = False
|
||||
# We only need to create aux streams.
|
||||
self.aux_streams = Backend.Streams(
|
||||
[torch.cuda.Stream() for i in range(max_num_streams - 1)])
|
||||
[torch.cuda.Stream() for _ in range(max_num_streams - 1)])
|
||||
self.events = Backend.Events()
|
||||
inductor_config.enable_auto_functionalized_v2 = False
|
||||
|
||||
@ -125,7 +123,7 @@ class Backend:
|
||||
example_inputs,
|
||||
self.enable_inductor,
|
||||
self.input_num_tokens,
|
||||
self.cuda_graph_batch_sizes,
|
||||
self.capture_num_tokens,
|
||||
self._graph_pool_handle,
|
||||
len(self.aux_streams) + 1,
|
||||
)
|
||||
|
||||
@ -14,8 +14,7 @@ from tensorrt_llm.llmapi.utils import enable_llm_debug
|
||||
from ..utils import (get_model_extra_attrs, get_piecewise_cuda_graph_flag,
|
||||
make_weak_ref)
|
||||
from .multi_stream.auto_multi_stream import multi_stream_schedule
|
||||
from .utils import (get_enable_piecewise_cuda_graph_capture_flag,
|
||||
is_call_function)
|
||||
from .utils import get_capture_piecewise_cuda_graph_flag, is_call_function
|
||||
|
||||
|
||||
class PiecewiseInterpreter(Interpreter):
|
||||
@ -25,7 +24,7 @@ class PiecewiseInterpreter(Interpreter):
|
||||
module: GraphModule,
|
||||
enable_inductor: bool,
|
||||
compile_time_num_tokens: Union[int | torch.SymInt],
|
||||
cuda_graph_batch_sizes: list[int],
|
||||
capture_num_tokens: list[int],
|
||||
exclude_modules_id: list[int],
|
||||
graph_pool_handle: tuple[int, int],
|
||||
garbage_collect_values: bool = True,
|
||||
@ -37,7 +36,7 @@ class PiecewiseInterpreter(Interpreter):
|
||||
self.fake_mode = detect_fake_mode()
|
||||
|
||||
self.compile_time_num_tokens = compile_time_num_tokens
|
||||
self.cuda_graph_batch_sizes = cuda_graph_batch_sizes
|
||||
self.capture_num_tokens = capture_num_tokens
|
||||
self.exclude_modules = [f"submod_{i}" for i in exclude_modules_id]
|
||||
self.graph_pool_handle = graph_pool_handle
|
||||
self.enable_inductor = enable_inductor
|
||||
@ -86,7 +85,7 @@ class PiecewiseInterpreter(Interpreter):
|
||||
target,
|
||||
self.compile_time_num_tokens,
|
||||
runtime_num_tokens_idx,
|
||||
self.cuda_graph_batch_sizes,
|
||||
self.capture_num_tokens,
|
||||
self.graph_pool_handle,
|
||||
compile_fx(submod, args) if self.enable_inductor else submod,
|
||||
self.enable_inductor,
|
||||
@ -120,7 +119,7 @@ class PiecewiseRunner(object):
|
||||
name: str,
|
||||
compile_time_num_tokens: Union[int | torch.SymInt],
|
||||
runtime_num_tokens_idx: tuple[int],
|
||||
cuda_graph_batch_sizes: List[int],
|
||||
capture_num_tokens: List[int],
|
||||
graph_pool_handle,
|
||||
default_callable: Callable,
|
||||
enable_inductor: bool,
|
||||
@ -139,9 +138,9 @@ class PiecewiseRunner(object):
|
||||
|
||||
self.entries: dict[int, Entry] = {}
|
||||
|
||||
for bs in cuda_graph_batch_sizes:
|
||||
self.entries[bs] = Entry(
|
||||
bs,
|
||||
for num_tokens in capture_num_tokens:
|
||||
self.entries[num_tokens] = Entry(
|
||||
num_tokens,
|
||||
enable_inductor=self.enable_inductor,
|
||||
callable=default_callable,
|
||||
)
|
||||
@ -167,7 +166,7 @@ class PiecewiseRunner(object):
|
||||
|
||||
if entry.cuda_graph is None:
|
||||
|
||||
if not get_enable_piecewise_cuda_graph_capture_flag():
|
||||
if not get_capture_piecewise_cuda_graph_flag():
|
||||
return entry.callable(*args)
|
||||
|
||||
if entry.warmup_count < 3:
|
||||
@ -228,7 +227,7 @@ def piecewise_optimizer(
|
||||
example_inputs: List[torch.Tensor],
|
||||
enable_inductor: bool,
|
||||
input_num_tokens: Union[int | torch.SymInt],
|
||||
cuda_graph_batch_sizes: Sequence[int],
|
||||
capture_num_tokens: Sequence[int],
|
||||
graph_pool_handle: tuple[int, int],
|
||||
max_num_streams: int = 1,
|
||||
) -> tuple[GraphModule, int]:
|
||||
@ -269,7 +268,7 @@ def piecewise_optimizer(
|
||||
gm,
|
||||
enable_inductor,
|
||||
input_num_tokens,
|
||||
cuda_graph_batch_sizes,
|
||||
capture_num_tokens,
|
||||
exclude_modules_id,
|
||||
graph_pool_handle,
|
||||
max_num_streams=max_num_streams,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
from typing import Callable, List, Union
|
||||
|
||||
import torch
|
||||
@ -33,16 +34,26 @@ def is_call_function(node: Node, target: Union[List[Callable], Callable]):
|
||||
_enable_piecewise_cuda_graph_capture = False
|
||||
|
||||
|
||||
def set_enable_piecewise_cuda_graph_capture_flag(enable: bool):
|
||||
def set_capture_piecewise_cuda_graph_flag(enable: bool):
|
||||
global _enable_piecewise_cuda_graph_capture
|
||||
_enable_piecewise_cuda_graph_capture = enable
|
||||
|
||||
|
||||
def get_enable_piecewise_cuda_graph_capture_flag() -> bool:
|
||||
def get_capture_piecewise_cuda_graph_flag() -> bool:
|
||||
global _enable_piecewise_cuda_graph_capture
|
||||
return _enable_piecewise_cuda_graph_capture
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def capture_piecewise_cuda_graph(enable: bool):
|
||||
prev_enable = get_capture_piecewise_cuda_graph_flag()
|
||||
set_capture_piecewise_cuda_graph_flag(enable)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
set_capture_piecewise_cuda_graph_flag(prev_enable)
|
||||
|
||||
|
||||
def inplace_info():
|
||||
inplace_map = {
|
||||
torch.ops.trtllm.flashinfer_fused_add_rmsnorm.default: {
|
||||
|
||||
@ -8,6 +8,7 @@ from tensorrt_llm._utils import get_sm_version
|
||||
|
||||
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
|
||||
OptimizationProfile, TunableRunner, TuningConfig)
|
||||
from ..modules.multi_stream_utils import do_multi_stream
|
||||
from ..utils import (fp4_scale_infer_shape,
|
||||
get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2)
|
||||
@ -925,6 +926,8 @@ def get_stream(stream_id: int):
|
||||
|
||||
@torch.library.custom_op("trtllm::set_stream", mutates_args=())
|
||||
def set_stream(stream_id: int) -> None:
|
||||
if not do_multi_stream():
|
||||
return
|
||||
stream = get_stream(stream_id)
|
||||
assert stream is not None
|
||||
torch.cuda.set_stream(stream)
|
||||
@ -932,18 +935,24 @@ def set_stream(stream_id: int) -> None:
|
||||
|
||||
@torch.library.custom_op("trtllm::record_event", mutates_args=())
|
||||
def record_event(event_idx: int) -> None:
|
||||
if not do_multi_stream():
|
||||
return
|
||||
event = get_event(event_idx)
|
||||
event.record()
|
||||
|
||||
|
||||
@torch.library.custom_op("trtllm::wait_event", mutates_args=())
|
||||
def wait_event(event_idx: int) -> None:
|
||||
if not do_multi_stream():
|
||||
return
|
||||
event = get_event(event_idx)
|
||||
event.wait()
|
||||
|
||||
|
||||
@torch.library.custom_op("trtllm::record_stream", mutates_args=())
|
||||
def record_stream(tensor: torch.Tensor, stream_id: int) -> None:
|
||||
if not do_multi_stream():
|
||||
return
|
||||
stream = get_stream(stream_id)
|
||||
assert stream is not None
|
||||
tensor.record_stream(stream)
|
||||
|
||||
@ -9,12 +9,12 @@ from mpi4py import MPI
|
||||
|
||||
import tensorrt_llm
|
||||
import tensorrt_llm.bindings.internal.runtime as _tbr
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import is_graph_capturing
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ...distributed import AllReduce
|
||||
from ...utils import EventType
|
||||
from ..multi_stream_utils import do_multi_stream
|
||||
|
||||
|
||||
def _tensor_to_weight(t: torch.Tensor) -> _tbr.MoeWeight:
|
||||
@ -472,7 +472,7 @@ class SingleLayerMoeLoadBalancer:
|
||||
assert self.func_called_count["start_wait_gpu_stage"] == 0
|
||||
self.func_called_count["start_wait_gpu_stage"] += 1
|
||||
if self.updates_enabled:
|
||||
if is_graph_capturing():
|
||||
if do_multi_stream():
|
||||
self.event_dict[EventType.Main].record()
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
self.event_dict[EventType.Main].wait()
|
||||
@ -491,7 +491,7 @@ class SingleLayerMoeLoadBalancer:
|
||||
assert self.func_called_count["done_wait_gpu_stage"] == 0
|
||||
self.func_called_count["done_wait_gpu_stage"] += 1
|
||||
if self.updates_enabled:
|
||||
if is_graph_capturing():
|
||||
if do_multi_stream():
|
||||
self.event_dict[EventType.MoeBalancer].wait()
|
||||
|
||||
def start_set_cpu_stage(self):
|
||||
@ -502,7 +502,7 @@ class SingleLayerMoeLoadBalancer:
|
||||
assert self.func_called_count["start_set_cpu_stage"] == 0
|
||||
self.func_called_count["start_set_cpu_stage"] += 1
|
||||
if self.updates_enabled:
|
||||
if is_graph_capturing():
|
||||
if do_multi_stream():
|
||||
self.event_dict[EventType.Main].record()
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
self.event_dict[EventType.Main].wait()
|
||||
@ -522,7 +522,7 @@ class SingleLayerMoeLoadBalancer:
|
||||
self.func_called_count[name] = 0
|
||||
self.statistic_flag_tensor = None
|
||||
if self.updates_enabled:
|
||||
if is_graph_capturing():
|
||||
if do_multi_stream():
|
||||
self.event_dict[EventType.MoeBalancer].wait()
|
||||
|
||||
def update_local_statistic(self, local_raw_expert_ids: torch.Tensor,
|
||||
@ -544,7 +544,7 @@ class SingleLayerMoeLoadBalancer:
|
||||
(self.expert_count, ),
|
||||
dtype=torch.int32,
|
||||
device=torch.device('cuda'))
|
||||
if is_graph_capturing():
|
||||
if do_multi_stream():
|
||||
self.event_dict[EventType.Main].record()
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
self.event_dict[EventType.Main].wait()
|
||||
@ -569,7 +569,7 @@ class SingleLayerMoeLoadBalancer:
|
||||
assert self.func_called_count["update_local_statistic"] > 0
|
||||
self.func_called_count["get_local_statistic_tensor"] += 1
|
||||
if self.updates_enabled:
|
||||
if is_graph_capturing():
|
||||
if do_multi_stream():
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
self.event_dict[EventType.MoeBalancer].record()
|
||||
self.event_dict[EventType.MoeBalancer].wait()
|
||||
@ -598,7 +598,7 @@ class SingleLayerMoeLoadBalancer:
|
||||
self.single_layer_load_balancer_ptr)
|
||||
|
||||
if self.updates_enabled:
|
||||
if is_graph_capturing():
|
||||
if do_multi_stream():
|
||||
self.event_dict[EventType.Main].record()
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
self.event_dict[EventType.Main].wait()
|
||||
@ -636,7 +636,7 @@ class SingleLayerMoeLoadBalancer:
|
||||
if self.updates_enabled:
|
||||
self.update_local_statistic(local_raw_expert_ids, is_first_stage,
|
||||
is_last_stage)
|
||||
if is_graph_capturing():
|
||||
if do_multi_stream():
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
_update_statistic()
|
||||
else:
|
||||
@ -660,7 +660,7 @@ class SingleLayerMoeLoadBalancer:
|
||||
assert self.func_called_count["update_statistic_with_local_ids"] == 0
|
||||
self.func_called_count["update_statistic_with_global_ids"] += 1
|
||||
if self.updates_enabled:
|
||||
if is_graph_capturing():
|
||||
if do_multi_stream():
|
||||
self.event_dict[EventType.Main].record()
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
self.event_dict[EventType.Main].wait()
|
||||
@ -851,8 +851,8 @@ class MoeLoadBalancer:
|
||||
"""
|
||||
self.load_balancer_impl.set_warm_up_iter_count(iter_count)
|
||||
|
||||
def set_next_iter_info(self, enable_statistic: Optional[bool],
|
||||
enable_update_weights: Optional[bool]):
|
||||
def set_iter_info(self, enable_statistic: Optional[bool],
|
||||
enable_update_weights: Optional[bool]):
|
||||
if enable_statistic is not None:
|
||||
self.enable_statistic = enable_statistic
|
||||
if enable_update_weights is not None:
|
||||
@ -998,8 +998,8 @@ class MoeLoadBalancerIterContext:
|
||||
"""
|
||||
if self.moe_load_balancer is not None and not self.moe_load_balancer.is_static_routing(
|
||||
):
|
||||
self.moe_load_balancer.set_next_iter_info(self.enable_statistic,
|
||||
self.enable_updates)
|
||||
self.moe_load_balancer.set_iter_info(self.enable_statistic,
|
||||
self.enable_updates)
|
||||
self.moe_load_balancer.start_iter()
|
||||
return self
|
||||
|
||||
|
||||
@ -1,8 +1,35 @@
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..pyexecutor.cuda_graph_runner import is_graph_capturing
|
||||
|
||||
class do_multi_stream_local(threading.local):
|
||||
|
||||
def __init__(self):
|
||||
self.do_multi_stream = False
|
||||
|
||||
|
||||
_local = do_multi_stream_local()
|
||||
|
||||
|
||||
def set_do_multi_stream(enable: bool):
|
||||
_local.do_multi_stream = enable
|
||||
|
||||
|
||||
def do_multi_stream() -> bool:
|
||||
return _local.do_multi_stream
|
||||
|
||||
|
||||
@contextmanager
|
||||
def with_multi_stream(enable: bool):
|
||||
prev_do_multi_stream = _local.do_multi_stream
|
||||
set_do_multi_stream(enable)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
set_do_multi_stream(prev_do_multi_stream)
|
||||
|
||||
|
||||
def maybe_execute_in_parallel(
|
||||
@ -30,9 +57,9 @@ def maybe_execute_in_parallel(
|
||||
tuple[Any, Any]: the return values of fn0() and fn1()
|
||||
"""
|
||||
|
||||
do_multi_stream = is_graph_capturing() and aux_stream is not None
|
||||
multi_stream = do_multi_stream() and aux_stream is not None
|
||||
|
||||
if do_multi_stream:
|
||||
if multi_stream:
|
||||
event0.record()
|
||||
result0 = fn0()
|
||||
|
||||
|
||||
@ -242,8 +242,8 @@ class KvCacheCreator:
|
||||
torch_used_bytes = torch.cuda.memory_stats(
|
||||
)["allocated_bytes.all.current"]
|
||||
finally:
|
||||
py_executor.shutdown()
|
||||
py_executor.is_warmup = False
|
||||
py_executor.shutdown()
|
||||
py_executor.enable_iter_perf_stats = origin_iter_stats
|
||||
py_executor.set_gather_responses(False)
|
||||
|
||||
|
||||
@ -79,6 +79,7 @@ class PyTorchConfig:
|
||||
torch_compile_fullgraph: bool = True
|
||||
torch_compile_inductor_enabled: bool = False
|
||||
torch_compile_piecewise_cuda_graph: bool = False
|
||||
torch_compile_piecewise_cuda_graph_num_tokens: Optional[List[int]] = None
|
||||
# When torch compile is enabled, userbuffers is enabled by default
|
||||
torch_compile_enable_userbuffers: bool = True
|
||||
torch_compile_max_num_streams: int = 1
|
||||
|
||||
@ -1,28 +1,11 @@
|
||||
import threading
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ..attention_backend.interface import AttentionMetadata
|
||||
from ..modules.multi_stream_utils import with_multi_stream
|
||||
from ..speculative.interface import SpecMetadata
|
||||
from ..utils import make_weak_ref, set_piecewise_cuda_graph_flag
|
||||
|
||||
|
||||
class graph_capturing_local(threading.local):
|
||||
|
||||
def __init__(self):
|
||||
self.is_graph_capturing = False
|
||||
|
||||
|
||||
_local = graph_capturing_local()
|
||||
|
||||
|
||||
def set_graph_capturing(enable: bool):
|
||||
_local.is_graph_capturing = enable
|
||||
|
||||
|
||||
def is_graph_capturing() -> bool:
|
||||
return _local.is_graph_capturing
|
||||
from ..utils import make_weak_ref, piecewise_cuda_graph
|
||||
|
||||
|
||||
class DecodingCUDAGraphRunner:
|
||||
@ -97,14 +80,11 @@ class DecodingCUDAGraphRunner:
|
||||
# internal states according to the docs:
|
||||
# https://pytorch.org/docs/stable/notes/cuda.html#cuda-graph-semantics
|
||||
# This also lets us initialize states in the attn_metadata.
|
||||
set_graph_capturing(True)
|
||||
set_piecewise_cuda_graph_flag(False)
|
||||
for _ in range(2):
|
||||
forward_fn(inputs)
|
||||
with torch.cuda.graph(self._graph, pool=pool):
|
||||
output = forward_fn(inputs)
|
||||
set_graph_capturing(False)
|
||||
set_piecewise_cuda_graph_flag(True)
|
||||
with with_multi_stream(True), piecewise_cuda_graph(False):
|
||||
for _ in range(2):
|
||||
forward_fn(inputs)
|
||||
with torch.cuda.graph(self._graph, pool=pool):
|
||||
output = forward_fn(inputs)
|
||||
# Mark weak ref here. The output tensor should be freed properly.
|
||||
self._output = make_weak_ref(output)
|
||||
return self._graph.pool()
|
||||
|
||||
@ -40,7 +40,7 @@ from ..attention_backend.utils import get_attention_backend
|
||||
from ..attention_backend.vanilla import VanillaAttentionMetadata
|
||||
from ..autotuner import AutoTuner, autotune
|
||||
from ..compilation.backend import Backend
|
||||
from ..compilation.utils import set_enable_piecewise_cuda_graph_capture_flag
|
||||
from ..compilation.utils import capture_piecewise_cuda_graph
|
||||
from ..distributed import MPIDist
|
||||
from ..distributed.communicator import init_pp_comm
|
||||
from ..expert_statistic import ExpertStatistic
|
||||
@ -293,8 +293,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self.enable_spec_decode = self.is_spec_decode
|
||||
self.is_draft_model = is_draft_model
|
||||
|
||||
self.in_warmup = False
|
||||
|
||||
self.attn_runtime_features = attn_runtime_features or AttentionRuntimeFeatures(
|
||||
)
|
||||
|
||||
@ -335,6 +333,15 @@ class PyTorchModelEngine(ModelEngine):
|
||||
pytorch_backend_config.torch_compile_piecewise_cuda_graph
|
||||
and not self.enable_attention_dp)
|
||||
|
||||
piecewise_cuda_graph_num_tokens = (
|
||||
pytorch_backend_config.torch_compile_piecewise_cuda_graph_num_tokens
|
||||
or pytorch_backend_config.cuda_graph_batch_sizes or [])
|
||||
|
||||
self._piecewise_cuda_graph_num_tokens = [
|
||||
i for i in piecewise_cuda_graph_num_tokens
|
||||
if i <= self.max_num_tokens
|
||||
]
|
||||
|
||||
try:
|
||||
use_ub_for_nccl = (
|
||||
pytorch_backend_config.allreduce_strategy == "NCCL_SYMMETRIC"
|
||||
@ -349,8 +356,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
enable_userbuffers=use_ub,
|
||||
enable_piecewise_cuda_graph=self.
|
||||
_torch_compile_piecewise_cuda_graph,
|
||||
cuda_graph_batch_sizes=pytorch_backend_config.
|
||||
cuda_graph_batch_sizes,
|
||||
capture_num_tokens=self._piecewise_cuda_graph_num_tokens,
|
||||
max_num_streams=pytorch_backend_config.
|
||||
torch_compile_max_num_streams)
|
||||
if isinstance(self.model, DecoderModelForCausalLM):
|
||||
@ -373,6 +379,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
traceback.print_exception(Exception, e, e.__traceback__)
|
||||
raise e
|
||||
|
||||
self.is_warmup = False
|
||||
|
||||
self.attn_backend = get_attention_backend(attn_backend)
|
||||
|
||||
if self.is_spec_decode:
|
||||
@ -478,17 +486,44 @@ class PyTorchModelEngine(ModelEngine):
|
||||
logger.debug(f"Detected use_mrope: {use_mrope}")
|
||||
return use_mrope
|
||||
|
||||
@property
|
||||
def is_warmup(self):
|
||||
return getattr(self, "_is_warmup", False)
|
||||
|
||||
@is_warmup.setter
|
||||
def is_warmup(self, value: bool):
|
||||
self._is_warmup = value
|
||||
|
||||
self.moe_load_balancer_iter_info = (not value, not value)
|
||||
|
||||
@property
|
||||
def moe_load_balancer_iter_info(self):
|
||||
moe_load_balancer: MoeLoadBalancer = getattr(self, 'moe_load_balancer',
|
||||
None)
|
||||
if moe_load_balancer is not None:
|
||||
return moe_load_balancer.enable_statistic, moe_load_balancer.enable_update_weights
|
||||
return False, False
|
||||
|
||||
@moe_load_balancer_iter_info.setter
|
||||
def moe_load_balancer_iter_info(self, value: Tuple[bool, bool]):
|
||||
moe_load_balancer: MoeLoadBalancer = getattr(self, 'moe_load_balancer',
|
||||
None)
|
||||
if moe_load_balancer is not None:
|
||||
moe_load_balancer.set_iter_info(enable_statistic=value[0],
|
||||
enable_update_weights=value[1])
|
||||
|
||||
@property
|
||||
def use_beam_search(self):
|
||||
return self.max_beam_width > 1
|
||||
|
||||
@contextmanager
|
||||
def set_warmup_flag(self):
|
||||
self.in_warmup = True
|
||||
prev_is_warmup = self.is_warmup
|
||||
self.is_warmup = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.in_warmup = False
|
||||
self.is_warmup = prev_is_warmup
|
||||
|
||||
@staticmethod
|
||||
def with_warmup_flag(method):
|
||||
@ -669,120 +704,110 @@ class PyTorchModelEngine(ModelEngine):
|
||||
if cp_type == CpType.STAR:
|
||||
return
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
if self._torch_compile_enabled:
|
||||
if self._torch_compile_enabled:
|
||||
|
||||
def disable_optimization(backend: Backend):
|
||||
# Disable torch.compile optimization and fallback to eager execution
|
||||
backend.bypass_optimization()
|
||||
# Disable piecewise CUDA graph capture since the capture run will produce wrong results
|
||||
set_enable_piecewise_cuda_graph_capture_flag(False)
|
||||
|
||||
stack.callback(disable_optimization,
|
||||
self._torch_compile_backend)
|
||||
|
||||
self._torch_compile_backend.enable_optimization()
|
||||
|
||||
# Disable cuda graph capture here so that we can properly capture it later
|
||||
with self.no_cuda_graph():
|
||||
available_tokens = kv_cache_manager.get_num_available_tokens(
|
||||
self.runtime_draft_len)
|
||||
warmup_batch_size = [1, self.batch_size // 2]
|
||||
if self.batch_size < 2:
|
||||
warmup_batch_size = [1]
|
||||
for bs in warmup_batch_size:
|
||||
for num_tokens_per_request in [
|
||||
1,
|
||||
min(self.max_num_tokens // max(bs, 1),
|
||||
min(available_tokens, self.max_seq_len - 1))
|
||||
]:
|
||||
with release_batch(
|
||||
get_torch_compile_warmup_request(
|
||||
bs, num_tokens_per_request)) as batch:
|
||||
if batch is None:
|
||||
# No KV cache space!
|
||||
continue
|
||||
logger.info(
|
||||
f"Run warmup for batch size={bs}, pure {'context' if num_tokens_per_request > 1 else 'generation'} phase"
|
||||
)
|
||||
self.forward(batch,
|
||||
new_tensors_device=None,
|
||||
resource_manager=resource_manager)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if self.pytorch_backend_config.enable_autotuner:
|
||||
with self.no_cuda_graph(), autotune():
|
||||
result = get_autotune_warmup_request()
|
||||
with release_batch(result) as batch:
|
||||
if batch is None:
|
||||
# No KV cache space!
|
||||
pass
|
||||
else:
|
||||
# Disable cuda graph capture here so that we can properly capture it later
|
||||
with self.no_cuda_graph():
|
||||
available_tokens = kv_cache_manager.get_num_available_tokens(
|
||||
self.runtime_draft_len)
|
||||
warmup_batch_size = [1, self.batch_size // 2]
|
||||
if self.batch_size < 2:
|
||||
warmup_batch_size = [1]
|
||||
for bs in warmup_batch_size:
|
||||
for num_tokens_per_request in [
|
||||
1,
|
||||
min(self.max_num_tokens // max(bs, 1),
|
||||
min(available_tokens, self.max_seq_len - 1))
|
||||
]:
|
||||
with release_batch(
|
||||
get_torch_compile_warmup_request(
|
||||
bs, num_tokens_per_request)) as batch:
|
||||
if batch is None:
|
||||
# No KV cache space!
|
||||
continue
|
||||
logger.info(
|
||||
f"Run warmup for batch size={bs}, pure {'context' if num_tokens_per_request > 1 else 'generation'} phase"
|
||||
)
|
||||
self.forward(batch,
|
||||
new_tensors_device=None,
|
||||
resource_manager=resource_manager)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
logger.info(
|
||||
f"[Autotuner] Cache size after warmup is {len(AutoTuner.get().profiling_cache)}"
|
||||
)
|
||||
|
||||
AutoTuner.get().print_profiling_cache()
|
||||
|
||||
if not (self._run_cuda_graphs
|
||||
or self._torch_compile_piecewise_cuda_graph):
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Creating CUDA graph instances for {len(self._cuda_graph_batch_sizes)} batch sizes."
|
||||
)
|
||||
# Reverse the order of the cuda graph batch sizes to make smaller batch size graph could reuse larger batch size graph memory
|
||||
cuda_graph_batch_sizes = sorted(self._cuda_graph_batch_sizes,
|
||||
reverse=True)
|
||||
# Create CUDA graphs for different draft lengths
|
||||
draft_lengths = [self.max_draft_len]
|
||||
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
|
||||
# so that when we disable spec decode at runtime, we can still run the captured graph.
|
||||
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
|
||||
if (not self.is_draft_model and self.max_draft_len > 0
|
||||
and not self.spec_config.spec_dec_mode.use_one_engine()
|
||||
# Assume that speculation is always on if the user didn't give us a max_concurrency
|
||||
# value. This will save on memory.
|
||||
and self.spec_config.max_concurrency is not None):
|
||||
draft_lengths.append(0)
|
||||
|
||||
for bs in cuda_graph_batch_sizes:
|
||||
if bs > self.batch_size:
|
||||
# skip batch size larger than self.batch_size
|
||||
continue
|
||||
|
||||
for draft_len in draft_lengths:
|
||||
with release_batch(
|
||||
get_cuda_graph_warmup_request(bs,
|
||||
draft_len)) as batch:
|
||||
if batch is None:
|
||||
# No KV cache space!
|
||||
return
|
||||
logger.info(
|
||||
f"Run generation only CUDA graph warmup for batch size={bs}, draft_len={draft_len}"
|
||||
)
|
||||
self.enable_spec_decode = draft_len > 0 or self.is_draft_model
|
||||
if self.pytorch_backend_config.enable_autotuner:
|
||||
with self.no_cuda_graph(), autotune():
|
||||
result = get_autotune_warmup_request()
|
||||
with release_batch(result) as batch:
|
||||
if batch is None:
|
||||
# No KV cache space!
|
||||
pass
|
||||
else:
|
||||
self.forward(batch,
|
||||
new_tensors_device=None,
|
||||
resource_manager=resource_manager)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if self._torch_compile_piecewise_cuda_graph and self._torch_compile_enabled:
|
||||
for seq_lens in cuda_graph_batch_sizes:
|
||||
set_enable_piecewise_cuda_graph_capture_flag(True)
|
||||
logger.info(
|
||||
f"[Autotuner] Cache size after warmup is {len(AutoTuner.get().profiling_cache)}"
|
||||
)
|
||||
|
||||
AutoTuner.get().print_profiling_cache()
|
||||
|
||||
if not (self._run_cuda_graphs
|
||||
or self._torch_compile_piecewise_cuda_graph):
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Creating CUDA graph instances for {len(self._cuda_graph_batch_sizes)} batch sizes."
|
||||
)
|
||||
# Reverse the order of the cuda graph batch sizes to make smaller batch size graph could reuse larger batch size graph memory
|
||||
cuda_graph_batch_sizes = sorted(self._cuda_graph_batch_sizes,
|
||||
reverse=True)
|
||||
# Create CUDA graphs for different draft lengths
|
||||
draft_lengths = [self.max_draft_len]
|
||||
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
|
||||
# so that when we disable spec decode at runtime, we can still run the captured graph.
|
||||
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
|
||||
if (not self.is_draft_model and self.max_draft_len > 0
|
||||
and not self.spec_config.spec_dec_mode.use_one_engine()
|
||||
# Assume that speculation is always on if the user didn't give us a max_concurrency
|
||||
# value. This will save on memory.
|
||||
and self.spec_config.max_concurrency is not None):
|
||||
draft_lengths.append(0)
|
||||
|
||||
for bs in cuda_graph_batch_sizes:
|
||||
if bs > self.batch_size:
|
||||
# skip batch size larger than self.batch_size
|
||||
continue
|
||||
|
||||
for draft_len in draft_lengths:
|
||||
with release_batch(get_cuda_graph_warmup_request(
|
||||
bs, draft_len)) as batch:
|
||||
if batch is None:
|
||||
# No KV cache space!
|
||||
return
|
||||
logger.info(
|
||||
f"Run generation only CUDA graph warmup for batch size={bs}, draft_len={draft_len}"
|
||||
)
|
||||
self.enable_spec_decode = draft_len > 0 or self.is_draft_model
|
||||
self.forward(batch,
|
||||
new_tensors_device=None,
|
||||
resource_manager=resource_manager)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if self._torch_compile_piecewise_cuda_graph and self._torch_compile_enabled:
|
||||
piecewise_cuda_graph_num_tokens = sorted(
|
||||
self._piecewise_cuda_graph_num_tokens, reverse=True)
|
||||
|
||||
with capture_piecewise_cuda_graph(True):
|
||||
for num_tokens in piecewise_cuda_graph_num_tokens:
|
||||
with self.no_cuda_graph():
|
||||
with release_batch(
|
||||
get_torch_compile_warmup_request(
|
||||
1, seq_lens)) as batch:
|
||||
1, num_tokens)) as batch:
|
||||
logger.info(
|
||||
f"Run piecewise CUDA graph warmup for seq_lens={seq_lens}"
|
||||
f"Run piecewise CUDA graph warmup for num tokens={num_tokens}"
|
||||
)
|
||||
# self.model.mtp_worker.stored_input_ids = []
|
||||
|
||||
for _ in range(3):
|
||||
self.forward(batch,
|
||||
new_tensors_device=None,
|
||||
@ -793,7 +818,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
torch.cuda.synchronize()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
set_enable_piecewise_cuda_graph_capture_flag(False)
|
||||
|
||||
# Set the value back to the original value
|
||||
self.enable_spec_decode = self.is_spec_decode
|
||||
@ -1541,7 +1565,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# Cache indirection is only used for beam search on generation requests
|
||||
if self.use_beam_search and num_generation_requests > 0:
|
||||
# CUDA Graph needs to set beam width during warmup (where the graph is captured), to ensure that cache indirection buffer is correctly picked up by the CUDA graph
|
||||
is_cuda_graph_during_warmup = self.in_warmup and attn_metadata.is_cuda_graph
|
||||
is_cuda_graph_during_warmup = self.is_warmup and attn_metadata.is_cuda_graph
|
||||
if cache_indirection_buffer is not None:
|
||||
#Copy cache indirection to local buffer with offsets changing: seq_slots[i] -> i
|
||||
self.cache_indirection_attention[:num_generation_requests].copy_(
|
||||
@ -2151,14 +2175,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
spec_resource_manager = None
|
||||
spec_metadata = None
|
||||
|
||||
moe_load_balancer = None
|
||||
if hasattr(self, 'moe_load_balancer'):
|
||||
moe_load_balancer = getattr(self, 'moe_load_balancer')
|
||||
if not self.in_warmup:
|
||||
moe_enable_statistic = True
|
||||
moe_enable_update = True
|
||||
moe_load_balancer.set_next_iter_info(moe_enable_statistic,
|
||||
moe_enable_update)
|
||||
moe_load_balancer: MoeLoadBalancer = getattr(self, 'moe_load_balancer',
|
||||
None)
|
||||
|
||||
if kv_cache_manager is None:
|
||||
inputs, gather_ids = self._prepare_tp_inputs_no_cache(
|
||||
|
||||
@ -161,7 +161,6 @@ class PyExecutor:
|
||||
self.profile_start_iters, self.profile_stop_iters = _load_iteration_indexes(
|
||||
PROFILE_START_STOP_ENV_VAR_NAME)
|
||||
self.gc_nvtx_watcher_handle = _gc_nvtx_watcher()
|
||||
self.is_warmup = False # During warmup, we don't enable the profiler
|
||||
|
||||
# related modules
|
||||
self.resource_manager = resource_manager
|
||||
@ -220,9 +219,12 @@ class PyExecutor:
|
||||
|
||||
self.inflight_req_ids = ReqIdsSet()
|
||||
|
||||
# During warmup, we don't enable the profiler
|
||||
self.is_warmup = True
|
||||
self.model_engine.warmup(self.resource_manager)
|
||||
if self.draft_model_engine is not None:
|
||||
self.draft_model_engine.warmup(self.resource_manager)
|
||||
self.is_warmup = False
|
||||
|
||||
self.is_shutdown = False
|
||||
self.max_batch_size = max_batch_size
|
||||
@ -280,6 +282,18 @@ class PyExecutor:
|
||||
finally:
|
||||
self._executor_loop_cleanup()
|
||||
|
||||
@property
|
||||
def is_warmup(self) -> bool:
|
||||
return getattr(self, "_is_warmup", False)
|
||||
|
||||
@is_warmup.setter
|
||||
def is_warmup(self, value: bool):
|
||||
self._is_warmup = value
|
||||
# Set warmup flag in model engine to trigger torch compile and avoid moe load balancer statistics update
|
||||
self.model_engine.is_warmup = value
|
||||
if self.draft_model_engine is not None:
|
||||
self.draft_model_engine.is_warmup = value
|
||||
|
||||
def start_worker(self):
|
||||
with self.worker_lock:
|
||||
if self.worker_started == False:
|
||||
|
||||
@ -265,3 +265,13 @@ def set_piecewise_cuda_graph_flag(enable: bool):
|
||||
def get_piecewise_cuda_graph_flag() -> bool:
|
||||
global _enable_piecewise_cuda_graph
|
||||
return _enable_piecewise_cuda_graph
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def piecewise_cuda_graph(enable: bool):
|
||||
prev_enable = get_piecewise_cuda_graph_flag()
|
||||
set_piecewise_cuda_graph_flag(enable)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
set_piecewise_cuda_graph_flag(prev_enable)
|
||||
|
||||
@ -1990,6 +1990,21 @@ class TorchCompileConfig(StrictBaseModel):
|
||||
default=False,
|
||||
description="Enable piecewise CUDA graph in torch.compile.")
|
||||
|
||||
capture_num_tokens: Optional[List[int]] = Field(
|
||||
default=None,
|
||||
description=
|
||||
"List of num of tokens to capture the piecewise CUDA graph for. If not provided, the number of tokens will be the same as cuda_graph_config.batch_sizes."
|
||||
)
|
||||
|
||||
@field_validator('capture_num_tokens')
|
||||
@classmethod
|
||||
def validate_capture_num_tokens(cls, v):
|
||||
if v is None:
|
||||
return v
|
||||
if any(t <= 0 for t in v):
|
||||
raise ValueError("capture_num_tokens must contain positive ints.")
|
||||
return sorted(set(v), reverse=True)
|
||||
|
||||
enable_userbuffers: bool = Field(
|
||||
default=True,
|
||||
description=
|
||||
@ -2368,6 +2383,10 @@ class TorchLlmArgs(BaseLlmArgs):
|
||||
enable_piecewise_cuda_graph
|
||||
if self.torch_compile_config is not None else TorchCompileConfig.
|
||||
model_fields['enable_piecewise_cuda_graph'].default,
|
||||
torch_compile_piecewise_cuda_graph_num_tokens=self.
|
||||
torch_compile_config.capture_num_tokens
|
||||
if self.torch_compile_config is not None else
|
||||
TorchCompileConfig.model_fields['capture_num_tokens'].default,
|
||||
torch_compile_enable_userbuffers=self.torch_compile_config.
|
||||
enable_userbuffers if self.torch_compile_config is not None else
|
||||
TorchCompileConfig.model_fields['enable_userbuffers'].default,
|
||||
|
||||
@ -269,7 +269,7 @@ class TestMoeLoadBalancer(unittest.TestCase):
|
||||
mock_load_balancer_impl.return_value.set_warm_up_iter_count.assert_called_once_with(
|
||||
10)
|
||||
|
||||
balancer.set_next_iter_info(True, True)
|
||||
balancer.set_iter_info(True, True)
|
||||
|
||||
with MoeLoadBalancerIterContext(balancer):
|
||||
mock_load_balancer_impl.return_value.start_iter.assert_called_once_with(
|
||||
@ -308,7 +308,7 @@ class TestMoeLoadBalancer(unittest.TestCase):
|
||||
balancer.finalize_model()
|
||||
|
||||
# enable statistic, disable weight update
|
||||
balancer.set_next_iter_info(True, False)
|
||||
balancer.set_iter_info(True, False)
|
||||
|
||||
# Create sample token data - each token selects 2 experts
|
||||
# 4 tokens, each selecting 2 experts
|
||||
@ -373,7 +373,7 @@ class TestMoeLoadBalancer(unittest.TestCase):
|
||||
balancer.finalize_model()
|
||||
|
||||
# enable statistic, disable weight update
|
||||
balancer.set_next_iter_info(True, False)
|
||||
balancer.set_iter_info(True, False)
|
||||
|
||||
# Create sample token data - tokens selecting different experts
|
||||
token_selected_experts = torch.tensor(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user