/* * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/runtime/medusaModule.h" #include #include namespace tensorrt_llm::runtime { void MedusaModule::initMedusaTensorsFromChoices(MedusaChoices const& choices, std::vector& topKs, TensorPtr& generationInputLengths, TensorPtr& positionOffsets, TensorPtr& treeIds, TensorPtr& paths, TensorPtr& packedMask, SizeType32& totalPaths) const noexcept { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto const numChoices = static_cast(choices.size()); std::vector sortedIndices(numChoices); std::iota(sortedIndices.begin(), sortedIndices.end(), 0); std::vector prefixes(numChoices); // Fill prefixes for all nodes for (SizeType32 ci = 0; ci < numChoices; ++ci) { auto const& choice = choices[ci]; prefixes[ci] = computePrefix(choice, choice.size()); } // Sort indices based on depth and prefix // Prefix has property that on the same depth, nodes that are more left in the tree have smaller prefix std::sort(sortedIndices.begin(), sortedIndices.end(), [&prefixes, &choices](SizeType32 const& a, SizeType32 const& b) { auto const aSize = choices[a].size(); auto const bSize = choices[b].size(); return aSize < bSize || (aSize == bSize && prefixes[a] < prefixes[b]); }); dumpChoices(choices, sortedIndices); topKs.resize(medusaHeads(), 0); auto generationInputLengthsPtr = bufferCast(*generationInputLengths); auto positionOffsetsPtr = bufferCast(*positionOffsets); auto treeIdsPtr = bufferCast(*treeIds); // Fixed sequence length currently. std::fill( generationInputLengthsPtr, generationInputLengthsPtr + generationInputLengths->getSize(), tokensPerStep()); std::fill(positionOffsetsPtr, positionOffsetsPtr + positionOffsets->getSize(), -1); std::fill(treeIdsPtr, treeIdsPtr + treeIds->getSize(), -1); // +1 for root std::vector tree(choices.size() + 1); auto& rootNode = tree[0]; rootNode.depth = 0; // That's how we recognize root rootNode.nodeId = -1; // Start depth from 1 because we count root node implicetely SizeType32 depth = 1; // Running max TopK, reset at every new level SizeType32 maxTopK = 0; // Global node in tree idx is sum of TopKs at previous levels SizeType32 globalNodeInTreeIdx = 0; // Hash table to map prefix to linear idx at previous level std::unordered_map prevPrefixToLinearIdxMap; // Hash table to map prefix to linear idx at current level std::unordered_map curPrefixToLinearIdxMap; // Add root node prevPrefixToLinearIdxMap[0] = 0; positionOffsetsPtr[0] = 0; TLLM_CHECK(numChoices <= maxMedusaTokens()); for (SizeType32 ci = 0; ci < numChoices; ++ci) { auto const index = sortedIndices[ci]; // Access choices based on sorting auto const& choice = choices[index]; auto const curDepth = static_cast(choice.size()); // If new depth is found (choices are sorted by depth) if (curDepth != depth) { TLLM_CHECK(depth + 1 == curDepth); // Save TopK topKs[depth - 1] = maxTopK; // Accumulate TopK for global indexing in tree globalNodeInTreeIdx += maxTopK; // Swap hash maps prevPrefixToLinearIdxMap = curPrefixToLinearIdxMap; // Reset counters maxTopK = 0; curPrefixToLinearIdxMap.clear(); // Increment depth depth++; } MedusaTreeNode node; node.depth = depth; node.linearIdx = ci + 1; // nodeId is index in array of sibling nodes node.nodeId = choice.back(); // Compute prefix from current node and store its linear idx to hash map curPrefixToLinearIdxMap[prefixes[index]] = node.linearIdx; // Compute prefix of the parent node auto const parentPrefix = computePrefix(choice, choice.size() - 1); // Store parent's linear idx node.parentLinearIdx = prevPrefixToLinearIdxMap[parentPrefix]; // Update TopK maxTopK = std::max(maxTopK, node.nodeId + 1); // Position offset is the depth of the node positionOffsetsPtr[node.linearIdx] = depth; // Save tree ids treeIdsPtr[node.linearIdx - 1] = globalNodeInTreeIdx + node.nodeId; // Save node tree[node.linearIdx] = node; } // Write TopK for the last level topKs[depth - 1] = maxTopK; for (SizeType32 ci = 0; ci < numChoices + 1; ++ci) { auto& node = tree[ci]; // For all nodes except root if (node.nodeId != -1) { // Add current linear index to parent's child tree[node.parentLinearIdx].childLinearIndices.push_back(ci); } } totalPaths = computePathsAndMask(tree, packedMask, paths); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } SizeType32 MedusaModule::computePathsAndMask( std::vector const& tree, TensorPtr& packedMask, TensorPtr& paths) const { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); auto pathsPtr = bufferCast(*paths); std::fill(pathsPtr, pathsPtr + paths->getSize(), -1); std::fill(bufferCast(*packedMask), bufferCast(*packedMask) + packedMask->getSize(), 0); SizeType32 numPaths = 0; // Count leaves for (auto const& node : tree) { // Leaf if (node.childLinearIndices.size() == 0) { numPaths++; } } // Stack contains indices of the node in the 'tree' vector std::stack stack; // Add first node stack.push(0); // Index of the current tracked path SizeType32 pathIdx = 0; // Do DFS to construct paths and mask while (!stack.empty()) { auto const ci = stack.top(); stack.pop(); auto const& node = tree[ci]; // If not root copy mask of your parent if (node.nodeId != -1) { copyPackedMask(packedMask, node.parentLinearIdx, ci); } // Set value on the diagonal setOnePackedMask(packedMask, ci, ci); // Leaf if (node.childLinearIndices.size() == 0) { // Current nodeIdx SizeType32 nodeIdx = ci; // Until we hit the root while (tree[nodeIdx].nodeId != -1) { auto const& curNode = tree[nodeIdx]; pathsPtr[(numPaths - 1 - pathIdx) * (medusaHeads() + 1) + curNode.depth] = curNode.linearIdx; // Go from top to the bottom nodeIdx = curNode.parentLinearIdx; } // Fill data for root pathsPtr[(numPaths - 1 - pathIdx) * (medusaHeads() + 1) + 0] = 0; // Fill next path pathIdx++; } // Go over all children of the node and visit them for (auto const& childLinearIdx : node.childLinearIndices) { stack.push(childLinearIdx); } } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return numPaths; } void MedusaModule::copyPackedMask(TensorPtr& mask, SizeType32 srcIdx, SizeType32 dstIdx) const { auto srcRow = ITensor::slice(mask, srcIdx, 1); auto dstRow = ITensor::slice(mask, dstIdx, 1); std::memcpy( bufferCast(*dstRow), bufferCast(*srcRow), numPackedMasks() * sizeof(SizeType32)); } void MedusaModule::setOnePackedMask(TensorPtr& mask, SizeType32 row, SizeType32 col) const { auto const maskIdx = static_cast(col / 32); auto const bitIdx = col % 32; auto setMask = 1 << bitIdx; bufferCast(*mask)[row * numPackedMasks() + maskIdx] |= setMask; } MedusaModule::Prefix MedusaModule::computePrefix(std::vector const& vec, SizeType32 len) const { SizeType32 constexpr BITS_PER_BYTE = 8; // Check that prefix fits into the underlying data type TLLM_CHECK_WITH_INFO(static_cast(sizeof(Prefix)) * BITS_PER_BYTE / PREFIX_CHUNK_SIZE_BITS >= len, "Provided choices have depth (%d) larger than Prefix can fit (%ld).", len, sizeof(Prefix) * BITS_PER_BYTE / PREFIX_CHUNK_SIZE_BITS); Prefix prefix = 0; for (SizeType32 ci = 0; ci < len; ++ci) { auto val = vec[ci]; TLLM_CHECK_WITH_INFO(val <= PREFIX_MAX_VALUE, "Provided choices have too large node degree (%d). Larger than Prefix can fit (%d).", val, PREFIX_MAX_VALUE); // Prefix has property that on the same depth, nodes that are more left in the tree have smaller prefix prefix |= (vec[ci] << PREFIX_CHUNK_SIZE_BITS * (len - 1 - ci)); } return prefix; } void MedusaModule::dumpChoices(MedusaChoices const& choices, std::vector const& indices) const { std::stringstream ss; ss << "Medusa choices = ["; for (size_t ci = 0; ci < indices.size(); ++ci) { auto const idx = indices[ci]; auto const& choice = choices[idx]; ss << "["; for (size_t vi = 0; vi < choice.size(); ++vi) { ss << choice[vi]; if (vi < choice.size() - 1) { ss << ", "; } } ss << "]"; if (ci < indices.size() - 1) { ss << ", "; } } ss << "]" << std::endl; TLLM_LOG_DEBUG(ss.str().c_str()); } } // namespace tensorrt_llm::runtime