diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index 80e37bba7d..aa32d6317e 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -813,13 +813,14 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata): # Expand schedule metadata buffer (only generation) kv_lens_expanded = self.kv_lens_expanded_cuda[:num_tokens] scheduler_metadata_buffer_expanded = get_paged_mqa_logits_metadata( - kv_lens_expanded, tokens_per_block, self.num_sms) + kv_lens_expanded, self.kv_cache_manager.tokens_per_block, + self.num_sms) self.scheduler_metadata_buffer_expanded.copy_( scheduler_metadata_buffer_expanded, non_blocking=True) elif self.max_draft_tokens == 3: scheduler_metadata_buffer_mtp3 = get_paged_mqa_logits_metadata( self.kv_lens_cuda[self.num_contexts:self.num_seqs], - tokens_per_block, self.num_sms // 2) + self.kv_cache_manager.tokens_per_block, self.num_sms // 2) self.scheduler_metadata_buffer_mtp3.copy_( scheduler_metadata_buffer_mtp3, non_blocking=True) self.prepare_dense_topk_indices(self.kv_lens_cuda, device=True)