/* * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/common/opUtils.h" #include "tensorrt_llm/runtime/torchUtils.h" #include "tensorrt_llm/kernels/IndexerTopK.h" // #include // #include // #include // #include // #include // #include // #include namespace th = torch; namespace tl = tensorrt_llm; namespace tk = tensorrt_llm::kernels; namespace torch_ext { void indexer_topk_decode( th::Tensor const& logits, th::Tensor const& seq_lens, th::Tensor const& indices, int64_t next_n, int64_t index_topk) { TORCH_CHECK(logits.is_cuda() && seq_lens.is_cuda() && indices.is_cuda(), "logits, seq_lens, and indices must be CUDA tensors"); TORCH_CHECK(logits.get_device() == seq_lens.get_device() && logits.get_device() == indices.get_device(), "logits, seq_lens, and indices must be on the same device"); TORCH_CHECK(logits.dim() == 2, "logits must be a 2D Tensor"); TORCH_CHECK(seq_lens.dim() == 1, "seq_lens must be a 1D Tensor"); TORCH_CHECK(indices.dim() == 2, "indices must be a 2D Tensor"); auto const inputSize = logits.sizes(); auto const numRows64 = inputSize[0]; auto const numColumns64 = inputSize[1]; TORCH_CHECK( seq_lens.size(0) * next_n == numRows64, "seq_lens length multiplied by next_n must equal logits.size(0)"); TORCH_CHECK(indices.size(0) == numRows64, "indices first dimension must match logits.size(0)"); TORCH_CHECK(indices.size(1) >= index_topk, "indices second dimension must be at least index_topk"); TORCH_CHECK(seq_lens.is_contiguous(), "seq_lens must be contiguous"); TORCH_CHECK(indices.is_contiguous(), "indices must be contiguous"); TORCH_CHECK(next_n > 0, "next_n must be greater than 0"); int32_t num_rows = static_cast(numRows64); int32_t num_columns = static_cast(numColumns64); int32_t logits_stride_0 = static_cast(logits.stride(0)); int32_t logits_stride_1 = static_cast(logits.stride(1)); TORCH_CHECK(logits_stride_0 >= 0, "logits_stride_0 must be greater than or equal to 0"); TORCH_CHECK(logits_stride_1 >= 0, "logits_stride_1 must be greater than or equal to 0"); int32_t splitWorkThreshold = 200 * 1000; th::Tensor aux_indices = th::empty({0}, th::TensorOptions().dtype(th::kInt32).device(logits.device())); th::Tensor aux_logits = th::empty({0}, th::TensorOptions().dtype(th::kFloat32).device(logits.device())); constexpr auto multipleBlocksPerRowConfig = 10; if (num_columns >= splitWorkThreshold) { aux_indices = th::empty({num_rows, multipleBlocksPerRowConfig, index_topk}, th::TensorOptions().dtype(th::kInt32).device(logits.device())); aux_logits = th::empty({num_rows, multipleBlocksPerRowConfig, index_topk}, th::TensorOptions().dtype(th::kFloat32).device(logits.device())); } auto stream = at::cuda::getCurrentCUDAStream(logits.get_device()); tk::invokeIndexerTopKDecode(logits.data_ptr(), seq_lens.data_ptr(), indices.data_ptr(), aux_logits.data_ptr(), aux_indices.data_ptr(), splitWorkThreshold, num_rows, num_columns, logits_stride_0, logits_stride_1, static_cast(next_n), static_cast(index_topk), stream); } void indexer_topk_prefill(th::Tensor const& logits, th::Tensor const& row_starts, th::Tensor const& row_ends, th::Tensor const& indices, int64_t index_topk) { TORCH_CHECK(logits.is_cuda() && row_starts.is_cuda() && row_ends.is_cuda() && indices.is_cuda(), "logits, row_starts, row_ends, and indices must be CUDA tensors"); TORCH_CHECK(logits.get_device() == row_starts.get_device() && logits.get_device() == row_ends.get_device() && logits.get_device() == indices.get_device(), "logits, row_starts, row_ends, and indices must be on the same device"); TORCH_CHECK(indices.dim() == 2, "indices must be a 2D Tensor"); TORCH_CHECK(logits.dim() == 2, "logits must be a 2D Tensor"); auto const inputSize = logits.sizes(); auto const numRows64 = inputSize[0]; auto const numColumns64 = inputSize[1]; TORCH_CHECK(row_starts.dim() == 1, "row_starts must be a 1D Tensor"); TORCH_CHECK(row_ends.dim() == 1, "row_ends must be a 1D Tensor"); TORCH_CHECK(row_starts.size(0) == numRows64 && row_ends.size(0) == numRows64, "row_starts/row_ends must have one entry per row in logits"); TORCH_CHECK(row_starts.is_contiguous(), "row_starts must be contiguous"); TORCH_CHECK(row_ends.is_contiguous(), "row_ends must be contiguous"); int32_t num_rows = static_cast(numRows64); int32_t num_columns = static_cast(numColumns64); int32_t logits_stride_0 = static_cast(logits.stride(0)); int32_t logits_stride_1 = static_cast(logits.stride(1)); TORCH_CHECK(logits_stride_0 >= 0, "logits_stride_0 must be greater than or equal to 0"); TORCH_CHECK(logits_stride_1 >= 0, "logits_stride_1 must be greater than or equal to 0"); auto stream = at::cuda::getCurrentCUDAStream(logits.get_device()); tk::invokeIndexerTopKPrefill(logits.data_ptr(), row_starts.data_ptr(), row_ends.data_ptr(), indices.data_ptr(), num_rows, num_columns, static_cast(logits_stride_0), static_cast(logits_stride_1), static_cast(index_topk), stream); } } // end namespace torch_ext TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "indexer_topk_decode(Tensor logits, Tensor seq_lens, Tensor indices, int next_n, int index_topk=2048) -> " "()"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("indexer_topk_decode", &torch_ext::indexer_topk_decode); } TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "indexer_topk_prefill(Tensor logits, Tensor row_starts, Tensor row_ends, Tensor indices, int " "index_topk=2048) -> ()"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("indexer_topk_prefill", &torch_ext::indexer_topk_prefill); }