From 7b8b9ccbaf35fffe3d24fc4227c25b8d9633fda1 Mon Sep 17 00:00:00 2001 From: Enwei Zhu <21126786+syuoni@users.noreply.github.com> Date: Fri, 16 Jan 2026 11:04:26 +0800 Subject: [PATCH] [https://nvbugs/5669671][fix] Support GuidedDecoder with sharded logits (#10698) Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> --- cpp/tensorrt_llm/kernels/logitsBitmask.cu | 51 ++++++++++--------- cpp/tensorrt_llm/kernels/logitsBitmask.h | 2 +- cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp | 32 +++++++----- .../_torch/pyexecutor/guided_decoder.py | 33 ++++++++++-- .../_torch/pyexecutor/py_executor_creator.py | 3 +- 5 files changed, 79 insertions(+), 42 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/logitsBitmask.cu b/cpp/tensorrt_llm/kernels/logitsBitmask.cu index ac66967e0f..0bb1ddd6be 100644 --- a/cpp/tensorrt_llm/kernels/logitsBitmask.cu +++ b/cpp/tensorrt_llm/kernels/logitsBitmask.cu @@ -64,8 +64,8 @@ __device__ PackedT packedNegativeInfinity() } // namespace template -__global__ void __launch_bounds__(kThreadsPerBlock) logitsBitmaskKernel( - T** __restrict__ logits, uint32_t const** __restrict__ bitmask, int32_t vocabSizePadded, int32_t bitmaskSize) +__global__ void __launch_bounds__(kThreadsPerBlock) + logitsBitmaskKernel(T** __restrict__ logits, uint32_t const** __restrict__ bitmask, int32_t vocabSizePadded) { int constexpr kAlignment = sizeof(PackedT) / sizeof(T); uint32_t constexpr kPackedMask = (1 << kAlignment) - 1; @@ -123,29 +123,28 @@ void logitsBitmaskDispatchToBitsPerThread( static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); int32_t const numBlocksPerRow = ceilDiv(2048 / kThreadsPerBlock * smCount, batchSize); int32_t const numBitsPerThread = ceilDiv(vocabSizePadded, kThreadsPerBlock * numBlocksPerRow); - int32_t bitmaskSize = ceilDiv(vocabSizePadded, kBitsPerMaskElement); dim3 const block(kThreadsPerBlock); if (numBitsPerThread <= 4 && kAlignment <= 4) { dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 4), batchSize); - logitsBitmaskKernel<<>>(logits, bitmask, vocabSizePadded, bitmaskSize); + logitsBitmaskKernel<<>>(logits, bitmask, vocabSizePadded); } else if (numBitsPerThread <= 8 && kAlignment <= 8) { dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 8), batchSize); - logitsBitmaskKernel<<>>(logits, bitmask, vocabSizePadded, bitmaskSize); + logitsBitmaskKernel<<>>(logits, bitmask, vocabSizePadded); } else if (numBitsPerThread <= 16 && kAlignment <= 16) { dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 16), batchSize); - logitsBitmaskKernel<<>>(logits, bitmask, vocabSizePadded, bitmaskSize); + logitsBitmaskKernel<<>>(logits, bitmask, vocabSizePadded); } else { dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 32), batchSize); - logitsBitmaskKernel<<>>(logits, bitmask, vocabSizePadded, bitmaskSize); + logitsBitmaskKernel<<>>(logits, bitmask, vocabSizePadded); } } @@ -185,7 +184,7 @@ template void invokeLogitsBitmask<__nv_bfloat16>( template __global__ void __launch_bounds__(kThreadsPerBlock) contiguousLogitsBitmaskKernel(T* __restrict__ logits, uint32_t const* __restrict__ bitmask, int32_t const* __restrict__ tokenMask, int32_t const* __restrict__ d2t, - int32_t vocabSizePadded, int32_t bitmaskSize) + int32_t vocabSizePadded, int32_t bitmaskStride) { int constexpr kAlignment = sizeof(PackedT) / sizeof(T); uint32_t constexpr kPackedMask = (1 << kAlignment) - 1; @@ -199,7 +198,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) contiguousLogitsBitmaskKerne int const blockOffset = blockIdx.x * kThreadsPerBlock * kBitsPerThread; T* logitsGmemPtr = logits + batchIdx * vocabSizePadded + blockOffset; - uint32_t const* bitmaskGmemPtr = bitmask + batchIdx * bitmaskSize + blockOffset / kBitsPerMaskElement; + uint32_t const* bitmaskGmemPtr = bitmask + batchIdx * bitmaskStride + blockOffset / kBitsPerMaskElement; int const bitmaskInnerIdx = threadIdx.x % (kBitsPerMaskElement / kAlignment); T logitsReg[kAlignment]; @@ -224,7 +223,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) contiguousLogitsBitmaskKerne for (int i = 0; i < kAlignment; i++) { int const d2tOffset = blockOffset + offset + i + d2t[blockOffset + offset + i]; - bitmaskVal |= ((~bitmask[batchIdx * bitmaskSize + d2tOffset / kBitsPerMaskElement] + bitmaskVal |= ((~bitmask[batchIdx * bitmaskStride + d2tOffset / kBitsPerMaskElement] >> (d2tOffset % kBitsPerMaskElement)) & 1) << i; @@ -257,7 +256,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) contiguousLogitsBitmaskKerne template void contiguousLogitsBitmaskDispatchToBitsPerThread(T* logits, uint32_t const* bitmask, int32_t const* tokenMask, - int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskSize, cudaStream_t stream) + int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskStride, cudaStream_t stream) { int constexpr kAlignment = sizeof(PackedT) / sizeof(T); static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); @@ -270,63 +269,69 @@ void contiguousLogitsBitmaskDispatchToBitsPerThread(T* logits, uint32_t const* b { dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 4), batchSize); contiguousLogitsBitmaskKernel - <<>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskSize); + <<>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskStride); } else if (numBitsPerThread <= 8 && kAlignment <= 8) { dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 8), batchSize); contiguousLogitsBitmaskKernel - <<>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskSize); + <<>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskStride); } else if (numBitsPerThread <= 16 && kAlignment <= 16) { dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 16), batchSize); contiguousLogitsBitmaskKernel - <<>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskSize); + <<>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskStride); } else { dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 32), batchSize); contiguousLogitsBitmaskKernel - <<>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskSize); + <<>>(logits, bitmask, tokenMask, d2t, vocabSizePadded, bitmaskStride); } } template void invokeContiguousLogitsBitmask(T* logits, uint32_t const* bitmask, int32_t const* tokenMask, int32_t const* d2t, - int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskSize, cudaStream_t stream) + int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskStride, cudaStream_t stream) { + // bitmaskStride may not equal to ceilDiv(vocabSizePadded, kBitsPerMaskElement) when: + // (1) d2t is present, then logits are "pruned" (e.g., EAGLE3), while bitmask is complete and should be accessed + // according to d2t. + // (2) logits are sharded along the vocabulary dimension, while the bitmask is complete and should be accessed + // according to the sharding. + // Dispatch to PackedT if (vocabSizePadded % (sizeof(float4) / sizeof(T)) == 0) { contiguousLogitsBitmaskDispatchToBitsPerThread( - logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskSize, stream); + logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskStride, stream); } else if (vocabSizePadded % (sizeof(float2) / sizeof(T)) == 0) { contiguousLogitsBitmaskDispatchToBitsPerThread( - logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskSize, stream); + logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskStride, stream); } else if (vocabSizePadded % (sizeof(float) / sizeof(T)) == 0) { contiguousLogitsBitmaskDispatchToBitsPerThread( - logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskSize, stream); + logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskStride, stream); } else { contiguousLogitsBitmaskDispatchToBitsPerThread( - logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskSize, stream); + logits, bitmask, tokenMask, d2t, batchSize, vocabSizePadded, bitmaskStride, stream); } } template void invokeContiguousLogitsBitmask(float* logits, uint32_t const* bitmask, int32_t const* tokenMask, - int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskSize, cudaStream_t stream); + int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskStride, cudaStream_t stream); template void invokeContiguousLogitsBitmask(half* logits, uint32_t const* bitmask, int32_t const* tokenMask, - int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskSize, cudaStream_t stream); + int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskStride, cudaStream_t stream); #ifdef ENABLE_BF16 template void invokeContiguousLogitsBitmask<__nv_bfloat16>(__nv_bfloat16* logits, uint32_t const* bitmask, - int32_t const* tokenMask, int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskSize, + int32_t const* tokenMask, int32_t const* d2t, int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskStride, cudaStream_t stream); #endif diff --git a/cpp/tensorrt_llm/kernels/logitsBitmask.h b/cpp/tensorrt_llm/kernels/logitsBitmask.h index e2e6cb28cd..e3fb6c08fb 100644 --- a/cpp/tensorrt_llm/kernels/logitsBitmask.h +++ b/cpp/tensorrt_llm/kernels/logitsBitmask.h @@ -32,7 +32,7 @@ void invokeLogitsBitmask( template void invokeContiguousLogitsBitmask(T* logits, uint32_t const* bitmask, int32_t const* tokenMask, int32_t const* d2t, - int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskSize, cudaStream_t stream); + int32_t batchSize, int32_t vocabSizePadded, int32_t bitmaskStride, cudaStream_t stream); } // namespace kernels diff --git a/cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp b/cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp index 2f6eddd5ca..6c3e83fd85 100644 --- a/cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp +++ b/cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp @@ -23,6 +23,11 @@ TRTLLM_NAMESPACE_BEGIN namespace torch_ext { +namespace +{ +int32_t constexpr kBitsPerMaskElement = 32; +} // namespace + void logitsBitmask(torch::Tensor const& logits, torch::Tensor const& bitmask, at::optional const& tokenMask = at::nullopt, at::optional const& d2t = at::nullopt) { @@ -34,12 +39,15 @@ void logitsBitmask(torch::Tensor const& logits, torch::Tensor const& bitmask, TORCH_CHECK(bitmask.size(0) == batchSize, "bitmask must have the same batch size as logits."); int32_t vocabSizePadded = logits.size(1); - int32_t bitmaskSize = bitmask.size(1); + if (!d2t.has_value()) + { + TORCH_CHECK(bitmask.size(1) == tensorrt_llm::common::ceilDiv(vocabSizePadded, kBitsPerMaskElement), + "bitmask.size(1) must be equal to ceilDiv(vocab_size, 32)."); + } TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor."); TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous."); TORCH_CHECK(logits.dim() == 2, "logits must be a 2D tensor."); TORCH_CHECK(bitmask.is_cuda(), "bitmask must be a CUDA tensor."); - TORCH_CHECK(bitmask.is_contiguous(), "bitmask must be contiguous."); TORCH_CHECK(bitmask.dim() == 2, "bitmask must be a 2D tensor."); TORCH_CHECK(bitmask.scalar_type() == torch::kUInt32 || bitmask.scalar_type() == torch::kInt32, "bitmask must have element type uint32 or int32."); @@ -52,7 +60,7 @@ void logitsBitmask(torch::Tensor const& logits, torch::Tensor const& bitmask, TORCH_CHECK(tokenMask->dim() == 1, "tokenMask must be a 1D tensor."); TORCH_CHECK(tokenMask->size(0) == batchSize, "tokenMask must have the same batch size as logits."); TORCH_CHECK(tokenMask->scalar_type() == torch::kInt32, "tokenMask must have element type int32."); - tokenMaskPtr = reinterpret_cast(tokenMask->data_ptr()); + tokenMaskPtr = static_cast(tokenMask->data_ptr()); } int32_t const* d2tPtr = nullptr; @@ -63,7 +71,7 @@ void logitsBitmask(torch::Tensor const& logits, torch::Tensor const& bitmask, TORCH_CHECK(d2t->dim() == 1, "d2t must be a 1D tensor."); TORCH_CHECK(d2t->size(0) == vocabSizePadded, "d2t must have the same vocab size as logits."); TORCH_CHECK(d2t->scalar_type() == torch::kInt32, "d2t must have element type int32."); - d2tPtr = reinterpret_cast(d2t->data_ptr()); + d2tPtr = static_cast(d2t->data_ptr()); } auto stream = at::cuda::getCurrentCUDAStream(logits.get_device()).stream(); @@ -72,23 +80,23 @@ void logitsBitmask(torch::Tensor const& logits, torch::Tensor const& bitmask, { case torch::kFloat32: { - tensorrt_llm::kernels::invokeContiguousLogitsBitmask(reinterpret_cast(logits.data_ptr()), - reinterpret_cast(bitmask.data_ptr()), tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded, - bitmaskSize, stream); + tensorrt_llm::kernels::invokeContiguousLogitsBitmask(static_cast(logits.data_ptr()), + static_cast(bitmask.data_ptr()), tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded, + bitmask.stride(0), stream); break; } case torch::kFloat16: { - tensorrt_llm::kernels::invokeContiguousLogitsBitmask<__half>(reinterpret_cast<__half*>(logits.data_ptr()), - reinterpret_cast(bitmask.data_ptr()), tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded, - bitmaskSize, stream); + tensorrt_llm::kernels::invokeContiguousLogitsBitmask<__half>(static_cast<__half*>(logits.data_ptr()), + static_cast(bitmask.data_ptr()), tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded, + bitmask.stride(0), stream); break; } case torch::kBFloat16: { tensorrt_llm::kernels::invokeContiguousLogitsBitmask<__nv_bfloat16>( - reinterpret_cast<__nv_bfloat16*>(logits.data_ptr()), reinterpret_cast(bitmask.data_ptr()), - tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded, bitmaskSize, stream); + static_cast<__nv_bfloat16*>(logits.data_ptr()), static_cast(bitmask.data_ptr()), + tokenMaskPtr, d2tPtr, batchSize, vocabSizePadded, bitmask.stride(0), stream); break; } default: TORCH_CHECK(false, "logits dtype must be float, half or bfloat16."); break; diff --git a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py index 0d40951604..01386e55e5 100644 --- a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py +++ b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py @@ -145,11 +145,13 @@ class GuidedDecoder: guided_decoding_config: GuidedDecodingConfig, max_num_sequences: int, vocab_size_padded: int, - max_num_draft_tokens: int = 0): + max_num_draft_tokens: int = 0, + rank: int = 0): self.guided_decoding_backend = guided_decoding_config.backend self.max_num_sequences = max_num_sequences self.vocab_size_padded = vocab_size_padded self.max_num_draft_tokens = max_num_draft_tokens + self.rank = rank if self.guided_decoding_backend == GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR: self.grammar_matcher_factory = XGrammarMatcherFactory( @@ -305,9 +307,26 @@ class GuidedDecoder: """ if num_bitmask_tokens is None: num_bitmask_tokens = requests.num_bitmask_tokens + + # In general, the logits passed to GuidedDecoder are complete in the vocabulary dimension. + # In some special cases (e.g., MTP), the logits are sharded in the vocabulary dimension. + vocab_size_padded = self.vocab_size_padded if d2t is None else d2t.size( + 0) + assert vocab_size_padded % logits.size(1) == 0 + tp_size = vocab_size_padded // logits.size(1) + assert self.bitmask_size % tp_size == 0 + tp_rank = self.rank % tp_size + bitmask_start = tp_rank * self.bitmask_size // tp_size + bitmask_end = bitmask_start + self.bitmask_size // tp_size + + if d2t is not None: + d2t_start = tp_rank * vocab_size_padded // tp_size + d2t_end = d2t_start + vocab_size_padded // tp_size + d2t = d2t[d2t_start:d2t_end] + torch.ops.trtllm.logits_bitmask( logits[:num_bitmask_tokens], - self.bitmask[:num_bitmask_tokens], + self.bitmask[:num_bitmask_tokens, bitmask_start:bitmask_end], token_mask=self.token_mask[:num_bitmask_tokens], d2t=d2t) @@ -423,9 +442,13 @@ class CapturableGuidedDecoder(GuidedDecoder): guided_decoding_config: GuidedDecodingConfig, max_num_sequences: int, vocab_size_padded: int, - max_num_draft_tokens: int = 0): - super().__init__(guided_decoding_config, max_num_sequences, - vocab_size_padded, max_num_draft_tokens) + max_num_draft_tokens: int = 0, + rank: int = 0): + super().__init__(guided_decoding_config=guided_decoding_config, + max_num_sequences=max_num_sequences, + vocab_size_padded=vocab_size_padded, + max_num_draft_tokens=max_num_draft_tokens, + rank=rank) # self.requests should be accessed by normal host code; # self.requests_hostfunc should be accessed by hostfunc (CUDA callback). self.requests_hostfunc: Optional[GuidedRequests] = None diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 25f391d813..46960dabe7 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -505,7 +505,8 @@ def create_py_executor( kwargs = { "guided_decoding_config": guided_decoding_config, "max_num_sequences": max_batch_size, - "vocab_size_padded": model_engine.model.vocab_size_padded + "vocab_size_padded": model_engine.model.vocab_size_padded, + "rank": mapping.rank, } if spec_config is not None: kwargs[