From a4c3359513dae5694a2a01955abffb7702b004ab Mon Sep 17 00:00:00 2001 From: yuxianq <142763828+yuxianq@users.noreply.github.com> Date: Mon, 12 May 2025 23:25:54 +0800 Subject: [PATCH] fix: Reset planned states to avoid memory leak in TrtllmAttentionWrapper (#4227) Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- .../_torch/attention_backend/trtllm.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 6ac5871fc8..0071b92a6a 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -133,17 +133,17 @@ class TrtllmAttentionWrapper: self, *, tokens_per_block: Optional[int] = None, - max_num_requests: int, - max_sequence_length: int, - max_context_length: int, + max_num_requests: int = 0, + max_sequence_length: int = 0, + max_context_length: int = 0, attention_window_size: Optional[int] = None, sink_token_length: int = 0, beam_width: int = 1, - sequence_length: torch.Tensor, - host_past_key_value_lengths: torch.Tensor, - context_lengths: torch.Tensor, - host_context_lengths: torch.Tensor, - host_request_types: torch.Tensor, + sequence_length: torch.Tensor = ..., + host_past_key_value_lengths: torch.Tensor = ..., + context_lengths: torch.Tensor = ..., + host_context_lengths: torch.Tensor = ..., + host_request_types: torch.Tensor = ..., kv_cache_block_offsets: Optional[torch.Tensor] = None, host_kv_cache_block_offsets: Optional[torch.Tensor] = None, host_kv_cache_pool_pointers: Optional[torch.Tensor] = None, @@ -163,6 +163,8 @@ class TrtllmAttentionWrapper: ): """ Plan the attention operation. + Call this method without arguments can reset the planned states. + For required arguments, can use ellipsis (...) as default value to represent invalid states. Args: tokens_per_block (int): Token number per KV cache block. max_num_requests (int): Max request number per batch. @@ -216,7 +218,6 @@ class TrtllmAttentionWrapper: 'mrope_rotary_cos_sin') if mrope_config is not None else None self.mrope_position_deltas = mrope_config.get( 'mrope_position_deltas') if mrope_config is not None else None - self.kwargs.update(kwargs) self.block_ids_per_seq = block_ids_per_seq if max_sequence_length > self.rope_params.max_positions: @@ -224,6 +225,8 @@ class TrtllmAttentionWrapper: self.rotary_inv_freq, self.rotary_cos_sin = self.rope_params.create_rope_const_params( ) + self.kwargs.update(kwargs) + def run( self, q: torch.Tensor, @@ -372,6 +375,8 @@ class TrtllmAttentionWrapper: self.mrope_rotary_cos_sin, self.mrope_position_deltas, ) + # reset the planned states (especially tensors) to avoid memory leak + self.plan() return output