[https://nvbugs/5669671][fix] Support GuidedDecoder with sharded logits (#10698)

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
Enwei Zhu 2026-01-16 11:04:26 +08:00 committed by GitHub
parent 9f741fb254
commit 7b8b9ccbaf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 79 additions and 42 deletions

View File

@ -64,8 +64,8 @@ __device__ PackedT packedNegativeInfinity()
} // namespace
template <typename T, typename PackedT, int32_t kBitsPerThread>
__global__ void __launch_bounds__(kThreadsPerBlock) logitsBitmaskKernel(
T** __restrict__ logits, uint32_t const** __restrict__ bitmask, int32_t vocabSizePadded, int32_t bitmaskSize)
__global__ void __launch_bounds__(kThreadsPerBlock)
logitsBitmaskKernel(T** __restrict__ logits, uint32_t const** __restrict__ bitmask, int32_t vocabSizePadded)
{
int constexpr kAlignment = sizeof(PackedT) / sizeof(T);
uint32_t constexpr kPackedMask = (1 << kAlignment) - 1;
@ -123,29 +123,28 @@ void logitsBitmaskDispatchToBitsPerThread(
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const numBlocksPerRow = ceilDiv(2048 / kThreadsPerBlock * smCount, batchSize);
int32_t const numBitsPerThread = ceilDiv(vocabSizePadded, kThreadsPerBlock * numBlocksPerRow);
int32_t bitmaskSize = ceilDiv(vocabSizePadded, kBitsPerMaskElement);
dim3 const block(kThreadsPerBlock);
if (numBitsPerThread <= 4 && kAlignment <= 4)
{
dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 4), batchSize);
logitsBitmaskKernel<T, PackedT, 4><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded, bitmaskSize);
logitsBitmaskKernel<T, PackedT, 4><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded);
}
else if (numBitsPerThread <= 8 && kAlignment <= 8)
{
dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 8), batchSize);
logitsBitmaskKernel<T, PackedT, 8><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded, bitmaskSize);
logitsBitmaskKernel<T, PackedT, 8><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded);
}
else if (numBitsPerThread <= 16 && kAlignment <= 16)
{
dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 16), batchSize);
logitsBitmaskKernel<T, PackedT, 16><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded, bitmaskSize);
logitsBitmaskKernel<T, PackedT, 16><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded);
}
else
{
dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 32), batchSize);
logitsBitmaskKernel<T, PackedT, 32><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded, bitmaskSize);
logitsBitmaskKernel<T, PackedT, 32><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded);
}
}
@ -185,7 +184,7 @@ template void invokeLogitsBitmask<__nv_bfloat16>(
template <typename T, typename PackedT, int32_t kBitsPerThread>
__global__ void __launch_bounds__(kThreadsPerBlock) contiguousLogitsBitmaskKernel(T* __restrict__ logits,
uint32_t const* __restrict__ bitmask, int32_t const* __restrict__ tokenMask, int32_t const* __restrict__ d2t,
int32_t vocabSizePadded, int32_t bitmaskSize)
int32_t vocabSizePadded, int32_t bitmaskStride)
{
int constexpr kAlignment = sizeof(PackedT) / sizeof(T);
uint32_t constexpr kPackedMask = (1 << kAlignment) - 1;
@ -199,7 +198,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) contiguousLogitsBitmaskKerne
int const blockOffset = blockIdx.x * kThreadsPerBlock * kBitsPerThread;
T* logitsGmemPtr = logits + batchIdx * vocabSizePadded + blockOffset;
uint32_t const* bitmaskGmemPtr = bitmask + batchIdx * bitmaskSize + blockOffset / kBitsPerMaskElement;
uint32_t const* bitmaskGmemPtr = bitmask + batchIdx * bitmaskStride + blockOffset / kBitsPerMaskElement;
int const bitmaskInnerIdx = threadIdx.x % (kBitsPerMaskElement / kAlignment);
T logitsReg[kAlignment];
@ -224,7 +223,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) contiguousLogitsBitmaskKerne
for (int i = 0; i < kAlignment; i++)
{
int const d2tOffset = blockOffset + offset + i + d2t[blockOffset + offset + i];
bitmaskVal |= ((~bitmask[batchIdx * bitmaskSize + d2tOffset / kBitsPerMaskElement]
bitmaskVal |= ((~bitmask[batchIdx * bitmaskStride + d2tOffset / kBitsPerMaskElement]
>> (d2tOffset % kBitsPerMaskElement))
& 1)
<< i;
@ -257,7 +256,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) contiguousLogitsBitmaskKerne
template <typename T, typename PackedT>
void contiguousLogitsBitmaskDispatchToBitsPerThread(T* logits, uint32_t const* bitmask, int32_t const* tokenMask,
int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskSize, cudaStream_t stream)
int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskStride, cudaStream_t stream)
{
int constexpr kAlignment = sizeof(PackedT) / sizeof(T);
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
@ -270,63 +269,69 @@ void contiguousLogitsBitmaskDispatchToBitsPerThread(T* logits, uint32_t const* b
{
dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 4), batchSize);
contiguousLogitsBitmaskKernel<T, PackedT, 4>
<<<grid, block, 0, stream>>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskSize);
<<<grid, block, 0, stream>>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskStride);
}
else if (numBitsPerThread <= 8 && kAlignment <= 8)
{
dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 8), batchSize);
contiguousLogitsBitmaskKernel<T, PackedT, 8>
<<<grid, block, 0, stream>>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskSize);
<<<grid, block, 0, stream>>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskStride);
}
else if (numBitsPerThread <= 16 && kAlignment <= 16)
{
dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 16), batchSize);
contiguousLogitsBitmaskKernel<T, PackedT, 16>
<<<grid, block, 0, stream>>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskSize);
<<<grid, block, 0, stream>>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskStride);
}
else
{
dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 32), batchSize);
contiguousLogitsBitmaskKernel<T, PackedT, 32>
<<<grid, block, 0, stream>>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskSize);
<<<grid, block, 0, stream>>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskStride);
}
}
template <typename T>
void invokeContiguousLogitsBitmask(T* logits, uint32_t const* bitmask, int32_t const* tokenMask, int32_t const* d2t,
int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskSize, cudaStream_t stream)
int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskStride, cudaStream_t stream)
{
// bitmaskStride may not equal to ceilDiv(vocabSizePadded, kBitsPerMaskElement) when:
// (1) d2t is present, then logits are "pruned" (e.g., EAGLE3), while bitmask is complete and should be accessed
// according to d2t.
// (2) logits are sharded along the vocabulary dimension, while the bitmask is complete and should be accessed
// according to the sharding.
// Dispatch to PackedT
if (vocabSizePadded % (sizeof(float4) / sizeof(T)) == 0)
{
contiguousLogitsBitmaskDispatchToBitsPerThread<T, float4>(
logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskSize, stream);
logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskStride, stream);
}
else if (vocabSizePadded % (sizeof(float2) / sizeof(T)) == 0)
{
contiguousLogitsBitmaskDispatchToBitsPerThread<T, float2>(
logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskSize, stream);
logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskStride, stream);
}
else if (vocabSizePadded % (sizeof(float) / sizeof(T)) == 0)
{
contiguousLogitsBitmaskDispatchToBitsPerThread<T, float>(
logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskSize, stream);
logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskStride, stream);
}
else
{
contiguousLogitsBitmaskDispatchToBitsPerThread<T, T>(
logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskSize, stream);
logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskStride, stream);
}
}
template void invokeContiguousLogitsBitmask<float>(float* logits, uint32_t const* bitmask, int32_t const* tokenMask,
int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskSize, cudaStream_t stream);
int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskStride, cudaStream_t stream);
template void invokeContiguousLogitsBitmask<half>(half* logits, uint32_t const* bitmask, int32_t const* tokenMask,
int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskSize, cudaStream_t stream);
int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskStride, cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeContiguousLogitsBitmask<__nv_bfloat16>(__nv_bfloat16* logits, uint32_t const* bitmask,
int32_t const* tokenMask, int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskSize,
int32_t const* tokenMask, int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskStride,
cudaStream_t stream);
#endif

View File

@ -32,7 +32,7 @@ void invokeLogitsBitmask(
template <typename T>
void invokeContiguousLogitsBitmask(T* logits, uint32_t const* bitmask, int32_t const* tokenMask, int32_t const* d2t,
int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskSize, cudaStream_t stream);
int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskStride, cudaStream_t stream);
} // namespace kernels

View File

@ -23,6 +23,11 @@ TRTLLM_NAMESPACE_BEGIN
namespace torch_ext
{
namespace
{
int32_t constexpr kBitsPerMaskElement = 32;
} // namespace
void logitsBitmask(torch::Tensor const& logits, torch::Tensor const& bitmask,
at::optional<torch::Tensor> const& tokenMask = at::nullopt, at::optional<torch::Tensor> const& d2t = at::nullopt)
{
@ -34,12 +39,15 @@ void logitsBitmask(torch::Tensor const& logits, torch::Tensor const& bitmask,
TORCH_CHECK(bitmask.size(0) == batchSize, "bitmask must have the same batch size as logits.");
int32_t vocabSizePadded = logits.size(1);
int32_t bitmaskSize = bitmask.size(1);
if (!d2t.has_value())
{
TORCH_CHECK(bitmask.size(1) == tensorrt_llm::common::ceilDiv(vocabSizePadded, kBitsPerMaskElement),
"bitmask.size(1) must be equal to ceilDiv(vocab_size, 32).");
}
TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor.");
TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous.");
TORCH_CHECK(logits.dim() == 2, "logits must be a 2D tensor.");
TORCH_CHECK(bitmask.is_cuda(), "bitmask must be a CUDA tensor.");
TORCH_CHECK(bitmask.is_contiguous(), "bitmask must be contiguous.");
TORCH_CHECK(bitmask.dim() == 2, "bitmask must be a 2D tensor.");
TORCH_CHECK(bitmask.scalar_type() == torch::kUInt32 || bitmask.scalar_type() == torch::kInt32,
"bitmask must have element type uint32 or int32.");
@ -52,7 +60,7 @@ void logitsBitmask(torch::Tensor const& logits, torch::Tensor const& bitmask,
TORCH_CHECK(tokenMask->dim() == 1, "tokenMask must be a 1D tensor.");
TORCH_CHECK(tokenMask->size(0) == batchSize, "tokenMask must have the same batch size as logits.");
TORCH_CHECK(tokenMask->scalar_type() == torch::kInt32, "tokenMask must have element type int32.");
tokenMaskPtr = reinterpret_cast<int32_t const*>(tokenMask->data_ptr());
tokenMaskPtr = static_cast<int32_t const*>(tokenMask->data_ptr());
}
int32_t const* d2tPtr = nullptr;
@ -63,7 +71,7 @@ void logitsBitmask(torch::Tensor const& logits, torch::Tensor const& bitmask,
TORCH_CHECK(d2t->dim() == 1, "d2t must be a 1D tensor.");
TORCH_CHECK(d2t->size(0) == vocabSizePadded, "d2t must have the same vocab size as logits.");
TORCH_CHECK(d2t->scalar_type() == torch::kInt32, "d2t must have element type int32.");
d2tPtr = reinterpret_cast<int32_t const*>(d2t->data_ptr());
d2tPtr = static_cast<int32_t const*>(d2t->data_ptr());
}
auto stream = at::cuda::getCurrentCUDAStream(logits.get_device()).stream();
@ -72,23 +80,23 @@ void logitsBitmask(torch::Tensor const& logits, torch::Tensor const& bitmask,
{
case torch::kFloat32:
{
tensorrt_llm::kernels::invokeContiguousLogitsBitmask<float>(reinterpret_cast<float*>(logits.data_ptr()),
reinterpret_cast<uint32_t const*>(bitmask.data_ptr()), tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded,
bitmaskSize, stream);
tensorrt_llm::kernels::invokeContiguousLogitsBitmask<float>(static_cast<float*>(logits.data_ptr()),
static_cast<uint32_t const*>(bitmask.data_ptr()), tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded,
bitmask.stride(0), stream);
break;
}
case torch::kFloat16:
{
tensorrt_llm::kernels::invokeContiguousLogitsBitmask<__half>(reinterpret_cast<__half*>(logits.data_ptr()),
reinterpret_cast<uint32_t const*>(bitmask.data_ptr()), tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded,
bitmaskSize, stream);
tensorrt_llm::kernels::invokeContiguousLogitsBitmask<__half>(static_cast<__half*>(logits.data_ptr()),
static_cast<uint32_t const*>(bitmask.data_ptr()), tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded,
bitmask.stride(0), stream);
break;
}
case torch::kBFloat16:
{
tensorrt_llm::kernels::invokeContiguousLogitsBitmask<__nv_bfloat16>(
reinterpret_cast<__nv_bfloat16*>(logits.data_ptr()), reinterpret_cast<uint32_t const*>(bitmask.data_ptr()),
tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded, bitmaskSize, stream);
static_cast<__nv_bfloat16*>(logits.data_ptr()), static_cast<uint32_t const*>(bitmask.data_ptr()),
tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded, bitmask.stride(0), stream);
break;
}
default: TORCH_CHECK(false, "logits dtype must be float, half or bfloat16."); break;

View File

@ -145,11 +145,13 @@ class GuidedDecoder:
guided_decoding_config: GuidedDecodingConfig,
max_num_sequences: int,
vocab_size_padded: int,
max_num_draft_tokens: int = 0):
max_num_draft_tokens: int = 0,
rank: int = 0):
self.guided_decoding_backend = guided_decoding_config.backend
self.max_num_sequences = max_num_sequences
self.vocab_size_padded = vocab_size_padded
self.max_num_draft_tokens = max_num_draft_tokens
self.rank = rank
if self.guided_decoding_backend == GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR:
self.grammar_matcher_factory = XGrammarMatcherFactory(
@ -305,9 +307,26 @@ class GuidedDecoder:
"""
if num_bitmask_tokens is None:
num_bitmask_tokens = requests.num_bitmask_tokens
# In general, the logits passed to GuidedDecoder are complete in the vocabulary dimension.
# In some special cases (e.g., MTP), the logits are sharded in the vocabulary dimension.
vocab_size_padded = self.vocab_size_padded if d2t is None else d2t.size(
0)
assert vocab_size_padded % logits.size(1) == 0
tp_size = vocab_size_padded // logits.size(1)
assert self.bitmask_size % tp_size == 0
tp_rank = self.rank % tp_size
bitmask_start = tp_rank * self.bitmask_size // tp_size
bitmask_end = bitmask_start + self.bitmask_size // tp_size
if d2t is not None:
d2t_start = tp_rank * vocab_size_padded // tp_size
d2t_end = d2t_start + vocab_size_padded // tp_size
d2t = d2t[d2t_start:d2t_end]
torch.ops.trtllm.logits_bitmask(
logits[:num_bitmask_tokens],
self.bitmask[:num_bitmask_tokens],
self.bitmask[:num_bitmask_tokens, bitmask_start:bitmask_end],
token_mask=self.token_mask[:num_bitmask_tokens],
d2t=d2t)
@ -423,9 +442,13 @@ class CapturableGuidedDecoder(GuidedDecoder):
guided_decoding_config: GuidedDecodingConfig,
max_num_sequences: int,
vocab_size_padded: int,
max_num_draft_tokens: int = 0):
super().__init__(guided_decoding_config, max_num_sequences,
vocab_size_padded, max_num_draft_tokens)
max_num_draft_tokens: int = 0,
rank: int = 0):
super().__init__(guided_decoding_config=guided_decoding_config,
max_num_sequences=max_num_sequences,
vocab_size_padded=vocab_size_padded,
max_num_draft_tokens=max_num_draft_tokens,
rank=rank)
# self.requests should be accessed by normal host code;
# self.requests_hostfunc should be accessed by hostfunc (CUDA callback).
self.requests_hostfunc: Optional[GuidedRequests] = None

View File

@ -505,7 +505,8 @@ def create_py_executor(
kwargs = {
"guided_decoding_config": guided_decoding_config,
"max_num_sequences": max_batch_size,
"vocab_size_padded": model_engine.model.vocab_size_padded
"vocab_size_padded": model_engine.model.vocab_size_padded,
"rank": mapping.rank,
}
if spec_config is not None:
kwargs[