[TRTLLM-9198][perf] Add torch.compile + multi-stream support for k-cache scatter and weight scaling (#8988)

Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com>
Co-authored-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
This commit is contained in:
Chang Liu 2025-11-10 20:33:30 -08:00 committed by GitHub
parent c61b44e594
commit 7ceb5e5ab6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 43 additions and 19 deletions

View File

@ -17,6 +17,7 @@ from tensorrt_llm._torch.modules.multi_stream_utils import \
maybe_execute_in_parallel maybe_execute_in_parallel
from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm._torch.utils import maybe_compile
from tensorrt_llm._utils import get_size_in_bytes from tensorrt_llm._utils import get_size_in_bytes
from tensorrt_llm.bindings import DataType from tensorrt_llm.bindings import DataType
from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.bindings.executor import KvCacheConfig
@ -572,6 +573,12 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
self.on_update_kv_lens() self.on_update_kv_lens()
@maybe_compile(dynamic=True)
def _scale(weights: torch.Tensor, q_scale: torch.Tensor,
s: float) -> torch.Tensor:
return weights * q_scale.squeeze(-1) * s
class Indexer(nn.Module): class Indexer(nn.Module):
def __init__(self, def __init__(self,
@ -964,9 +971,6 @@ class Indexer(nn.Module):
if not use_custom_topk: if not use_custom_topk:
topk_indices_buffer[:hidden_states.shape[0]] = -1 topk_indices_buffer[:hidden_states.shape[0]] = -1
# Store k_fp8 and k_scale into indexer k cache
self._update_k_cache(k_fp8, k_scale, metadata)
if has_prefill: if has_prefill:
# Use chunked prefill to reduce memory footprint # Use chunked prefill to reduce memory footprint
if metadata.indexer_prefill_chunks is not None: if metadata.indexer_prefill_chunks is not None:
@ -1121,9 +1125,7 @@ class Indexer(nn.Module):
q_scale: torch.Tensor) -> torch.Tensor: q_scale: torch.Tensor) -> torch.Tensor:
weights = indexer_weights if indexer_weights is not None else self.weights_proj( weights = indexer_weights if indexer_weights is not None else self.weights_proj(
hidden_states) hidden_states)
weights = weights.unsqueeze(-1) * q_scale * self.weight_scale_factor weights = _scale(weights, q_scale, self.weight_scale_factor)
# output weights is guaranteed to be float32 due to type promotion from q_scale (float32)
weights = weights.squeeze(-1)
return weights return weights
@torch.inference_mode() @torch.inference_mode()
@ -1192,7 +1194,15 @@ class Indexer(nn.Module):
q_fp8 = q_fp8.view(-1, self.n_heads, self.head_dim) q_fp8 = q_fp8.view(-1, self.n_heads, self.head_dim)
q_scale = q_scale.view(-1, self.n_heads, 1) q_scale = q_scale.view(-1, self.n_heads, 1)
weights = self.weight_scale(hidden_states, indexer_weights, q_scale) weights, _ = maybe_execute_in_parallel(
lambda: self.weight_scale(hidden_states, indexer_weights, q_scale),
lambda: self._update_k_cache(
k_fp8, k_scale, metadata), # store k_fp8 and k_scale in k cache
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
# Return topk indices buffer for sparse attention [num_tokens, index_topk] # Return topk indices buffer for sparse attention [num_tokens, index_topk]
return self.sparse_attn_indexer(metadata, hidden_states, q_fp8, k_fp8, return self.sparse_attn_indexer(metadata, hidden_states, q_fp8, k_fp8,
k_scale, weights) k_scale, weights)

View File

@ -23,7 +23,7 @@ from ..distributed import AllReduceParams, alltoall_helix
from ..model_config import ModelConfig from ..model_config import ModelConfig
from ..peft.lora.layer import LoraLayer, LoraModuleType from ..peft.lora.layer import LoraLayer, LoraModuleType
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs, from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
is_piecewise_running, is_torch_compiling) is_torch_compiling, maybe_compile)
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
from .multi_stream_utils import maybe_execute_in_parallel from .multi_stream_utils import maybe_execute_in_parallel
from .rms_norm import RMSNorm from .rms_norm import RMSNorm
@ -76,17 +76,6 @@ def extract_extra_attrs(layer_idx: str, attn_type: str):
return metadata, attn_layer return metadata, attn_layer
def maybe_compile(func):
def wrapper(*args, **kwargs):
if is_piecewise_running():
# When piecewise running, we don't need to compile the function to avoid host overhead in attention op.
return func(*args, **kwargs)
return torch.compile(func)(*args, **kwargs)
return wrapper
@maybe_compile @maybe_compile
def maybe_compiled_copy_(dst, src): def maybe_compiled_copy_(dst, src):
dst.copy_(src) dst.copy_(src)

View File

@ -325,3 +325,26 @@ def get_device_uuid(device_idx: int) -> str:
property = torch.cuda.get_device_properties(device_idx) property = torch.cuda.get_device_properties(device_idx)
uuid = "GPU-" + str(property.uuid) uuid = "GPU-" + str(property.uuid)
return uuid return uuid
def maybe_compile(func=None, **compile_kwargs):
"""
Conditionally compile a function with torch.compile.
If is_piecewise_running() is True, the function will not be compiled to avoid host overhead in attention op.
Args:
func: The function to decorate (optional, for direct decoration).
**compile_kwargs: Keyword arguments for torch.compile.
Returns:
The conditionally compiled function..
"""
def decorator(f):
def wrapper(*args, **kwargs):
if is_piecewise_running():
return f(*args, **kwargs)
return torch.compile(f, **compile_kwargs)(*args, **kwargs)
return wrapper
return decorator(func) if func else decorator

View File

@ -1175,6 +1175,7 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type):
f" Chunk {i}: Q[{chunk.token_start}:{chunk.token_end}] ({num_q} tokens), " f" Chunk {i}: Q[{chunk.token_start}:{chunk.token_end}] ({num_q} tokens), "
f"K[{chunk.k_token_start}:{chunk.k_token_end}] ({num_k} tokens)") f"K[{chunk.k_token_start}:{chunk.k_token_end}] ({num_k} tokens)")
indexer._update_k_cache(k_fp8, k_scale, metadata_chunked)
topk_indices_chunked = indexer.sparse_attn_indexer(metadata_chunked, topk_indices_chunked = indexer.sparse_attn_indexer(metadata_chunked,
hidden_states, q_fp8, hidden_states, q_fp8,
k_fp8, k_scale, weights) k_fp8, k_scale, weights)
@ -1206,6 +1207,7 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type):
f"✓ Created {num_baseline_chunks} chunk(s) (effectively non-chunked)" f"✓ Created {num_baseline_chunks} chunk(s) (effectively non-chunked)"
) )
indexer._update_k_cache(k_fp8, k_scale, metadata_baseline)
topk_indices_baseline = indexer.sparse_attn_indexer(metadata_baseline, topk_indices_baseline = indexer.sparse_attn_indexer(metadata_baseline,
hidden_states, q_fp8, hidden_states, q_fp8,
k_fp8, k_scale, weights) k_fp8, k_scale, weights)