mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 02:02:01 +08:00
[https://nvbugs/5669671][fix] Support GuidedDecoder with sharded logits (#10698)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
parent
9f741fb254
commit
7b8b9ccbaf
@ -64,8 +64,8 @@ __device__ PackedT packedNegativeInfinity()
|
||||
} // namespace
|
||||
|
||||
template <typename T, typename PackedT, int32_t kBitsPerThread>
|
||||
__global__ void __launch_bounds__(kThreadsPerBlock) logitsBitmaskKernel(
|
||||
T** __restrict__ logits, uint32_t const** __restrict__ bitmask, int32_t vocabSizePadded, int32_t bitmaskSize)
|
||||
__global__ void __launch_bounds__(kThreadsPerBlock)
|
||||
logitsBitmaskKernel(T** __restrict__ logits, uint32_t const** __restrict__ bitmask, int32_t vocabSizePadded)
|
||||
{
|
||||
int constexpr kAlignment = sizeof(PackedT) / sizeof(T);
|
||||
uint32_t constexpr kPackedMask = (1 << kAlignment) - 1;
|
||||
@ -123,29 +123,28 @@ void logitsBitmaskDispatchToBitsPerThread(
|
||||
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
int32_t const numBlocksPerRow = ceilDiv(2048 / kThreadsPerBlock * smCount, batchSize);
|
||||
int32_t const numBitsPerThread = ceilDiv(vocabSizePadded, kThreadsPerBlock * numBlocksPerRow);
|
||||
int32_t bitmaskSize = ceilDiv(vocabSizePadded, kBitsPerMaskElement);
|
||||
|
||||
dim3 const block(kThreadsPerBlock);
|
||||
|
||||
if (numBitsPerThread <= 4 && kAlignment <= 4)
|
||||
{
|
||||
dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 4), batchSize);
|
||||
logitsBitmaskKernel<T, PackedT, 4><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded, bitmaskSize);
|
||||
logitsBitmaskKernel<T, PackedT, 4><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded);
|
||||
}
|
||||
else if (numBitsPerThread <= 8 && kAlignment <= 8)
|
||||
{
|
||||
dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 8), batchSize);
|
||||
logitsBitmaskKernel<T, PackedT, 8><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded, bitmaskSize);
|
||||
logitsBitmaskKernel<T, PackedT, 8><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded);
|
||||
}
|
||||
else if (numBitsPerThread <= 16 && kAlignment <= 16)
|
||||
{
|
||||
dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 16), batchSize);
|
||||
logitsBitmaskKernel<T, PackedT, 16><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded, bitmaskSize);
|
||||
logitsBitmaskKernel<T, PackedT, 16><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded);
|
||||
}
|
||||
else
|
||||
{
|
||||
dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 32), batchSize);
|
||||
logitsBitmaskKernel<T, PackedT, 32><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded, bitmaskSize);
|
||||
logitsBitmaskKernel<T, PackedT, 32><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded);
|
||||
}
|
||||
}
|
||||
|
||||
@ -185,7 +184,7 @@ template void invokeLogitsBitmask<__nv_bfloat16>(
|
||||
template <typename T, typename PackedT, int32_t kBitsPerThread>
|
||||
__global__ void __launch_bounds__(kThreadsPerBlock) contiguousLogitsBitmaskKernel(T* __restrict__ logits,
|
||||
uint32_t const* __restrict__ bitmask, int32_t const* __restrict__ tokenMask, int32_t const* __restrict__ d2t,
|
||||
int32_t vocabSizePadded, int32_t bitmaskSize)
|
||||
int32_t vocabSizePadded, int32_t bitmaskStride)
|
||||
{
|
||||
int constexpr kAlignment = sizeof(PackedT) / sizeof(T);
|
||||
uint32_t constexpr kPackedMask = (1 << kAlignment) - 1;
|
||||
@ -199,7 +198,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) contiguousLogitsBitmaskKerne
|
||||
int const blockOffset = blockIdx.x * kThreadsPerBlock * kBitsPerThread;
|
||||
T* logitsGmemPtr = logits + batchIdx * vocabSizePadded + blockOffset;
|
||||
|
||||
uint32_t const* bitmaskGmemPtr = bitmask + batchIdx * bitmaskSize + blockOffset / kBitsPerMaskElement;
|
||||
uint32_t const* bitmaskGmemPtr = bitmask + batchIdx * bitmaskStride + blockOffset / kBitsPerMaskElement;
|
||||
int const bitmaskInnerIdx = threadIdx.x % (kBitsPerMaskElement / kAlignment);
|
||||
T logitsReg[kAlignment];
|
||||
|
||||
@ -224,7 +223,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) contiguousLogitsBitmaskKerne
|
||||
for (int i = 0; i < kAlignment; i++)
|
||||
{
|
||||
int const d2tOffset = blockOffset + offset + i + d2t[blockOffset + offset + i];
|
||||
bitmaskVal |= ((~bitmask[batchIdx * bitmaskSize + d2tOffset / kBitsPerMaskElement]
|
||||
bitmaskVal |= ((~bitmask[batchIdx * bitmaskStride + d2tOffset / kBitsPerMaskElement]
|
||||
>> (d2tOffset % kBitsPerMaskElement))
|
||||
& 1)
|
||||
<< i;
|
||||
@ -257,7 +256,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) contiguousLogitsBitmaskKerne
|
||||
|
||||
template <typename T, typename PackedT>
|
||||
void contiguousLogitsBitmaskDispatchToBitsPerThread(T* logits, uint32_t const* bitmask, int32_t const* tokenMask,
|
||||
int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskSize, cudaStream_t stream)
|
||||
int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskStride, cudaStream_t stream)
|
||||
{
|
||||
int constexpr kAlignment = sizeof(PackedT) / sizeof(T);
|
||||
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
@ -270,63 +269,69 @@ void contiguousLogitsBitmaskDispatchToBitsPerThread(T* logits, uint32_t const* b
|
||||
{
|
||||
dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 4), batchSize);
|
||||
contiguousLogitsBitmaskKernel<T, PackedT, 4>
|
||||
<<<grid, block, 0, stream>>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskSize);
|
||||
<<<grid, block, 0, stream>>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskStride);
|
||||
}
|
||||
else if (numBitsPerThread <= 8 && kAlignment <= 8)
|
||||
{
|
||||
dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 8), batchSize);
|
||||
contiguousLogitsBitmaskKernel<T, PackedT, 8>
|
||||
<<<grid, block, 0, stream>>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskSize);
|
||||
<<<grid, block, 0, stream>>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskStride);
|
||||
}
|
||||
else if (numBitsPerThread <= 16 && kAlignment <= 16)
|
||||
{
|
||||
dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 16), batchSize);
|
||||
contiguousLogitsBitmaskKernel<T, PackedT, 16>
|
||||
<<<grid, block, 0, stream>>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskSize);
|
||||
<<<grid, block, 0, stream>>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskStride);
|
||||
}
|
||||
else
|
||||
{
|
||||
dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 32), batchSize);
|
||||
contiguousLogitsBitmaskKernel<T, PackedT, 32>
|
||||
<<<grid, block, 0, stream>>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskSize);
|
||||
<<<grid, block, 0, stream>>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskStride);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeContiguousLogitsBitmask(T* logits, uint32_t const* bitmask, int32_t const* tokenMask, int32_t const* d2t,
|
||||
int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskSize, cudaStream_t stream)
|
||||
int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskStride, cudaStream_t stream)
|
||||
{
|
||||
// bitmaskStride may not equal to ceilDiv(vocabSizePadded, kBitsPerMaskElement) when:
|
||||
// (1) d2t is present, then logits are "pruned" (e.g., EAGLE3), while bitmask is complete and should be accessed
|
||||
// according to d2t.
|
||||
// (2) logits are sharded along the vocabulary dimension, while the bitmask is complete and should be accessed
|
||||
// according to the sharding.
|
||||
|
||||
// Dispatch to PackedT
|
||||
if (vocabSizePadded % (sizeof(float4) / sizeof(T)) == 0)
|
||||
{
|
||||
contiguousLogitsBitmaskDispatchToBitsPerThread<T, float4>(
|
||||
logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskSize, stream);
|
||||
logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskStride, stream);
|
||||
}
|
||||
else if (vocabSizePadded % (sizeof(float2) / sizeof(T)) == 0)
|
||||
{
|
||||
contiguousLogitsBitmaskDispatchToBitsPerThread<T, float2>(
|
||||
logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskSize, stream);
|
||||
logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskStride, stream);
|
||||
}
|
||||
else if (vocabSizePadded % (sizeof(float) / sizeof(T)) == 0)
|
||||
{
|
||||
contiguousLogitsBitmaskDispatchToBitsPerThread<T, float>(
|
||||
logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskSize, stream);
|
||||
logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskStride, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
contiguousLogitsBitmaskDispatchToBitsPerThread<T, T>(
|
||||
logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskSize, stream);
|
||||
logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskStride, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template void invokeContiguousLogitsBitmask<float>(float* logits, uint32_t const* bitmask, int32_t const* tokenMask,
|
||||
int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskSize, cudaStream_t stream);
|
||||
int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskStride, cudaStream_t stream);
|
||||
template void invokeContiguousLogitsBitmask<half>(half* logits, uint32_t const* bitmask, int32_t const* tokenMask,
|
||||
int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskSize, cudaStream_t stream);
|
||||
int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskStride, cudaStream_t stream);
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template void invokeContiguousLogitsBitmask<__nv_bfloat16>(__nv_bfloat16* logits, uint32_t const* bitmask,
|
||||
int32_t const* tokenMask, int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskSize,
|
||||
int32_t const* tokenMask, int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskStride,
|
||||
cudaStream_t stream);
|
||||
#endif
|
||||
|
||||
|
||||
@ -32,7 +32,7 @@ void invokeLogitsBitmask(
|
||||
|
||||
template <typename T>
|
||||
void invokeContiguousLogitsBitmask(T* logits, uint32_t const* bitmask, int32_t const* tokenMask, int32_t const* d2t,
|
||||
int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskSize, cudaStream_t stream);
|
||||
int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskStride, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
|
||||
@ -23,6 +23,11 @@ TRTLLM_NAMESPACE_BEGIN
|
||||
namespace torch_ext
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
int32_t constexpr kBitsPerMaskElement = 32;
|
||||
} // namespace
|
||||
|
||||
void logitsBitmask(torch::Tensor const& logits, torch::Tensor const& bitmask,
|
||||
at::optional<torch::Tensor> const& tokenMask = at::nullopt, at::optional<torch::Tensor> const& d2t = at::nullopt)
|
||||
{
|
||||
@ -34,12 +39,15 @@ void logitsBitmask(torch::Tensor const& logits, torch::Tensor const& bitmask,
|
||||
TORCH_CHECK(bitmask.size(0) == batchSize, "bitmask must have the same batch size as logits.");
|
||||
|
||||
int32_t vocabSizePadded = logits.size(1);
|
||||
int32_t bitmaskSize = bitmask.size(1);
|
||||
if (!d2t.has_value())
|
||||
{
|
||||
TORCH_CHECK(bitmask.size(1) == tensorrt_llm::common::ceilDiv(vocabSizePadded, kBitsPerMaskElement),
|
||||
"bitmask.size(1) must be equal to ceilDiv(vocab_size, 32).");
|
||||
}
|
||||
TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor.");
|
||||
TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous.");
|
||||
TORCH_CHECK(logits.dim() == 2, "logits must be a 2D tensor.");
|
||||
TORCH_CHECK(bitmask.is_cuda(), "bitmask must be a CUDA tensor.");
|
||||
TORCH_CHECK(bitmask.is_contiguous(), "bitmask must be contiguous.");
|
||||
TORCH_CHECK(bitmask.dim() == 2, "bitmask must be a 2D tensor.");
|
||||
TORCH_CHECK(bitmask.scalar_type() == torch::kUInt32 || bitmask.scalar_type() == torch::kInt32,
|
||||
"bitmask must have element type uint32 or int32.");
|
||||
@ -52,7 +60,7 @@ void logitsBitmask(torch::Tensor const& logits, torch::Tensor const& bitmask,
|
||||
TORCH_CHECK(tokenMask->dim() == 1, "tokenMask must be a 1D tensor.");
|
||||
TORCH_CHECK(tokenMask->size(0) == batchSize, "tokenMask must have the same batch size as logits.");
|
||||
TORCH_CHECK(tokenMask->scalar_type() == torch::kInt32, "tokenMask must have element type int32.");
|
||||
tokenMaskPtr = reinterpret_cast<int32_t const*>(tokenMask->data_ptr());
|
||||
tokenMaskPtr = static_cast<int32_t const*>(tokenMask->data_ptr());
|
||||
}
|
||||
|
||||
int32_t const* d2tPtr = nullptr;
|
||||
@ -63,7 +71,7 @@ void logitsBitmask(torch::Tensor const& logits, torch::Tensor const& bitmask,
|
||||
TORCH_CHECK(d2t->dim() == 1, "d2t must be a 1D tensor.");
|
||||
TORCH_CHECK(d2t->size(0) == vocabSizePadded, "d2t must have the same vocab size as logits.");
|
||||
TORCH_CHECK(d2t->scalar_type() == torch::kInt32, "d2t must have element type int32.");
|
||||
d2tPtr = reinterpret_cast<int32_t const*>(d2t->data_ptr());
|
||||
d2tPtr = static_cast<int32_t const*>(d2t->data_ptr());
|
||||
}
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(logits.get_device()).stream();
|
||||
@ -72,23 +80,23 @@ void logitsBitmask(torch::Tensor const& logits, torch::Tensor const& bitmask,
|
||||
{
|
||||
case torch::kFloat32:
|
||||
{
|
||||
tensorrt_llm::kernels::invokeContiguousLogitsBitmask<float>(reinterpret_cast<float*>(logits.data_ptr()),
|
||||
reinterpret_cast<uint32_t const*>(bitmask.data_ptr()), tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded,
|
||||
bitmaskSize, stream);
|
||||
tensorrt_llm::kernels::invokeContiguousLogitsBitmask<float>(static_cast<float*>(logits.data_ptr()),
|
||||
static_cast<uint32_t const*>(bitmask.data_ptr()), tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded,
|
||||
bitmask.stride(0), stream);
|
||||
break;
|
||||
}
|
||||
case torch::kFloat16:
|
||||
{
|
||||
tensorrt_llm::kernels::invokeContiguousLogitsBitmask<__half>(reinterpret_cast<__half*>(logits.data_ptr()),
|
||||
reinterpret_cast<uint32_t const*>(bitmask.data_ptr()), tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded,
|
||||
bitmaskSize, stream);
|
||||
tensorrt_llm::kernels::invokeContiguousLogitsBitmask<__half>(static_cast<__half*>(logits.data_ptr()),
|
||||
static_cast<uint32_t const*>(bitmask.data_ptr()), tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded,
|
||||
bitmask.stride(0), stream);
|
||||
break;
|
||||
}
|
||||
case torch::kBFloat16:
|
||||
{
|
||||
tensorrt_llm::kernels::invokeContiguousLogitsBitmask<__nv_bfloat16>(
|
||||
reinterpret_cast<__nv_bfloat16*>(logits.data_ptr()), reinterpret_cast<uint32_t const*>(bitmask.data_ptr()),
|
||||
tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded, bitmaskSize, stream);
|
||||
static_cast<__nv_bfloat16*>(logits.data_ptr()), static_cast<uint32_t const*>(bitmask.data_ptr()),
|
||||
tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded, bitmask.stride(0), stream);
|
||||
break;
|
||||
}
|
||||
default: TORCH_CHECK(false, "logits dtype must be float, half or bfloat16."); break;
|
||||
|
||||
@ -145,11 +145,13 @@ class GuidedDecoder:
|
||||
guided_decoding_config: GuidedDecodingConfig,
|
||||
max_num_sequences: int,
|
||||
vocab_size_padded: int,
|
||||
max_num_draft_tokens: int = 0):
|
||||
max_num_draft_tokens: int = 0,
|
||||
rank: int = 0):
|
||||
self.guided_decoding_backend = guided_decoding_config.backend
|
||||
self.max_num_sequences = max_num_sequences
|
||||
self.vocab_size_padded = vocab_size_padded
|
||||
self.max_num_draft_tokens = max_num_draft_tokens
|
||||
self.rank = rank
|
||||
|
||||
if self.guided_decoding_backend == GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR:
|
||||
self.grammar_matcher_factory = XGrammarMatcherFactory(
|
||||
@ -305,9 +307,26 @@ class GuidedDecoder:
|
||||
"""
|
||||
if num_bitmask_tokens is None:
|
||||
num_bitmask_tokens = requests.num_bitmask_tokens
|
||||
|
||||
# In general, the logits passed to GuidedDecoder are complete in the vocabulary dimension.
|
||||
# In some special cases (e.g., MTP), the logits are sharded in the vocabulary dimension.
|
||||
vocab_size_padded = self.vocab_size_padded if d2t is None else d2t.size(
|
||||
0)
|
||||
assert vocab_size_padded % logits.size(1) == 0
|
||||
tp_size = vocab_size_padded // logits.size(1)
|
||||
assert self.bitmask_size % tp_size == 0
|
||||
tp_rank = self.rank % tp_size
|
||||
bitmask_start = tp_rank * self.bitmask_size // tp_size
|
||||
bitmask_end = bitmask_start + self.bitmask_size // tp_size
|
||||
|
||||
if d2t is not None:
|
||||
d2t_start = tp_rank * vocab_size_padded // tp_size
|
||||
d2t_end = d2t_start + vocab_size_padded // tp_size
|
||||
d2t = d2t[d2t_start:d2t_end]
|
||||
|
||||
torch.ops.trtllm.logits_bitmask(
|
||||
logits[:num_bitmask_tokens],
|
||||
self.bitmask[:num_bitmask_tokens],
|
||||
self.bitmask[:num_bitmask_tokens, bitmask_start:bitmask_end],
|
||||
token_mask=self.token_mask[:num_bitmask_tokens],
|
||||
d2t=d2t)
|
||||
|
||||
@ -423,9 +442,13 @@ class CapturableGuidedDecoder(GuidedDecoder):
|
||||
guided_decoding_config: GuidedDecodingConfig,
|
||||
max_num_sequences: int,
|
||||
vocab_size_padded: int,
|
||||
max_num_draft_tokens: int = 0):
|
||||
super().__init__(guided_decoding_config, max_num_sequences,
|
||||
vocab_size_padded, max_num_draft_tokens)
|
||||
max_num_draft_tokens: int = 0,
|
||||
rank: int = 0):
|
||||
super().__init__(guided_decoding_config=guided_decoding_config,
|
||||
max_num_sequences=max_num_sequences,
|
||||
vocab_size_padded=vocab_size_padded,
|
||||
max_num_draft_tokens=max_num_draft_tokens,
|
||||
rank=rank)
|
||||
# self.requests should be accessed by normal host code;
|
||||
# self.requests_hostfunc should be accessed by hostfunc (CUDA callback).
|
||||
self.requests_hostfunc: Optional[GuidedRequests] = None
|
||||
|
||||
@ -505,7 +505,8 @@ def create_py_executor(
|
||||
kwargs = {
|
||||
"guided_decoding_config": guided_decoding_config,
|
||||
"max_num_sequences": max_batch_size,
|
||||
"vocab_size_padded": model_engine.model.vocab_size_padded
|
||||
"vocab_size_padded": model_engine.model.vocab_size_padded,
|
||||
"rank": mapping.rank,
|
||||
}
|
||||
if spec_config is not None:
|
||||
kwargs[
|
||||
|
||||
Loading…
Reference in New Issue
Block a user