TensorRT-LLMs/cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp
Enwei Zhu 1745102e72
[TRTLLM-7027][feat] Fuse d2t to logitsBitmaskKernel and fix a race condition in one-model spec (#7481)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Co-authored-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
2025-09-04 23:30:14 +08:00

107 lines
4.4 KiB
C++

/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/logitsBitmask.h"
#include "tensorrt_llm/thop/thUtils.h"
namespace torch_ext
{
void logitsBitmask(torch::Tensor const& logits, torch::Tensor const& bitmask,
at::optional<torch::Tensor> const& tokenMask = at::nullopt, at::optional<torch::Tensor> const& d2t = at::nullopt)
{
int32_t const batchSize = logits.size(0);
if (batchSize == 0)
{
return;
}
TORCH_CHECK(bitmask.size(0) == batchSize, "bitmask must have the same batch size as logits.");
int32_t vocabSizePadded = logits.size(1);
int32_t bitmaskSize = bitmask.size(1);
TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor.");
TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous.");
TORCH_CHECK(logits.dim() == 2, "logits must be a 2D tensor.");
TORCH_CHECK(bitmask.is_cuda(), "bitmask must be a CUDA tensor.");
TORCH_CHECK(bitmask.is_contiguous(), "bitmask must be contiguous.");
TORCH_CHECK(bitmask.dim() == 2, "bitmask must be a 2D tensor.");
TORCH_CHECK(bitmask.scalar_type() == torch::kUInt32 || bitmask.scalar_type() == torch::kInt32,
"bitmask must have element type uint32 or int32.");
int32_t const* tokenMaskPtr = nullptr;
if (tokenMask.has_value())
{
TORCH_CHECK(tokenMask->is_cuda(), "tokenMask must be a CUDA tensor.");
TORCH_CHECK(tokenMask->is_contiguous(), "tokenMask must be contiguous.");
TORCH_CHECK(tokenMask->dim() == 1, "tokenMask must be a 1D tensor.");
TORCH_CHECK(tokenMask->size(0) == batchSize, "tokenMask must have the same batch size as logits.");
TORCH_CHECK(tokenMask->scalar_type() == torch::kInt32, "tokenMask must have element type int32.");
tokenMaskPtr = reinterpret_cast<int32_t const*>(tokenMask->data_ptr());
}
int32_t const* d2tPtr = nullptr;
if (d2t.has_value())
{
TORCH_CHECK(d2t->is_cuda(), "d2t must be a CUDA tensor.");
TORCH_CHECK(d2t->is_contiguous(), "d2t must be contiguous.");
TORCH_CHECK(d2t->dim() == 1, "d2t must be a 1D tensor.");
TORCH_CHECK(d2t->size(0) == vocabSizePadded, "d2t must have the same vocab size as logits.");
TORCH_CHECK(d2t->scalar_type() == torch::kInt32, "d2t must have element type int32.");
d2tPtr = reinterpret_cast<int32_t const*>(d2t->data_ptr());
}
auto stream = at::cuda::getCurrentCUDAStream(logits.get_device()).stream();
switch (logits.scalar_type())
{
case torch::kFloat32:
{
tensorrt_llm::kernels::invokeContiguousLogitsBitmask<float>(reinterpret_cast<float*>(logits.data_ptr()),
reinterpret_cast<uint32_t const*>(bitmask.data_ptr()), tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded,
bitmaskSize, stream);
break;
}
case torch::kFloat16:
{
tensorrt_llm::kernels::invokeContiguousLogitsBitmask<__half>(reinterpret_cast<__half*>(logits.data_ptr()),
reinterpret_cast<uint32_t const*>(bitmask.data_ptr()), tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded,
bitmaskSize, stream);
break;
}
case torch::kBFloat16:
{
tensorrt_llm::kernels::invokeContiguousLogitsBitmask<__nv_bfloat16>(
reinterpret_cast<__nv_bfloat16*>(logits.data_ptr()), reinterpret_cast<uint32_t const*>(bitmask.data_ptr()),
tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded, bitmaskSize, stream);
break;
}
default: TORCH_CHECK(false, "logits dtype must be float, half or bfloat16."); break;
}
}
} // namespace torch_ext
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("logits_bitmask(Tensor(a!) logits, Tensor bitmask, Tensor? token_mask=None, Tensor? d2t=None) -> ()");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("logits_bitmask", &torch_ext::logitsBitmask);
}