mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5550409][fix] Disable torch compile in piecewise attention part to Avoid host overhead (#8708)
Signed-off-by: yizhang-nv <187001205+yizhang-nv@users.noreply.github.com>
This commit is contained in:
parent
d626d13d37
commit
a69bd2a6fa
@ -13,7 +13,8 @@ from tensorrt_llm.llmapi.utils import enable_llm_debug
|
||||
|
||||
from ..utils import (get_model_extra_attrs,
|
||||
get_per_request_piecewise_cuda_graph_flag,
|
||||
get_piecewise_cuda_graph_flag, make_weak_ref)
|
||||
get_piecewise_cuda_graph_flag, make_weak_ref,
|
||||
set_piecewise_running)
|
||||
from .multi_stream.auto_multi_stream import multi_stream_schedule
|
||||
from .utils import get_capture_piecewise_cuda_graph_flag, is_call_function
|
||||
|
||||
@ -27,6 +28,7 @@ class PiecewiseInterpreter(Interpreter):
|
||||
compile_time_num_tokens: Union[int | torch.SymInt],
|
||||
capture_num_tokens: list[int],
|
||||
exclude_modules_id: list[int],
|
||||
piecewise_runner_num: int,
|
||||
graph_pool_handle: tuple[int, int],
|
||||
garbage_collect_values: bool = True,
|
||||
graph=None,
|
||||
@ -38,6 +40,8 @@ class PiecewiseInterpreter(Interpreter):
|
||||
|
||||
self.compile_time_num_tokens = compile_time_num_tokens
|
||||
self.capture_num_tokens = capture_num_tokens
|
||||
self.piecewise_runner_num = piecewise_runner_num
|
||||
self.piecewise_runner_idx = 0
|
||||
self.exclude_modules = [f"submod_{i}" for i in exclude_modules_id]
|
||||
self.graph_pool_handle = graph_pool_handle
|
||||
self.enable_inductor = enable_inductor
|
||||
@ -90,8 +94,10 @@ class PiecewiseInterpreter(Interpreter):
|
||||
self.graph_pool_handle,
|
||||
compile_fx(submod, args) if self.enable_inductor else submod,
|
||||
self.enable_inductor,
|
||||
self.piecewise_runner_idx == 0,
|
||||
self.piecewise_runner_idx == self.piecewise_runner_num - 1,
|
||||
)
|
||||
|
||||
self.piecewise_runner_idx += 1
|
||||
return output
|
||||
|
||||
|
||||
@ -124,6 +130,8 @@ class PiecewiseRunner(object):
|
||||
graph_pool_handle,
|
||||
default_callable: Callable,
|
||||
enable_inductor: bool,
|
||||
is_first_runner: bool,
|
||||
is_last_runner: bool,
|
||||
):
|
||||
if runtime_num_tokens_idx != None:
|
||||
assert isinstance(compile_time_num_tokens, torch.SymInt)
|
||||
@ -138,6 +146,8 @@ class PiecewiseRunner(object):
|
||||
self.enable_inductor = enable_inductor
|
||||
|
||||
self.entries: dict[int, Entry] = {}
|
||||
self.is_first_runner = is_first_runner
|
||||
self.is_last_runner = is_last_runner
|
||||
|
||||
for num_tokens in capture_num_tokens:
|
||||
self.entries[num_tokens] = Entry(
|
||||
@ -161,6 +171,12 @@ class PiecewiseRunner(object):
|
||||
or not get_per_request_piecewise_cuda_graph_flag()):
|
||||
return self.default_callable(*args)
|
||||
|
||||
if self.is_first_runner or self.is_last_runner:
|
||||
if self.is_first_runner == self.is_last_runner:
|
||||
set_piecewise_running(False)
|
||||
else:
|
||||
set_piecewise_running(self.is_first_runner)
|
||||
|
||||
entry = self.entries[runtime_num_of_token]
|
||||
|
||||
if entry.enable_inductor and not entry.compiled:
|
||||
@ -267,6 +283,7 @@ def piecewise_optimizer(
|
||||
input_num_tokens,
|
||||
capture_num_tokens,
|
||||
exclude_modules_id,
|
||||
len(set(node_to_graph_id.values())) - len(exclude_modules_id),
|
||||
graph_pool_handle,
|
||||
max_num_streams=max_num_streams,
|
||||
)
|
||||
|
||||
@ -23,7 +23,7 @@ from ..distributed import AllReduceParams
|
||||
from ..model_config import ModelConfig
|
||||
from ..peft.lora.layer import LoraLayer, LoraModuleType
|
||||
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
|
||||
is_torch_compiling)
|
||||
is_piecewise_running, is_torch_compiling)
|
||||
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
|
||||
from .multi_stream_utils import maybe_execute_in_parallel
|
||||
from .rms_norm import RMSNorm
|
||||
@ -76,13 +76,24 @@ def extract_extra_attrs(layer_idx: str, attn_type: str):
|
||||
return metadata, attn_layer
|
||||
|
||||
|
||||
@torch.compile
|
||||
def compiled_copy_(dst, src):
|
||||
def maybe_compile(func):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if is_piecewise_running():
|
||||
# When piecewise running, we don't need to compile the function to avoid host overhead in attention op.
|
||||
return func(*args, **kwargs)
|
||||
return torch.compile(func)(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@maybe_compile
|
||||
def maybe_compiled_copy_(dst, src):
|
||||
dst.copy_(src)
|
||||
|
||||
|
||||
@torch.compile
|
||||
def compiled_cat(tensors, dim):
|
||||
@maybe_compile
|
||||
def maybe_compiled_cat(tensors, dim):
|
||||
return torch.cat(tensors, dim)
|
||||
|
||||
|
||||
@ -1222,8 +1233,9 @@ class MLA(nn.Module):
|
||||
)
|
||||
|
||||
k = torch.empty_like(q).view(-1, self.num_heads, self.qk_head_dim)
|
||||
compiled_copy_(k[..., :self.qk_nope_head_dim],
|
||||
k_nope.view(-1, self.num_heads, self.qk_nope_head_dim))
|
||||
maybe_compiled_copy_(
|
||||
k[..., :self.qk_nope_head_dim],
|
||||
k_nope.view(-1, self.num_heads, self.qk_nope_head_dim))
|
||||
if self.apply_rotary_emb:
|
||||
k[..., self.qk_nope_head_dim:] = k_pe.view(-1, 1,
|
||||
self.qk_rope_head_dim)
|
||||
@ -1317,7 +1329,7 @@ class MLA(nn.Module):
|
||||
full_k_nope = full_k_nope.view(-1, self.num_heads,
|
||||
self.qk_nope_head_dim)
|
||||
full_k_pe = full_k_pe.view(-1, 1, self.qk_rope_head_dim)
|
||||
full_k = compiled_cat(
|
||||
full_k = maybe_compiled_cat(
|
||||
(full_k_nope, full_k_pe.expand(-1, self.num_heads, -1)), dim=-1)
|
||||
full_k = full_k.view(-1, self.num_heads * self.qk_head_dim)
|
||||
|
||||
@ -1412,7 +1424,7 @@ class MLA(nn.Module):
|
||||
chunked_k_nope = chunked_k_nope.view(-1, self.num_heads,
|
||||
self.qk_nope_head_dim)
|
||||
chunked_k_pe = chunked_k_pe.view(-1, 1, self.qk_rope_head_dim)
|
||||
chunked_k = compiled_cat(
|
||||
chunked_k = maybe_compiled_cat(
|
||||
(chunked_k_nope, chunked_k_pe.expand(-1, self.num_heads, -1)),
|
||||
dim=-1)
|
||||
chunked_k = chunked_k.view(-1, self.num_heads * self.qk_head_dim)
|
||||
@ -1470,7 +1482,8 @@ class MLA(nn.Module):
|
||||
|
||||
k_nope = k_nope.view(-1, self.num_heads, self.qk_nope_head_dim)
|
||||
k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)
|
||||
k = compiled_cat((k_nope, k_pe.expand(-1, self.num_heads, -1)), dim=-1)
|
||||
k = maybe_compiled_cat((k_nope, k_pe.expand(-1, self.num_heads, -1)),
|
||||
dim=-1)
|
||||
k = k.view(-1, self.num_heads * self.qk_head_dim)
|
||||
|
||||
# copy q_lens to replace kv_lens_runtime
|
||||
|
||||
@ -12,6 +12,7 @@ from tensorrt_llm.math_utils import ceil_div, pad_up
|
||||
from tensorrt_llm.quantization.utils import fp4_utils
|
||||
|
||||
is_torch_compiling_flag = False
|
||||
is_piecewise_running_flag = False
|
||||
|
||||
aux_stream_name_list = [
|
||||
'Attention',
|
||||
@ -40,6 +41,16 @@ def is_torch_compiling() -> bool:
|
||||
return is_torch_compiling_flag
|
||||
|
||||
|
||||
def set_piecewise_running(enable: bool):
|
||||
global is_piecewise_running_flag
|
||||
is_piecewise_running_flag = enable
|
||||
|
||||
|
||||
def is_piecewise_running() -> bool:
|
||||
global is_piecewise_running_flag
|
||||
return is_piecewise_running_flag
|
||||
|
||||
|
||||
_global_attrs = threading.local()
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user