/* * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/quantization.h" #include "tensorrt_llm/kernels/kvCacheUtils.h" #include "tensorrt_llm/kernels/mlaChunkedPrefill.cuh" #include "tensorrt_llm/kernels/mlaKernels.h" #include "tensorrt_llm/thop/thUtils.h" #include #include namespace tk = tensorrt_llm::kernels; namespace tc = tensorrt_llm::common; using tk::KVBlockArray; namespace torch_ext { namespace { template void loadPagedKVCacheForMLAHelper(torch::Tensor& compressed_kv, torch::Tensor& k_pe, KVBlockArray& kv_cache, int const num_contexts, torch::Tensor const& cu_ctx_cached_kv_lens, int const max_input_seq_len, int const lora_size, int const rope_size, float const* kv_scale_quant_orig_ptr) { auto stream = at::cuda::getCurrentCUDAStream(compressed_kv.get_device()); auto* compressed_kv_ptr = static_cast(compressed_kv.data_ptr()); auto* k_pe_ptr = static_cast(k_pe.data_ptr()); auto const* cu_ctx_cached_kv_lens_ptr = cu_ctx_cached_kv_lens.data_ptr(); tensorrt_llm::kernels::invokeMLALoadPagedKV(compressed_kv_ptr, k_pe_ptr, kv_cache, num_contexts, cu_ctx_cached_kv_lens_ptr, max_input_seq_len, lora_size, rope_size, kv_scale_quant_orig_ptr, stream); } template void loadChunkedKVCacheForMLAHelper(torch::Tensor& output_kv, torch::Tensor& output_k_pe, KVBlockArray& kv_cache, int const num_contexts, torch::Tensor const& cu_ctx_chunked_len, int lora_size, int rope_size, int const chunked_size, int const chunked_idx, float const* kv_scale_quant_orig_ptr) { auto stream = at::cuda::getCurrentCUDAStream(output_kv.get_device()); T* output_kv_ptr = static_cast(output_kv.data_ptr()); T* output_k_pe_ptr = static_cast(output_k_pe.data_ptr()); tensorrt_llm::kernels::invokeMLALoadChunkedKV(output_kv_ptr, output_k_pe_ptr, kv_cache, num_contexts, cu_ctx_chunked_len.data_ptr(), lora_size, rope_size, chunked_size, chunked_idx, kv_scale_quant_orig_ptr, stream); } template void setPagedKVCacheForMLAHelper(torch::Tensor& output, torch::Tensor const& k, torch::Tensor const& v, torch::Tensor const& k_pe, int const num_requests, torch::Tensor const& cu_seq_lens, int const max_input_seq_len, int num_heads, int kv_dim, int rope_dim, int kv_cache_tokens_per_block, int64_t kv_token_stride) { auto stream = at::cuda::getCurrentCUDAStream(output.get_device()); auto* output_ptr = static_cast(output.data_ptr()); auto const* k_ptr = static_cast(k.data_ptr()); auto const* v_ptr = static_cast(v.data_ptr()); auto const* k_pe_ptr = static_cast(k_pe.data_ptr()); auto const* cu_seq_lens_ptr = cu_seq_lens.data_ptr(); // cudaMemset is faster than torch::zeros TLLM_CUDA_CHECK(cudaMemsetAsync(output_ptr, 0, output.numel() * torch::elementSize(output.scalar_type()), stream)); tensorrt_llm::kernels::invokeMLASetPagedKV(output_ptr, k_ptr, v_ptr, k_pe_ptr, num_requests, cu_seq_lens_ptr, max_input_seq_len, num_heads, kv_dim, rope_dim, kv_cache_tokens_per_block, kv_token_stride, stream); } template void setChunkedKVCacheForMLAHelper(torch::Tensor& output, torch::Tensor const& kv, torch::Tensor const& k_pe, int const num_requests, torch::Tensor const& cu_seq_lens, int num_heads, int kv_dim, int rope_dim, int kv_cache_tokens_per_block, int max_seq_len) { auto stream = at::cuda::getCurrentCUDAStream(output.get_device()); T* output_ptr = static_cast(output.data_ptr()); T* kv_ptr = static_cast(kv.data_ptr()); T* k_pe_ptr = static_cast(k_pe.data_ptr()); auto* cu_seq_lens_ptr = cu_seq_lens.data_ptr(); tensorrt_llm::kernels::invokeMLASetChunkedKV(output_ptr, kv_ptr, k_pe_ptr, num_requests, max_seq_len, num_heads, kv_dim, rope_dim, cu_seq_lens_ptr, kv_cache_tokens_per_block, stream); } template void invokeMLARopeAppendPagedKVAssignQHelper(KVBlockArray& kv_cache, torch::Tensor& q, torch::Tensor& latent_cache, int const num_requests, torch::Tensor const& cu_ctx_cached_kv_lens, torch::Tensor const& cu_seq_lens, int const max_input_uncached_seq_len, torch::Tensor const& cos_sin_cache, int const head_num, int const nope_size, int const rope_size, int const lora_size, float const* kv_scale_orig_quant_ptr) { auto stream = at::cuda::getCurrentCUDAStream(q.get_device()); auto* q_ptr = static_cast(q.data_ptr()); auto* latent_cache_ptr = static_cast(latent_cache.data_ptr()); auto const* cu_ctx_cached_kv_lens_ptr = cu_ctx_cached_kv_lens.data_ptr(); auto const* cu_seq_lens_ptr = cu_seq_lens.data_ptr(); auto const* cos_sin_cache_ptr = static_cast(cos_sin_cache.data_ptr()); tensorrt_llm::kernels::invokeMLARopeAppendPagedKVAssignQ(kv_cache, q_ptr, latent_cache_ptr, num_requests, cu_ctx_cached_kv_lens_ptr, cu_seq_lens_ptr, max_input_uncached_seq_len, cos_sin_cache_ptr, head_num, nope_size, rope_size, lora_size, kv_scale_orig_quant_ptr, stream); } template void mergeChunkedAttentionForMLAHelper(torch::Tensor& merged_attn, torch::Tensor const& temp_attn, torch::Tensor& merged_softmax_stats, torch::Tensor const& temp_softmax_stats, int64_t const num_requests, torch::Tensor const& cu_q_seq_lens, int64_t const max_q_seq_len, torch::Tensor const& merge_op, int64_t const num_heads, int64_t const head_size) { auto stream = at::cuda::getCurrentCUDAStream(merged_attn.get_device()); T* merged_attn_ptr = static_cast(merged_attn.data_ptr()); T* temp_attn_ptr = static_cast(temp_attn.data_ptr()); float* merged_softmax_stats_ptr = static_cast(merged_softmax_stats.data_ptr()); float* temp_softmax_stats_ptr = static_cast(temp_softmax_stats.data_ptr()); int64_t* const cu_q_seq_lens_ptr = cu_q_seq_lens.data_ptr(); int64_t* const merge_op_ptr = merge_op.data_ptr(); tensorrt_llm::kernels::invokeMergeAttnWithSoftmax(merged_attn_ptr, merged_softmax_stats_ptr, merged_attn_ptr, merged_softmax_stats_ptr, temp_attn_ptr, temp_softmax_stats_ptr, num_requests, cu_q_seq_lens_ptr, max_q_seq_len, merge_op_ptr, num_heads, head_size, stream); } /** * Creates a KVBlockArray object for managing KV cache * * @param num_contexts Number of contexts * @param max_blocks_per_sequence Maximum blocks per sequence * @param tokens_per_block Number of tokens per block * @param head_size Size of each head * @param num_kv_heads Number of KV heads (1 for MLA) * @param attention_window_size Attention window size * @param sink_token_length Sink token length * @param beam_width Beam width * @param kv_cache_quant_mode KV cache quantization mode * @param orig_dtype Original data type * @param host_kv_cache_pool_pointers Host KV cache pool pointers * @param host_kv_cache_pool_mapping Host KV cache pool mapping * @param kv_cache_block_offsets KV cache block offsets * @param layer_idx Layer index * @return Constructed KVBlockArray object */ KVBlockArray createKVBlockArray(int num_contexts, int max_blocks_per_sequence, int tokens_per_block, int head_size, int num_kv_heads, int attention_window_size, int sink_token_length, int beam_width, tc::QuantMode kv_cache_quant_mode, torch::Dtype orig_dtype, torch::Tensor const& host_kv_cache_pool_pointers, torch::Tensor const& host_kv_cache_pool_mapping, torch::Tensor const& kv_cache_block_offsets, int layer_idx) { auto const orig_elem_size = torch::elementSize(orig_dtype); auto const cache_elem_size = kv_cache_quant_mode.hasKvCacheQuant() ? sizeof(int8_t) : orig_elem_size; auto const size_per_token = num_kv_heads * head_size * cache_elem_size; int const cyclic_attention_window_size = attention_window_size; int const max_cyclic_attention_window_size = attention_window_size; bool const can_use_one_more_block = beam_width > 1; auto const pool_index = host_kv_cache_pool_mapping.index({layer_idx, 0}).item(); auto const layer_idx_in_cache_pool = host_kv_cache_pool_mapping.index({layer_idx, 1}).item(); int32_t const seq_offset = 0; KVBlockArray::DataType* block_offsets = static_cast(kv_cache_block_offsets.index({pool_index, seq_offset}).data_ptr()); auto const block_size = tokens_per_block * num_kv_heads * head_size; auto const bytes_per_block = block_size * cache_elem_size; int32_t const kv_factor = 1; // always 1 for MLA auto const intra_pool_offset = layer_idx_in_cache_pool * kv_factor * bytes_per_block; void* host_primary_pool_pointer = reinterpret_cast( reinterpret_cast(host_kv_cache_pool_pointers.index({pool_index, 0}).item()) + intra_pool_offset); void* host_secondary_pool_pointer = reinterpret_cast( reinterpret_cast(host_kv_cache_pool_pointers.index({pool_index, 1}).item()) + intra_pool_offset); return KVBlockArray(num_contexts, max_blocks_per_sequence, tokens_per_block, size_per_token, cyclic_attention_window_size, max_cyclic_attention_window_size, sink_token_length, can_use_one_more_block, host_primary_pool_pointer, host_secondary_pool_pointer, block_offsets); } } // namespace std::vector loadPagedKVCacheForMLA(torch::ScalarType out_dtype, int64_t const num_contexts, int64_t const num_ctx_cached_tokens, int64_t const max_ctx_cached_kv_len, torch::Tensor& cu_ctx_cached_kv_lens, torch::Tensor const& kv_cache_block_offsets, torch::Tensor const& host_kv_cache_block_offsets, torch::Tensor const& host_kv_cache_pool_pointers, torch::Tensor const& host_kv_cache_pool_mapping, torch::optional kv_scale_orig_quant, torch::optional kv_scale_quant_orig, int64_t const layer_idx, int64_t const lora_size, int64_t const rope_size, int64_t const tokens_per_block, int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width, int64_t const quant_mode) { TORCH_CHECK(out_dtype == torch::kFloat16 || out_dtype == torch::kFloat32 || out_dtype == torch::kBFloat16, "out_dtype only support float16, float32, bfloat16"); TLLM_CHECK(num_contexts > 0); TORCH_CHECK(num_ctx_cached_tokens > 0); TLLM_CHECK(max_ctx_cached_kv_len > 0); CHECK_INPUT(cu_ctx_cached_kv_lens, torch::kInt64); TORCH_CHECK(cu_ctx_cached_kv_lens.dim() == 1); TORCH_CHECK(cu_ctx_cached_kv_lens.size(0) >= num_contexts + 1); auto kv_cache_quant_mode = tc::QuantMode(static_cast(quant_mode)); int max_blocks_per_sequence = kv_cache_block_offsets.size(-1); int head_size = lora_size + rope_size; KVBlockArray kv_cache_buffer = createKVBlockArray(num_contexts, max_blocks_per_sequence, tokens_per_block, head_size, 1, // num_kv_heads is always 1 for MLA attention_window_size, sink_token_length, beam_width, kv_cache_quant_mode, out_dtype, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, kv_cache_block_offsets, layer_idx); float const* kv_scale_orig_quant_ptr = nullptr; float const* kv_scale_quant_orig_ptr = nullptr; if (kv_cache_quant_mode.hasKvCacheQuant()) { TLLM_CHECK_WITH_INFO(kv_cache_quant_mode.hasFp8KvCache(), "Only FP8 KV cache is supported for now"); TORCH_CHECK(kv_scale_orig_quant.has_value()); TORCH_CHECK(kv_scale_quant_orig.has_value()); kv_scale_orig_quant_ptr = kv_scale_orig_quant.value().data_ptr(); kv_scale_quant_orig_ptr = kv_scale_quant_orig.value().data_ptr(); TLLM_CHECK(kv_scale_orig_quant_ptr != nullptr); TLLM_CHECK(kv_scale_quant_orig_ptr != nullptr); } std::vector outputs; // compressed_kv {num_ctx_cached_tokens, lora_size} outputs.push_back(torch::empty( {num_ctx_cached_tokens, lora_size}, torch::dtype(out_dtype).device(torch::kCUDA).requires_grad(false))); // k_pe {num_ctx_cached_tokens, rope_size} outputs.push_back(torch::empty( {num_ctx_cached_tokens, rope_size}, torch::dtype(out_dtype).device(torch::kCUDA).requires_grad(false))); if (out_dtype == torch::kFloat16) { if (kv_cache_quant_mode.hasFp8KvCache()) { loadPagedKVCacheForMLAHelper(outputs[0], outputs[1], kv_cache_buffer, num_contexts, cu_ctx_cached_kv_lens, max_ctx_cached_kv_len, lora_size, rope_size, kv_scale_quant_orig_ptr); } else { loadPagedKVCacheForMLAHelper(outputs[0], outputs[1], kv_cache_buffer, num_contexts, cu_ctx_cached_kv_lens, max_ctx_cached_kv_len, lora_size, rope_size, kv_scale_quant_orig_ptr); } } else if (out_dtype == torch::kFloat32) { if (kv_cache_quant_mode.hasFp8KvCache()) { loadPagedKVCacheForMLAHelper(outputs[0], outputs[1], kv_cache_buffer, num_contexts, cu_ctx_cached_kv_lens, max_ctx_cached_kv_len, lora_size, rope_size, kv_scale_quant_orig_ptr); } else { loadPagedKVCacheForMLAHelper(outputs[0], outputs[1], kv_cache_buffer, num_contexts, cu_ctx_cached_kv_lens, max_ctx_cached_kv_len, lora_size, rope_size, kv_scale_quant_orig_ptr); } } else if (out_dtype == torch::kBFloat16) { if (kv_cache_quant_mode.hasFp8KvCache()) { loadPagedKVCacheForMLAHelper<__nv_bfloat16, __nv_fp8_e4m3>(outputs[0], outputs[1], kv_cache_buffer, num_contexts, cu_ctx_cached_kv_lens, max_ctx_cached_kv_len, lora_size, rope_size, kv_scale_quant_orig_ptr); } else { loadPagedKVCacheForMLAHelper<__nv_bfloat16, __nv_bfloat16>(outputs[0], outputs[1], kv_cache_buffer, num_contexts, cu_ctx_cached_kv_lens, max_ctx_cached_kv_len, lora_size, rope_size, kv_scale_quant_orig_ptr); } } return outputs; } std::vector loadChunkedKVCacheForMLA(torch::ScalarType out_dtype, int64_t const num_contexts, int64_t const num_ctx_cached_tokens, torch::Tensor& cu_ctx_chunked_kv_lens, torch::Tensor const& kv_cache_block_offsets, torch::Tensor const& host_kv_cache_pool_pointers, torch::Tensor const& host_kv_cache_pool_mapping, torch::optional kv_scale_orig_quant, torch::optional kv_scale_quant_orig, int64_t const layer_idx, int64_t const lora_size, int64_t const rope_size, int64_t const tokens_per_block, int64_t const chunked_size, int64_t const chunked_index, int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width, int64_t const quant_mode) { TORCH_CHECK(out_dtype == torch::kFloat16 || out_dtype == torch::kFloat32 || out_dtype == torch::kBFloat16, "out_dtype only support float16, float32, bfloat16"); TLLM_CHECK(num_contexts > 0); CHECK_INPUT(cu_ctx_chunked_kv_lens, torch::kInt64); TORCH_CHECK(cu_ctx_chunked_kv_lens.dim() == 1); TORCH_CHECK(cu_ctx_chunked_kv_lens.size(0) >= num_contexts + 1); int head_size = lora_size + rope_size; auto kv_cache_quant_mode = tc::QuantMode(static_cast(quant_mode)); int max_blocks_per_sequence = kv_cache_block_offsets.size(-1); KVBlockArray kv_cache_buffer = createKVBlockArray(num_contexts, max_blocks_per_sequence, tokens_per_block, head_size, 1, // num_kv_heads is always 1 for MLA attention_window_size, sink_token_length, beam_width, kv_cache_quant_mode, out_dtype, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, kv_cache_block_offsets, layer_idx); float const* kv_scale_orig_quant_ptr = nullptr; float const* kv_scale_quant_orig_ptr = nullptr; if (kv_cache_quant_mode.hasKvCacheQuant()) { TORCH_CHECK(kv_scale_orig_quant.has_value()); TORCH_CHECK(kv_scale_quant_orig.has_value()); kv_scale_orig_quant_ptr = kv_scale_orig_quant.value().data_ptr(); kv_scale_quant_orig_ptr = kv_scale_quant_orig.value().data_ptr(); TLLM_CHECK(kv_scale_orig_quant_ptr != nullptr); TLLM_CHECK(kv_scale_quant_orig_ptr != nullptr); } std::vector outputs; // compressed_kv {num_ctx_cached_tokens, lora_size} outputs.push_back(torch::empty( {num_ctx_cached_tokens, lora_size}, torch::dtype(out_dtype).device(torch::kCUDA).requires_grad(false))); // k_pe {num_ctx_cached_tokens, rope_size} outputs.push_back(torch::empty( {num_ctx_cached_tokens, rope_size}, torch::dtype(out_dtype).device(torch::kCUDA).requires_grad(false))); if (out_dtype == torch::kFloat16) { if (kv_cache_quant_mode.hasFp8KvCache()) { loadChunkedKVCacheForMLAHelper(outputs[0], outputs[1], kv_cache_buffer, num_contexts, cu_ctx_chunked_kv_lens, lora_size, rope_size, chunked_size, chunked_index, kv_scale_quant_orig_ptr); } else { loadChunkedKVCacheForMLAHelper(outputs[0], outputs[1], kv_cache_buffer, num_contexts, cu_ctx_chunked_kv_lens, lora_size, rope_size, chunked_size, chunked_index, kv_scale_quant_orig_ptr); } } else if (out_dtype == torch::kFloat32) { if (kv_cache_quant_mode.hasFp8KvCache()) { loadChunkedKVCacheForMLAHelper(outputs[0], outputs[1], kv_cache_buffer, num_contexts, cu_ctx_chunked_kv_lens, lora_size, rope_size, chunked_size, chunked_index, kv_scale_quant_orig_ptr); } else { loadChunkedKVCacheForMLAHelper(outputs[0], outputs[1], kv_cache_buffer, num_contexts, cu_ctx_chunked_kv_lens, lora_size, rope_size, chunked_size, chunked_index, kv_scale_quant_orig_ptr); } } else if (out_dtype == torch::kBFloat16) { if (kv_cache_quant_mode.hasFp8KvCache()) { loadChunkedKVCacheForMLAHelper<__nv_bfloat16, __nv_fp8_e4m3>(outputs[0], outputs[1], kv_cache_buffer, num_contexts, cu_ctx_chunked_kv_lens, lora_size, rope_size, chunked_size, chunked_index, kv_scale_quant_orig_ptr); } else { loadChunkedKVCacheForMLAHelper<__nv_bfloat16, __nv_bfloat16>(outputs[0], outputs[1], kv_cache_buffer, num_contexts, cu_ctx_chunked_kv_lens, lora_size, rope_size, chunked_size, chunked_index, kv_scale_quant_orig_ptr); } } return outputs; } torch::Tensor setPagedKVCacheForMLA(torch::Tensor& output, torch::Tensor const& k, torch::Tensor const& v, torch::Tensor const& k_pe, int64_t const num_requests, torch::Tensor const& cu_seq_lens, int64_t const max_input_seq_len, int64_t const num_heads, int64_t const kv_dim, int64_t const rope_dim, int64_t const kv_cache_tokens_per_block) { TORCH_CHECK(output.numel() > 0); auto output_dtype = output.scalar_type(); TORCH_CHECK(output_dtype == torch::kFloat16 || output_dtype == torch::kFloat32 || output_dtype == torch::kBFloat16); CHECK_TH_CUDA(output); CHECK_CONTIGUOUS(output); // k and v can be non-contiguous CHECK_TH_CUDA(k); CHECK_TYPE(k, output_dtype); CHECK_TH_CUDA(v); CHECK_TYPE(v, output_dtype); TORCH_CHECK(k.dim() == 3); TORCH_CHECK(v.dim() == 3); TORCH_CHECK(k.size(0) == v.size(0)); TORCH_CHECK(k.size(1) == v.size(1)); TORCH_CHECK(k.size(2) == v.size(2)); TORCH_CHECK(k.stride(1) == k.size(2)); TORCH_CHECK(v.stride(1) == v.size(2)); TORCH_CHECK(k.stride(2) == 1); TORCH_CHECK(v.stride(2) == 1); // k and v should have the same token stride int64_t k_token_stride = k.stride(0); int64_t v_token_stride = v.stride(0); TORCH_CHECK(k_token_stride == v_token_stride); // k_pe should be contiguous CHECK_INPUT(k_pe, output_dtype); CHECK_INPUT(cu_seq_lens, torch::kInt64); TORCH_CHECK(cu_seq_lens.dim() == 1); TORCH_CHECK(cu_seq_lens.size(0) >= num_requests + 1); if (output_dtype == torch::kFloat16) { setPagedKVCacheForMLAHelper(output, k, v, k_pe, num_requests, cu_seq_lens, max_input_seq_len, num_heads, kv_dim, rope_dim, kv_cache_tokens_per_block, k_token_stride); } else if (output_dtype == torch::kFloat32) { setPagedKVCacheForMLAHelper(output, k, v, k_pe, num_requests, cu_seq_lens, max_input_seq_len, num_heads, kv_dim, rope_dim, kv_cache_tokens_per_block, k_token_stride); } else if (output_dtype == torch::kBFloat16) { setPagedKVCacheForMLAHelper<__nv_bfloat16>(output, k, v, k_pe, num_requests, cu_seq_lens, max_input_seq_len, num_heads, kv_dim, rope_dim, kv_cache_tokens_per_block, k_token_stride); } int64_t max_block_num = (max_input_seq_len + kv_cache_tokens_per_block - 1) / kv_cache_tokens_per_block; torch::Tensor faked_kv_cache_block_offsets = torch::arange( 0, num_requests * 2 * max_block_num, torch::TensorOptions().dtype(torch::kInt32).device(output.device())); faked_kv_cache_block_offsets = faked_kv_cache_block_offsets.view({num_requests, 2, max_block_num}); return faked_kv_cache_block_offsets; } torch::Tensor setChunkedKVCacheForMLA(torch::Tensor& output, torch::Tensor const& kv, torch::Tensor const& k_pe, int64_t const num_requests, torch::Tensor const& cu_seq_lens, int64_t const num_heads, int64_t const kv_dim, int64_t const rope_dim, int64_t const kv_cache_tokens_per_block, int64_t const max_seq_len) { TORCH_CHECK(output.numel() > 0); TORCH_CHECK(output.scalar_type() == torch::kFloat16 || output.scalar_type() == torch::kFloat32 || output.scalar_type() == torch::kBFloat16); CHECK_TH_CUDA(output); CHECK_CONTIGUOUS(output); CHECK_INPUT(kv, output.scalar_type()); CHECK_INPUT(k_pe, output.scalar_type()); CHECK_INPUT(cu_seq_lens, torch::kInt64); TORCH_CHECK(cu_seq_lens.dim() == 1); TORCH_CHECK(cu_seq_lens.size(0) >= num_requests + 1); if (output.scalar_type() == torch::kFloat16) { setChunkedKVCacheForMLAHelper(output, kv, k_pe, num_requests, cu_seq_lens, num_heads, kv_dim, rope_dim, kv_cache_tokens_per_block, max_seq_len); } else if (output.scalar_type() == torch::kFloat32) { setChunkedKVCacheForMLAHelper(output, kv, k_pe, num_requests, cu_seq_lens, num_heads, kv_dim, rope_dim, kv_cache_tokens_per_block, max_seq_len); } else if (output.scalar_type() == torch::kBFloat16) { setChunkedKVCacheForMLAHelper<__nv_bfloat16>(output, kv, k_pe, num_requests, cu_seq_lens, num_heads, kv_dim, rope_dim, kv_cache_tokens_per_block, max_seq_len); } int64_t max_block_num = (max_seq_len + kv_cache_tokens_per_block - 1) / kv_cache_tokens_per_block; // TODO: actually this offset is always the same for all requests and all layers. torch::Tensor faked_kv_cache_block_offsets = torch::arange( 0, num_requests * 2 * max_block_num, torch::TensorOptions().dtype(torch::kInt32).device(output.device())); faked_kv_cache_block_offsets = faked_kv_cache_block_offsets.view({num_requests, 2, max_block_num}); return faked_kv_cache_block_offsets; } void MLARopeAppendPagedKVAssignQ(torch::Tensor& q, torch::Tensor& latent_cache, int64_t const num_contexts, torch::Tensor const& cu_ctx_cached_kv_lens, torch::Tensor const& cu_seq_lens, int64_t const max_input_uncached_seq_len, torch::Tensor const& cos_sin_cache, int64_t const head_num, int64_t const nope_size, int64_t const rope_size, int64_t const lora_size, torch::Tensor const& kv_cache_block_offsets, torch::Tensor const& host_kv_cache_block_offsets, torch::Tensor const& host_kv_cache_pool_pointers, torch::Tensor const& host_kv_cache_pool_mapping, torch::optional kv_scale_orig_quant, torch::optional kv_scale_quant_orig, int64_t const layer_idx, int64_t const tokens_per_block, int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width, int64_t const quant_mode) { auto input_dtype = q.scalar_type(); TORCH_CHECK(input_dtype == torch::kFloat16 || input_dtype == torch::kFloat32 || input_dtype == torch::kBFloat16); TORCH_CHECK(q.numel() > 0); TORCH_CHECK(q.dim() == 2); CHECK_TH_CUDA(q); CHECK_CONTIGUOUS(q); CHECK_INPUT(latent_cache, input_dtype); TORCH_CHECK(latent_cache.dim() == 2); CHECK_INPUT(cu_seq_lens, torch::kInt64); TORCH_CHECK(cu_seq_lens.dim() == 1); TORCH_CHECK(cu_seq_lens.size(0) >= num_contexts + 1); CHECK_INPUT(cu_ctx_cached_kv_lens, torch::kInt64); TORCH_CHECK(cu_ctx_cached_kv_lens.dim() == 1); TORCH_CHECK(cu_ctx_cached_kv_lens.size(0) >= num_contexts + 1); TORCH_CHECK(max_input_uncached_seq_len > 0); auto kv_cache_quant_mode = tc::QuantMode(static_cast(quant_mode)); int max_blocks_per_sequence = kv_cache_block_offsets.size(-1); int head_size = lora_size + rope_size; KVBlockArray kv_cache_buffer = createKVBlockArray(num_contexts, max_blocks_per_sequence, tokens_per_block, head_size, 1, // num_kv_heads is always 1 for MLA attention_window_size, sink_token_length, beam_width, kv_cache_quant_mode, input_dtype, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, kv_cache_block_offsets, layer_idx); float const* kv_scale_orig_quant_ptr = nullptr; float const* kv_scale_quant_orig_ptr = nullptr; if (kv_cache_quant_mode.hasKvCacheQuant()) { TLLM_CHECK_WITH_INFO(kv_cache_quant_mode.hasFp8KvCache(), "Only FP8 KV cache is supported for now"); TORCH_CHECK(kv_scale_orig_quant.has_value()); TORCH_CHECK(kv_scale_quant_orig.has_value()); kv_scale_orig_quant_ptr = kv_scale_orig_quant.value().data_ptr(); kv_scale_quant_orig_ptr = kv_scale_quant_orig.value().data_ptr(); TLLM_CHECK(kv_scale_orig_quant_ptr != nullptr); TLLM_CHECK(kv_scale_quant_orig_ptr != nullptr); } if (input_dtype == torch::kFloat16) { if (kv_cache_quant_mode.hasFp8KvCache()) { invokeMLARopeAppendPagedKVAssignQHelper(kv_cache_buffer, q, latent_cache, num_contexts, cu_ctx_cached_kv_lens, cu_seq_lens, max_input_uncached_seq_len, cos_sin_cache, head_num, nope_size, rope_size, lora_size, kv_scale_orig_quant_ptr); } else { invokeMLARopeAppendPagedKVAssignQHelper(kv_cache_buffer, q, latent_cache, num_contexts, cu_ctx_cached_kv_lens, cu_seq_lens, max_input_uncached_seq_len, cos_sin_cache, head_num, nope_size, rope_size, lora_size, kv_scale_orig_quant_ptr); } } else if (input_dtype == torch::kFloat32) { if (kv_cache_quant_mode.hasFp8KvCache()) { invokeMLARopeAppendPagedKVAssignQHelper(kv_cache_buffer, q, latent_cache, num_contexts, cu_ctx_cached_kv_lens, cu_seq_lens, max_input_uncached_seq_len, cos_sin_cache, head_num, nope_size, rope_size, lora_size, kv_scale_orig_quant_ptr); } else { invokeMLARopeAppendPagedKVAssignQHelper(kv_cache_buffer, q, latent_cache, num_contexts, cu_ctx_cached_kv_lens, cu_seq_lens, max_input_uncached_seq_len, cos_sin_cache, head_num, nope_size, rope_size, lora_size, kv_scale_orig_quant_ptr); } } else if (input_dtype == torch::kBFloat16) { if (kv_cache_quant_mode.hasFp8KvCache()) { invokeMLARopeAppendPagedKVAssignQHelper<__nv_bfloat16, __nv_fp8_e4m3>(kv_cache_buffer, q, latent_cache, num_contexts, cu_ctx_cached_kv_lens, cu_seq_lens, max_input_uncached_seq_len, cos_sin_cache, head_num, nope_size, rope_size, lora_size, kv_scale_orig_quant_ptr); } else { invokeMLARopeAppendPagedKVAssignQHelper<__nv_bfloat16, __nv_bfloat16>(kv_cache_buffer, q, latent_cache, num_contexts, cu_ctx_cached_kv_lens, cu_seq_lens, max_input_uncached_seq_len, cos_sin_cache, head_num, nope_size, rope_size, lora_size, kv_scale_orig_quant_ptr); } } } void mergeChunkedAttentionForMLA(torch::Tensor& merged_attn, torch::Tensor const& temp_attn, torch::Tensor& merged_softmax_stats, torch::Tensor const& temp_softmax_stats, int64_t const num_requests, torch::Tensor const& cu_q_seq_lens, int64_t const max_q_seq_len, torch::Tensor const& merge_op, int64_t const num_heads, int64_t const head_size) { TORCH_CHECK(merged_attn.numel() > 0); TORCH_CHECK(temp_attn.numel() > 0); TORCH_CHECK(merged_attn.scalar_type() == temp_attn.scalar_type()); TORCH_CHECK(merged_attn.scalar_type() == torch::kFloat16 || merged_attn.scalar_type() == torch::kFloat32 || merged_attn.scalar_type() == torch::kBFloat16); TORCH_CHECK(temp_softmax_stats.scalar_type() == merged_softmax_stats.scalar_type()); TORCH_CHECK(merged_softmax_stats.scalar_type() == torch::kFloat32); if (merged_attn.scalar_type() == torch::kFloat16) { mergeChunkedAttentionForMLAHelper(merged_attn, temp_attn, merged_softmax_stats, temp_softmax_stats, num_requests, cu_q_seq_lens, max_q_seq_len, merge_op, num_heads, head_size); } else if (merged_attn.scalar_type() == torch::kFloat32) { mergeChunkedAttentionForMLAHelper(merged_attn, temp_attn, merged_softmax_stats, temp_softmax_stats, num_requests, cu_q_seq_lens, max_q_seq_len, merge_op, num_heads, head_size); } else if (merged_attn.scalar_type() == torch::kBFloat16) { mergeChunkedAttentionForMLAHelper<__nv_bfloat16>(merged_attn, temp_attn, merged_softmax_stats, temp_softmax_stats, num_requests, cu_q_seq_lens, max_q_seq_len, merge_op, num_heads, head_size); } } } // namespace torch_ext TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "load_paged_kv_cache_for_mla(" "ScalarType out_dtype" ", int num_contexts" ", int num_ctx_cached_tokens" ", int max_ctx_cached_kv_len" ", Tensor cu_ctx_cached_kv_lens" ", Tensor kv_cache_block_offsets" ", Tensor host_kv_cache_block_offsets" ", Tensor host_kv_cache_pool_pointers" ", Tensor host_kv_cache_pool_mapping" ", Tensor? kv_scale_orig_quant" ", Tensor? kv_scale_quant_orig" ", int layer_idx" ", int lora_size" ", int rope_size" ", int tokens_per_block" ", int attention_window_size" ", int sink_token_length" ", int beam_width" ", int quant_mode" ") -> Tensor[]"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("load_paged_kv_cache_for_mla", &torch_ext::loadPagedKVCacheForMLA); } TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "load_chunked_kv_cache_for_mla(" "ScalarType out_dtype" ", int num_contexts" ", int num_ctx_cached_tokens" ", Tensor cu_ctx_chunked_kv_lens" ", Tensor kv_cache_block_offsets" ", Tensor host_kv_cache_pool_pointers" ", Tensor host_kv_cache_pool_mapping" ", Tensor? kv_scale_orig_quant" ", Tensor? kv_scale_quant_orig" ", int layer_idx" ", int lora_size" ", int rope_size" ", int tokens_per_block" ", int chunked_size" ", int chunked_index" ", int attention_window_size" ", int sink_token_length" ", int beam_width" ", int quant_mode" ") -> Tensor[]"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("load_chunked_kv_cache_for_mla", &torch_ext::loadChunkedKVCacheForMLA); } TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "set_paged_kv_cache_for_mla(" "Tensor output" ", Tensor k" ", Tensor v" ", Tensor k_pe" ", int num_requests" ", Tensor cu_seq_lens" ", int max_input_seq_len" ", int num_heads" ", int kv_dim" ", int rope_dim" ", int kv_cache_tokens_per_block" ") -> Tensor"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("set_paged_kv_cache_for_mla", &torch_ext::setPagedKVCacheForMLA); } TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "set_chunked_kv_cache_for_mla(" "Tensor output" ", Tensor kv" ", Tensor k_pe" ", int num_requests" ", Tensor cu_seq_lens" ", int num_heads" ", int kv_dim" ", int rope_dim" ", int kv_cache_tokens_per_block" ", int max_seq_len" ") -> Tensor"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("set_chunked_kv_cache_for_mla", &torch_ext::setChunkedKVCacheForMLA); } TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "mla_rope_append_paged_kv_assign_q(" "Tensor q" ", Tensor latent_cache" ", int num_contexts" ", Tensor cu_ctx_cached_kv_lens" ", Tensor cu_seq_lens" ", int max_input_uncached_seq_len" ", Tensor cos_sin_cache" ", int head_num" ", int nope_size" ", int rope_size" ", int lora_size" ", Tensor kv_cache_block_offsets" ", Tensor host_kv_cache_block_offsets" ", Tensor host_kv_cache_pool_pointers" ", Tensor host_kv_cache_pool_mapping" ", Tensor? kv_scale_orig_quant" ", Tensor? kv_scale_quant_orig" ", int layer_idx" ", int tokens_per_block" ", int attention_window_size" ", int sink_token_length" ", int beam_width" ", int quant_mode" ") -> ()"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("mla_rope_append_paged_kv_assign_q", &torch_ext::MLARopeAppendPagedKVAssignQ); } TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "merge_chunked_attention_for_mla(" "Tensor merged_attn" ", Tensor temp_attn" ", Tensor merged_softmax_stats" ", Tensor temp_softmax_stats" ", int num_requests" ", Tensor cu_q_seq_lens" ", int max_q_seq_len" ", Tensor merge_op" ", int num_heads" ", int head_size" ") -> ()"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("merge_chunked_attention_for_mla", &torch_ext::mergeChunkedAttentionForMLA); }