[https://nvbugs/5550409][fix] Disable torch compile in piecewise attention part to Avoid host overhead (#8708)

Signed-off-by: yizhang-nv <187001205+yizhang-nv@users.noreply.github.com>
This commit is contained in:
Yi Zhang 2025-10-29 18:12:58 +08:00 committed by GitHub
parent d626d13d37
commit a69bd2a6fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 53 additions and 12 deletions

View File

@ -13,7 +13,8 @@ from tensorrt_llm.llmapi.utils import enable_llm_debug
from ..utils import (get_model_extra_attrs,
get_per_request_piecewise_cuda_graph_flag,
get_piecewise_cuda_graph_flag, make_weak_ref)
get_piecewise_cuda_graph_flag, make_weak_ref,
set_piecewise_running)
from .multi_stream.auto_multi_stream import multi_stream_schedule
from .utils import get_capture_piecewise_cuda_graph_flag, is_call_function
@ -27,6 +28,7 @@ class PiecewiseInterpreter(Interpreter):
compile_time_num_tokens: Union[int | torch.SymInt],
capture_num_tokens: list[int],
exclude_modules_id: list[int],
piecewise_runner_num: int,
graph_pool_handle: tuple[int, int],
garbage_collect_values: bool = True,
graph=None,
@ -38,6 +40,8 @@ class PiecewiseInterpreter(Interpreter):
self.compile_time_num_tokens = compile_time_num_tokens
self.capture_num_tokens = capture_num_tokens
self.piecewise_runner_num = piecewise_runner_num
self.piecewise_runner_idx = 0
self.exclude_modules = [f"submod_{i}" for i in exclude_modules_id]
self.graph_pool_handle = graph_pool_handle
self.enable_inductor = enable_inductor
@ -90,8 +94,10 @@ class PiecewiseInterpreter(Interpreter):
self.graph_pool_handle,
compile_fx(submod, args) if self.enable_inductor else submod,
self.enable_inductor,
self.piecewise_runner_idx == 0,
self.piecewise_runner_idx == self.piecewise_runner_num - 1,
)
self.piecewise_runner_idx += 1
return output
@ -124,6 +130,8 @@ class PiecewiseRunner(object):
graph_pool_handle,
default_callable: Callable,
enable_inductor: bool,
is_first_runner: bool,
is_last_runner: bool,
):
if runtime_num_tokens_idx != None:
assert isinstance(compile_time_num_tokens, torch.SymInt)
@ -138,6 +146,8 @@ class PiecewiseRunner(object):
self.enable_inductor = enable_inductor
self.entries: dict[int, Entry] = {}
self.is_first_runner = is_first_runner
self.is_last_runner = is_last_runner
for num_tokens in capture_num_tokens:
self.entries[num_tokens] = Entry(
@ -161,6 +171,12 @@ class PiecewiseRunner(object):
or not get_per_request_piecewise_cuda_graph_flag()):
return self.default_callable(*args)
if self.is_first_runner or self.is_last_runner:
if self.is_first_runner == self.is_last_runner:
set_piecewise_running(False)
else:
set_piecewise_running(self.is_first_runner)
entry = self.entries[runtime_num_of_token]
if entry.enable_inductor and not entry.compiled:
@ -267,6 +283,7 @@ def piecewise_optimizer(
input_num_tokens,
capture_num_tokens,
exclude_modules_id,
len(set(node_to_graph_id.values())) - len(exclude_modules_id),
graph_pool_handle,
max_num_streams=max_num_streams,
)

View File

@ -23,7 +23,7 @@ from ..distributed import AllReduceParams
from ..model_config import ModelConfig
from ..peft.lora.layer import LoraLayer, LoraModuleType
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
is_torch_compiling)
is_piecewise_running, is_torch_compiling)
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
from .multi_stream_utils import maybe_execute_in_parallel
from .rms_norm import RMSNorm
@ -76,13 +76,24 @@ def extract_extra_attrs(layer_idx: str, attn_type: str):
return metadata, attn_layer
@torch.compile
def compiled_copy_(dst, src):
def maybe_compile(func):
def wrapper(*args, **kwargs):
if is_piecewise_running():
# When piecewise running, we don't need to compile the function to avoid host overhead in attention op.
return func(*args, **kwargs)
return torch.compile(func)(*args, **kwargs)
return wrapper
@maybe_compile
def maybe_compiled_copy_(dst, src):
dst.copy_(src)
@torch.compile
def compiled_cat(tensors, dim):
@maybe_compile
def maybe_compiled_cat(tensors, dim):
return torch.cat(tensors, dim)
@ -1222,8 +1233,9 @@ class MLA(nn.Module):
)
k = torch.empty_like(q).view(-1, self.num_heads, self.qk_head_dim)
compiled_copy_(k[..., :self.qk_nope_head_dim],
k_nope.view(-1, self.num_heads, self.qk_nope_head_dim))
maybe_compiled_copy_(
k[..., :self.qk_nope_head_dim],
k_nope.view(-1, self.num_heads, self.qk_nope_head_dim))
if self.apply_rotary_emb:
k[..., self.qk_nope_head_dim:] = k_pe.view(-1, 1,
self.qk_rope_head_dim)
@ -1317,7 +1329,7 @@ class MLA(nn.Module):
full_k_nope = full_k_nope.view(-1, self.num_heads,
self.qk_nope_head_dim)
full_k_pe = full_k_pe.view(-1, 1, self.qk_rope_head_dim)
full_k = compiled_cat(
full_k = maybe_compiled_cat(
(full_k_nope, full_k_pe.expand(-1, self.num_heads, -1)), dim=-1)
full_k = full_k.view(-1, self.num_heads * self.qk_head_dim)
@ -1412,7 +1424,7 @@ class MLA(nn.Module):
chunked_k_nope = chunked_k_nope.view(-1, self.num_heads,
self.qk_nope_head_dim)
chunked_k_pe = chunked_k_pe.view(-1, 1, self.qk_rope_head_dim)
chunked_k = compiled_cat(
chunked_k = maybe_compiled_cat(
(chunked_k_nope, chunked_k_pe.expand(-1, self.num_heads, -1)),
dim=-1)
chunked_k = chunked_k.view(-1, self.num_heads * self.qk_head_dim)
@ -1470,7 +1482,8 @@ class MLA(nn.Module):
k_nope = k_nope.view(-1, self.num_heads, self.qk_nope_head_dim)
k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)
k = compiled_cat((k_nope, k_pe.expand(-1, self.num_heads, -1)), dim=-1)
k = maybe_compiled_cat((k_nope, k_pe.expand(-1, self.num_heads, -1)),
dim=-1)
k = k.view(-1, self.num_heads * self.qk_head_dim)
# copy q_lens to replace kv_lens_runtime

View File

@ -12,6 +12,7 @@ from tensorrt_llm.math_utils import ceil_div, pad_up
from tensorrt_llm.quantization.utils import fp4_utils
is_torch_compiling_flag = False
is_piecewise_running_flag = False
aux_stream_name_list = [
'Attention',
@ -40,6 +41,16 @@ def is_torch_compiling() -> bool:
return is_torch_compiling_flag
def set_piecewise_running(enable: bool):
global is_piecewise_running_flag
is_piecewise_running_flag = enable
def is_piecewise_running() -> bool:
global is_piecewise_running_flag
return is_piecewise_running_flag
_global_attrs = threading.local()