fix: Reset planned states to avoid memory leak in TrtllmAttentionWrapper (#4227)

Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
yuxianq 2025-05-12 23:25:54 +08:00 committed by GitHub
parent 3dbb087292
commit a4c3359513
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -133,17 +133,17 @@ class TrtllmAttentionWrapper:
self,
*,
tokens_per_block: Optional[int] = None,
max_num_requests: int,
max_sequence_length: int,
max_context_length: int,
max_num_requests: int = 0,
max_sequence_length: int = 0,
max_context_length: int = 0,
attention_window_size: Optional[int] = None,
sink_token_length: int = 0,
beam_width: int = 1,
sequence_length: torch.Tensor,
host_past_key_value_lengths: torch.Tensor,
context_lengths: torch.Tensor,
host_context_lengths: torch.Tensor,
host_request_types: torch.Tensor,
sequence_length: torch.Tensor = ...,
host_past_key_value_lengths: torch.Tensor = ...,
context_lengths: torch.Tensor = ...,
host_context_lengths: torch.Tensor = ...,
host_request_types: torch.Tensor = ...,
kv_cache_block_offsets: Optional[torch.Tensor] = None,
host_kv_cache_block_offsets: Optional[torch.Tensor] = None,
host_kv_cache_pool_pointers: Optional[torch.Tensor] = None,
@ -163,6 +163,8 @@ class TrtllmAttentionWrapper:
):
"""
Plan the attention operation.
Call this method without arguments can reset the planned states.
For required arguments, can use ellipsis (...) as default value to represent invalid states.
Args:
tokens_per_block (int): Token number per KV cache block.
max_num_requests (int): Max request number per batch.
@ -216,7 +218,6 @@ class TrtllmAttentionWrapper:
'mrope_rotary_cos_sin') if mrope_config is not None else None
self.mrope_position_deltas = mrope_config.get(
'mrope_position_deltas') if mrope_config is not None else None
self.kwargs.update(kwargs)
self.block_ids_per_seq = block_ids_per_seq
if max_sequence_length > self.rope_params.max_positions:
@ -224,6 +225,8 @@ class TrtllmAttentionWrapper:
self.rotary_inv_freq, self.rotary_cos_sin = self.rope_params.create_rope_const_params(
)
self.kwargs.update(kwargs)
def run(
self,
q: torch.Tensor,
@ -372,6 +375,8 @@ class TrtllmAttentionWrapper:
self.mrope_rotary_cos_sin,
self.mrope_position_deltas,
)
# reset the planned states (especially tensors) to avoid memory leak
self.plan()
return output