diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index c30a0dc470..84a9d63b7a 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -17,6 +17,7 @@ from tensorrt_llm._torch.modules.multi_stream_utils import \ maybe_execute_in_parallel from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager +from tensorrt_llm._torch.utils import maybe_compile from tensorrt_llm._utils import get_size_in_bytes from tensorrt_llm.bindings import DataType from tensorrt_llm.bindings.executor import KvCacheConfig @@ -572,6 +573,12 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata): self.on_update_kv_lens() +@maybe_compile(dynamic=True) +def _scale(weights: torch.Tensor, q_scale: torch.Tensor, + s: float) -> torch.Tensor: + return weights * q_scale.squeeze(-1) * s + + class Indexer(nn.Module): def __init__(self, @@ -964,9 +971,6 @@ class Indexer(nn.Module): if not use_custom_topk: topk_indices_buffer[:hidden_states.shape[0]] = -1 - # Store k_fp8 and k_scale into indexer k cache - self._update_k_cache(k_fp8, k_scale, metadata) - if has_prefill: # Use chunked prefill to reduce memory footprint if metadata.indexer_prefill_chunks is not None: @@ -1121,9 +1125,7 @@ class Indexer(nn.Module): q_scale: torch.Tensor) -> torch.Tensor: weights = indexer_weights if indexer_weights is not None else self.weights_proj( hidden_states) - weights = weights.unsqueeze(-1) * q_scale * self.weight_scale_factor - # output weights is guaranteed to be float32 due to type promotion from q_scale (float32) - weights = weights.squeeze(-1) + weights = _scale(weights, q_scale, self.weight_scale_factor) return weights @torch.inference_mode() @@ -1192,7 +1194,15 @@ class Indexer(nn.Module): q_fp8 = q_fp8.view(-1, self.n_heads, self.head_dim) q_scale = q_scale.view(-1, self.n_heads, 1) - weights = self.weight_scale(hidden_states, indexer_weights, q_scale) + weights, _ = maybe_execute_in_parallel( + lambda: self.weight_scale(hidden_states, indexer_weights, q_scale), + lambda: self._update_k_cache( + k_fp8, k_scale, metadata), # store k_fp8 and k_scale in k cache + self.ln_events[0], + self.ln_events[1], + self.aux_stream, + ) + # Return topk indices buffer for sparse attention [num_tokens, index_topk] return self.sparse_attn_indexer(metadata, hidden_states, q_fp8, k_fp8, k_scale, weights) diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 834eaaee4d..7b5c5e429c 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -23,7 +23,7 @@ from ..distributed import AllReduceParams, alltoall_helix from ..model_config import ModelConfig from ..peft.lora.layer import LoraLayer, LoraModuleType from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs, - is_piecewise_running, is_torch_compiling) + is_torch_compiling, maybe_compile) from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig from .multi_stream_utils import maybe_execute_in_parallel from .rms_norm import RMSNorm @@ -76,17 +76,6 @@ def extract_extra_attrs(layer_idx: str, attn_type: str): return metadata, attn_layer -def maybe_compile(func): - - def wrapper(*args, **kwargs): - if is_piecewise_running(): - # When piecewise running, we don't need to compile the function to avoid host overhead in attention op. - return func(*args, **kwargs) - return torch.compile(func)(*args, **kwargs) - - return wrapper - - @maybe_compile def maybe_compiled_copy_(dst, src): dst.copy_(src) diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index 525f9f86f9..4b3349da85 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -325,3 +325,26 @@ def get_device_uuid(device_idx: int) -> str: property = torch.cuda.get_device_properties(device_idx) uuid = "GPU-" + str(property.uuid) return uuid + + +def maybe_compile(func=None, **compile_kwargs): + """ + Conditionally compile a function with torch.compile. + If is_piecewise_running() is True, the function will not be compiled to avoid host overhead in attention op. + Args: + func: The function to decorate (optional, for direct decoration). + **compile_kwargs: Keyword arguments for torch.compile. + Returns: + The conditionally compiled function.. + """ + + def decorator(f): + + def wrapper(*args, **kwargs): + if is_piecewise_running(): + return f(*args, **kwargs) + return torch.compile(f, **compile_kwargs)(*args, **kwargs) + + return wrapper + + return decorator(func) if func else decorator diff --git a/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py b/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py index 32756b9773..6cfe276a10 100644 --- a/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py +++ b/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py @@ -1175,6 +1175,7 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type): f" Chunk {i}: Q[{chunk.token_start}:{chunk.token_end}] ({num_q} tokens), " f"K[{chunk.k_token_start}:{chunk.k_token_end}] ({num_k} tokens)") + indexer._update_k_cache(k_fp8, k_scale, metadata_chunked) topk_indices_chunked = indexer.sparse_attn_indexer(metadata_chunked, hidden_states, q_fp8, k_fp8, k_scale, weights) @@ -1206,6 +1207,7 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type): f"✓ Created {num_baseline_chunks} chunk(s) (effectively non-chunked)" ) + indexer._update_k_cache(k_fp8, k_scale, metadata_baseline) topk_indices_baseline = indexer.sparse_attn_indexer(metadata_baseline, hidden_states, q_fp8, k_fp8, k_scale, weights)