mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix: Reset planned states to avoid memory leak in TrtllmAttentionWrapper (#4227)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
3dbb087292
commit
a4c3359513
@ -133,17 +133,17 @@ class TrtllmAttentionWrapper:
|
||||
self,
|
||||
*,
|
||||
tokens_per_block: Optional[int] = None,
|
||||
max_num_requests: int,
|
||||
max_sequence_length: int,
|
||||
max_context_length: int,
|
||||
max_num_requests: int = 0,
|
||||
max_sequence_length: int = 0,
|
||||
max_context_length: int = 0,
|
||||
attention_window_size: Optional[int] = None,
|
||||
sink_token_length: int = 0,
|
||||
beam_width: int = 1,
|
||||
sequence_length: torch.Tensor,
|
||||
host_past_key_value_lengths: torch.Tensor,
|
||||
context_lengths: torch.Tensor,
|
||||
host_context_lengths: torch.Tensor,
|
||||
host_request_types: torch.Tensor,
|
||||
sequence_length: torch.Tensor = ...,
|
||||
host_past_key_value_lengths: torch.Tensor = ...,
|
||||
context_lengths: torch.Tensor = ...,
|
||||
host_context_lengths: torch.Tensor = ...,
|
||||
host_request_types: torch.Tensor = ...,
|
||||
kv_cache_block_offsets: Optional[torch.Tensor] = None,
|
||||
host_kv_cache_block_offsets: Optional[torch.Tensor] = None,
|
||||
host_kv_cache_pool_pointers: Optional[torch.Tensor] = None,
|
||||
@ -163,6 +163,8 @@ class TrtllmAttentionWrapper:
|
||||
):
|
||||
"""
|
||||
Plan the attention operation.
|
||||
Call this method without arguments can reset the planned states.
|
||||
For required arguments, can use ellipsis (...) as default value to represent invalid states.
|
||||
Args:
|
||||
tokens_per_block (int): Token number per KV cache block.
|
||||
max_num_requests (int): Max request number per batch.
|
||||
@ -216,7 +218,6 @@ class TrtllmAttentionWrapper:
|
||||
'mrope_rotary_cos_sin') if mrope_config is not None else None
|
||||
self.mrope_position_deltas = mrope_config.get(
|
||||
'mrope_position_deltas') if mrope_config is not None else None
|
||||
self.kwargs.update(kwargs)
|
||||
self.block_ids_per_seq = block_ids_per_seq
|
||||
|
||||
if max_sequence_length > self.rope_params.max_positions:
|
||||
@ -224,6 +225,8 @@ class TrtllmAttentionWrapper:
|
||||
self.rotary_inv_freq, self.rotary_cos_sin = self.rope_params.create_rope_const_params(
|
||||
)
|
||||
|
||||
self.kwargs.update(kwargs)
|
||||
|
||||
def run(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
@ -372,6 +375,8 @@ class TrtllmAttentionWrapper:
|
||||
self.mrope_rotary_cos_sin,
|
||||
self.mrope_position_deltas,
|
||||
)
|
||||
# reset the planned states (especially tensors) to avoid memory leak
|
||||
self.plan()
|
||||
return output
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user