mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-9198][perf] Add torch.compile + multi-stream support for k-cache scatter and weight scaling (#8988)
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com> Co-authored-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
This commit is contained in:
parent
c61b44e594
commit
7ceb5e5ab6
@ -17,6 +17,7 @@ from tensorrt_llm._torch.modules.multi_stream_utils import \
|
||||
maybe_execute_in_parallel
|
||||
from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm._torch.utils import maybe_compile
|
||||
from tensorrt_llm._utils import get_size_in_bytes
|
||||
from tensorrt_llm.bindings import DataType
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
@ -572,6 +573,12 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
|
||||
self.on_update_kv_lens()
|
||||
|
||||
|
||||
@maybe_compile(dynamic=True)
|
||||
def _scale(weights: torch.Tensor, q_scale: torch.Tensor,
|
||||
s: float) -> torch.Tensor:
|
||||
return weights * q_scale.squeeze(-1) * s
|
||||
|
||||
|
||||
class Indexer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
@ -964,9 +971,6 @@ class Indexer(nn.Module):
|
||||
if not use_custom_topk:
|
||||
topk_indices_buffer[:hidden_states.shape[0]] = -1
|
||||
|
||||
# Store k_fp8 and k_scale into indexer k cache
|
||||
self._update_k_cache(k_fp8, k_scale, metadata)
|
||||
|
||||
if has_prefill:
|
||||
# Use chunked prefill to reduce memory footprint
|
||||
if metadata.indexer_prefill_chunks is not None:
|
||||
@ -1121,9 +1125,7 @@ class Indexer(nn.Module):
|
||||
q_scale: torch.Tensor) -> torch.Tensor:
|
||||
weights = indexer_weights if indexer_weights is not None else self.weights_proj(
|
||||
hidden_states)
|
||||
weights = weights.unsqueeze(-1) * q_scale * self.weight_scale_factor
|
||||
# output weights is guaranteed to be float32 due to type promotion from q_scale (float32)
|
||||
weights = weights.squeeze(-1)
|
||||
weights = _scale(weights, q_scale, self.weight_scale_factor)
|
||||
return weights
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -1192,7 +1194,15 @@ class Indexer(nn.Module):
|
||||
q_fp8 = q_fp8.view(-1, self.n_heads, self.head_dim)
|
||||
q_scale = q_scale.view(-1, self.n_heads, 1)
|
||||
|
||||
weights = self.weight_scale(hidden_states, indexer_weights, q_scale)
|
||||
weights, _ = maybe_execute_in_parallel(
|
||||
lambda: self.weight_scale(hidden_states, indexer_weights, q_scale),
|
||||
lambda: self._update_k_cache(
|
||||
k_fp8, k_scale, metadata), # store k_fp8 and k_scale in k cache
|
||||
self.ln_events[0],
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
|
||||
# Return topk indices buffer for sparse attention [num_tokens, index_topk]
|
||||
return self.sparse_attn_indexer(metadata, hidden_states, q_fp8, k_fp8,
|
||||
k_scale, weights)
|
||||
|
||||
@ -23,7 +23,7 @@ from ..distributed import AllReduceParams, alltoall_helix
|
||||
from ..model_config import ModelConfig
|
||||
from ..peft.lora.layer import LoraLayer, LoraModuleType
|
||||
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
|
||||
is_piecewise_running, is_torch_compiling)
|
||||
is_torch_compiling, maybe_compile)
|
||||
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
|
||||
from .multi_stream_utils import maybe_execute_in_parallel
|
||||
from .rms_norm import RMSNorm
|
||||
@ -76,17 +76,6 @@ def extract_extra_attrs(layer_idx: str, attn_type: str):
|
||||
return metadata, attn_layer
|
||||
|
||||
|
||||
def maybe_compile(func):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if is_piecewise_running():
|
||||
# When piecewise running, we don't need to compile the function to avoid host overhead in attention op.
|
||||
return func(*args, **kwargs)
|
||||
return torch.compile(func)(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@maybe_compile
|
||||
def maybe_compiled_copy_(dst, src):
|
||||
dst.copy_(src)
|
||||
|
||||
@ -325,3 +325,26 @@ def get_device_uuid(device_idx: int) -> str:
|
||||
property = torch.cuda.get_device_properties(device_idx)
|
||||
uuid = "GPU-" + str(property.uuid)
|
||||
return uuid
|
||||
|
||||
|
||||
def maybe_compile(func=None, **compile_kwargs):
|
||||
"""
|
||||
Conditionally compile a function with torch.compile.
|
||||
If is_piecewise_running() is True, the function will not be compiled to avoid host overhead in attention op.
|
||||
Args:
|
||||
func: The function to decorate (optional, for direct decoration).
|
||||
**compile_kwargs: Keyword arguments for torch.compile.
|
||||
Returns:
|
||||
The conditionally compiled function..
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if is_piecewise_running():
|
||||
return f(*args, **kwargs)
|
||||
return torch.compile(f, **compile_kwargs)(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator(func) if func else decorator
|
||||
|
||||
@ -1175,6 +1175,7 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type):
|
||||
f" Chunk {i}: Q[{chunk.token_start}:{chunk.token_end}] ({num_q} tokens), "
|
||||
f"K[{chunk.k_token_start}:{chunk.k_token_end}] ({num_k} tokens)")
|
||||
|
||||
indexer._update_k_cache(k_fp8, k_scale, metadata_chunked)
|
||||
topk_indices_chunked = indexer.sparse_attn_indexer(metadata_chunked,
|
||||
hidden_states, q_fp8,
|
||||
k_fp8, k_scale, weights)
|
||||
@ -1206,6 +1207,7 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type):
|
||||
f"✓ Created {num_baseline_chunks} chunk(s) (effectively non-chunked)"
|
||||
)
|
||||
|
||||
indexer._update_k_cache(k_fp8, k_scale, metadata_baseline)
|
||||
topk_indices_baseline = indexer.sparse_attn_indexer(metadata_baseline,
|
||||
hidden_states, q_fp8,
|
||||
k_fp8, k_scale, weights)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user