/* * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "selectiveScanPlugin.h" #include "tensorrt_llm/common/assert.h" using namespace nvinfer1; using namespace tensorrt_llm::kernels; using namespace tensorrt_llm::common; using tensorrt_llm::plugins::SelectiveScanPluginCreator; using tensorrt_llm::plugins::SelectiveScanPlugin; static char const* SELECTIVE_SCAN_PLUGIN_VERSION{"1"}; static char const* SELECTIVE_SCAN_PLUGIN_NAME{"SelectiveScan"}; PluginFieldCollection SelectiveScanPluginCreator::mFC{}; std::vector SelectiveScanPluginCreator::mPluginAttributes; SelectiveScanPlugin::SelectiveScanPlugin(int dim, int dstate, int dtRank, int nHeads, int nGroups, int chunkSize, bool deltaSoftplus, nvinfer1::DataType type, bool removePadding, bool pagedState, bool zEnabled, bool isMamba2) : mDim(dim) , mDState(dstate) , mDtRank(dtRank) , mNHeads(nHeads) , mNGroups(nGroups) , mChunkSize(chunkSize) , mDeltaSoftplus(deltaSoftplus) , mType(type) , mRemovePadding(removePadding) , mPagedState(pagedState) , mZEnabled(zEnabled) , mIsMamba2(isMamba2) { TLLM_CHECK_WITH_INFO((getSMVersion() >= 80) || (!mIsMamba2), "Pre SM 80 GPUs do not support Mamba2"); TLLM_CHECK_WITH_INFO((getSMVersion() >= 80) || (mType != DataType::kBF16), "Unsupported data type, pre SM 80 GPUs do not support bfloat16"); TLLM_CHECK_WITH_INFO((mType == DataType::kBF16) || (mType == DataType::kFLOAT) || (mType == DataType::kHALF), "Only support float, half, and bfloat16."); } // Parameterized constructor SelectiveScanPlugin::SelectiveScanPlugin(void const* data, size_t length) { char const *d = reinterpret_cast(data), *a = d; read(d, mDim); read(d, mDState); read(d, mDtRank); read(d, mNHeads); read(d, mNGroups); read(d, mChunkSize); read(d, mDeltaSoftplus); read(d, mType); read(d, mRemovePadding); read(d, mPagedState); read(d, mZEnabled); read(d, mIsMamba2); TLLM_CHECK(d == a + length); TLLM_CHECK_WITH_INFO((getSMVersion() >= 80) || (mType != DataType::kBF16), "Unsupported data type"); TLLM_CHECK_WITH_INFO((mType == DataType::kBF16) || (mType == DataType::kFLOAT) || (mType == DataType::kHALF), "Only support float, half, and bfloat16."); } // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* SelectiveScanPlugin::clone() const noexcept { auto* plugin = new SelectiveScanPlugin(mDim, mDState, mDtRank, mNHeads, mNGroups, mChunkSize, mDeltaSoftplus, mType, mRemovePadding, mPagedState, mZEnabled, mIsMamba2); plugin->setPluginNamespace(mNamespace.c_str()); return plugin; } // Outputs // output_tensor: [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding // state: [batch_size, dstate, dim] nvinfer1::DimsExprs SelectiveScanPlugin::getOutputDimensions( int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { if (outputIndex == 0) { if (mIsMamba2) { auto ret = inputs[getInputTensorIdx()]; ret.d[mRemovePadding ? 1 : 2] = exprBuilder.constant(mDim); return ret; } else { return inputs[getInputTensorIdx()]; } } return inputs[getStateIdx()]; } bool SelectiveScanPlugin::supportsFormatCombination( int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { if (pos == getHostRequestTypesIdx() || pos == getLastTokenIdsIdx() || (mRemovePadding && pos == getHostContextLengthIdx()) || (mPagedState && pos == getSlotMappingIdx())) { return inOut[pos].type == nvinfer1::DataType::kINT32; } else if (pos == getAIdx() || pos == getDeltaBiasIdx() || pos == getDIdx()) { return (inOut[pos].type == nvinfer1::DataType::kFLOAT) && (inOut[pos].format == TensorFormat::kLINEAR); } else if (mPagedState && pos == getStateIdx()) { return inOut[pos].type == nvinfer1::DataType::kINT64; } else { return (inOut[pos].type == mType) && (inOut[pos].format == TensorFormat::kLINEAR); } } void SelectiveScanPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { } size_t SelectiveScanPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { if (!mIsMamba2) return 0; int const NUM_BUFFERS = 5; size_t workspaces[NUM_BUFFERS]; if (mRemovePadding) { int B = inputs[getLastTokenIdsIdx()].dims.d[0]; int BxL = inputs[getInputTensorIdx()].dims.d[0]; // num_tokens int H = mNHeads; int P = mDim / H; int G = mNGroups; int N = mDState; int Q = mChunkSize; int BxC = (BxL + Q - 1) / Q + B; workspaces[0] = BxC * H * N * P * 2; // g_mxOs_ workspaces[1] = BxC * H * N * P * 4; // g_mxSt_ in float workspaces[2] = BxC * H * Q * 4; // g_mxdc_ in float workspaces[3] = BxC * H * Q * 4; // g_mxdA_ in float workspaces[4] = BxC * G * Q * Q * 2; // g_mxCB_ } else { int B = inputs[getInputTensorIdx()].dims.d[0]; int L = inputs[getInputTensorIdx()].dims.d[1]; int H = mNHeads; int P = mDim / H; int G = mNGroups; int N = mDState; int Q = mChunkSize; int C = (L + Q - 1) / Q; workspaces[0] = B * C * H * N * P * 2; // g_mxOs_ workspaces[1] = B * C * H * N * P * 4; // g_mxSt_ in float workspaces[2] = B * C * H * Q * 4; // g_mxdc_ in float workspaces[3] = B * C * H * Q * 4; // g_mxdA_ in float workspaces[4] = B * C * G * Q * Q * 2; // g_mxCB_ } return calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS); } void SelectiveScanPlugin::setSSMParams(SSMParamsBase& params, const size_t batch, const size_t dim, const size_t maxSeqLen, const size_t dstate, const size_t dtRank, const size_t nHeads, const size_t nGroups, const size_t chunkSize, void* statePtr, void const* x, void const* delta, void const* deltaBias, void const* A, void const* BC, void const* D, void const* z, void const* osPtr, void const* stPtr, void const* dcPtr, void const* dAPtr, void const* cbPtr, int const* lastTokenIds, int const* slotMapping, void* out, bool deltaSoftplus, bool removePadding) { // Reset the parameters memset(¶ms, 0, sizeof(params)); params.batch = batch; params.dim = dim; params.max_seqlen = maxSeqLen; params.dstate = dstate; params.dt_rank = dtRank; params.nheads = nHeads; params.ngroups = nGroups; params.chunk_size = chunkSize; params.delta_softplus = deltaSoftplus; params.remove_padding = removePadding; params.is_mamab2 = mIsMamba2; // Set the pointers and strides. params.u_ptr = const_cast(x); params.delta_ptr = const_cast(delta); params.A_ptr = const_cast(A); params.BC_ptr = const_cast(BC); params.D_ptr = const_cast(D); params.delta_bias_ptr = const_cast(deltaBias); params.out_ptr = out; params.x_ptr = statePtr; params.z_ptr = const_cast(z); params.Os_ptr = const_cast(osPtr); params.St_ptr = const_cast(stPtr); params.dc_ptr = const_cast(dcPtr); params.dA_ptr = const_cast(dAPtr); params.CB_ptr = const_cast(cbPtr); params.last_token_ids_ptr = lastTokenIds; params.slot_mapping_ptr = slotMapping; } template int SelectiveScanPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) { // inputs // 0. input_tensor [batch_size, max_seq_len, dim] or [num_tokens, dim] // 1. state mamba: [batch_size, dstate, dim] or host [1] containing only pointer for paged_state // mamba2: [batch_size, nheads, dstate, dim] or host [1] containing only pointer for paged_state // 2. delta, mamba: [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding // mamba2: [batch_size, seq_len, nheads] or [num_tokens, nheads] for remove_input_padding // 3. delta_bias, [dim] for mamba, [nheads] for mamba2 // 4. A, [dstate, dim] for mamba, [nheads] for mamba2 // 5. BC, mamba: [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding // mamba2: [batch_size, seq_len, ngroups * dstate * 2] or [num_tokens, ngroups * dstate * 2] for // remove_input_padding // 6. D, [dim] for mamba, [nheads] for mamba2 // 7. host_request_types [batch_size] int32. 0: context; 1: generation. // 8. last_token_ids [batch_size] int32 // 9. host_context_lengths [batch_size] int32, optional for remove_input_padding // 10. state_slot_mapping [batch_size] int32, optional for paged state // 11. z [batch_size, max_seq_len, dim] or [num_tokens, dim] // outputs // 0. output_tensor [batch_size, max_seq_len, dim] or [num_tokens, dim] // 1. state, [batch_size, dstate, dim] for mamba, [batch_size, nheads, dstate, dim] for mamba2 auto const batch_size = inputDesc[getHostRequestTypesIdx()].dims.d[0]; int max_seq_len; if (mRemovePadding) { int const* host_context_length = static_cast(inputs[getHostContextLengthIdx()]); max_seq_len = *std::max_element(host_context_length, host_context_length + batch_size); } else { max_seq_len = inputDesc[getInputTensorIdx()].dims.d[1]; } // only support context or generation, not for both of them RequestType const* reqTypes = static_cast(inputs[getHostRequestTypesIdx()]); SSMParamsBase ssm_params; int const* slotMapping = mPagedState ? static_cast(inputs[getSlotMappingIdx()]) : nullptr; void const* z = mZEnabled ? inputs[getZIdx()] : nullptr; void* statePtr = mPagedState ? *reinterpret_cast(const_cast(inputs[getStateIdx()])) : outputs[1]; // Workspace pointer shift int8_t* workspace_byte_ptr = reinterpret_cast(workspace); size_t offset = 0; T* mxOs = nullptr; float* mxSt = nullptr; float* mxdc = nullptr; float* mxdA = nullptr; T* mxCB = nullptr; if (!mIsMamba2 || reqTypes[0] == RequestType::kGENERATION) /* no workspace needed */ ; else if (mRemovePadding) { int B = inputDesc[getLastTokenIdsIdx()].dims.d[0]; int BxL = inputDesc[getInputTensorIdx()].dims.d[0]; // num_tokens int H = mNHeads; int P = mDim / H; int G = mNGroups; int N = mDState; int Q = mChunkSize; int BxC = (BxL + Q - 1) / Q + B; mxOs = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, BxC * H * N * P * 2)); mxSt = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, BxC * H * N * P * 4)); mxdc = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, BxC * H * Q * 4)); mxdA = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, BxC * H * Q * 4)); mxCB = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, BxC * G * Q * Q * 2)); } else { int B = inputDesc[getInputTensorIdx()].dims.d[0]; int L = inputDesc[getInputTensorIdx()].dims.d[1]; int H = mNHeads; int P = mDim / H; int G = mNGroups; int N = mDState; int Q = mChunkSize; int C = (L + Q - 1) / Q; mxOs = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, B * C * H * N * P * 2)); mxSt = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, B * C * H * N * P * 4)); mxdc = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, B * C * H * Q * 4)); mxdA = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, B * C * H * Q * 4)); mxCB = reinterpret_cast(nextWorkspacePtr(workspace_byte_ptr, offset, B * C * G * Q * Q * 2)); } setSSMParams(ssm_params, batch_size, mDim, max_seq_len, mDState, mDtRank, mNHeads, mNGroups, mChunkSize, statePtr, inputs[getInputTensorIdx()], inputs[getDeltaIdx()], inputs[getDeltaBiasIdx()], inputs[getAIdx()], inputs[getBCIdx()], inputs[getDIdx()], z, mxOs, mxSt, mxdc, mxdA, mxCB, static_cast(inputs[getLastTokenIdsIdx()]), slotMapping, outputs[0], mDeltaSoftplus, mRemovePadding); if (reqTypes[0] == RequestType::kCONTEXT) { if (mIsMamba2) { invokeChunkScan(ssm_params, stream); } else { invokeSelectiveScan(ssm_params, stream); } } else if (reqTypes[0] == RequestType::kGENERATION) { invokeSelectiveScanUpdate(ssm_params, stream); } return 0; } int SelectiveScanPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { if (isBuilding()) { return 0; } if (mType == DataType::kHALF) { return enqueueImpl(inputDesc, outputDesc, inputs, outputs, workspace, stream); } else if (mType == DataType::kFLOAT) { return enqueueImpl(inputDesc, outputDesc, inputs, outputs, workspace, stream); } #ifdef ENABLE_BF16 else if (mType == DataType::kBF16) { return enqueueImpl<__nv_bfloat16>(inputDesc, outputDesc, inputs, outputs, workspace, stream); } #endif return 0; } // IPluginV2Ext Methods nvinfer1::DataType SelectiveScanPlugin::getOutputDataType( int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { if (index == 0) { return inputTypes[getInputTensorIdx()]; } else { return inputTypes[getStateIdx()]; } } // IPluginV2 Methods char const* SelectiveScanPlugin::getPluginType() const noexcept { return SELECTIVE_SCAN_PLUGIN_NAME; } char const* SelectiveScanPlugin::getPluginVersion() const noexcept { return SELECTIVE_SCAN_PLUGIN_VERSION; } int SelectiveScanPlugin::getNbOutputs() const noexcept { return mPagedState ? 1 : 2; } int SelectiveScanPlugin::initialize() noexcept { return 0; } void SelectiveScanPlugin::terminate() noexcept {} size_t SelectiveScanPlugin::getSerializationSize() const noexcept { return sizeof(mDim) + sizeof(mDState) + sizeof(mDtRank) + sizeof(mNHeads) + sizeof(mNGroups) + sizeof(mChunkSize) + sizeof(mDeltaSoftplus) + sizeof(mType) + sizeof(mRemovePadding) + sizeof(mPagedState) + sizeof(mZEnabled) + sizeof(mIsMamba2); } void SelectiveScanPlugin::serialize(void* buffer) const noexcept { char *d = static_cast(buffer), *a = d; write(d, mDim); write(d, mDState); write(d, mDtRank); write(d, mNHeads); write(d, mNGroups); write(d, mChunkSize); write(d, mDeltaSoftplus); write(d, mType); write(d, mRemovePadding); write(d, mPagedState); write(d, mZEnabled); write(d, mIsMamba2); assert(d == a + getSerializationSize()); } void SelectiveScanPlugin::destroy() noexcept { delete this; } /////////////// SelectiveScanPluginCreator::SelectiveScanPluginCreator() { // Fill PluginFieldCollection with PluginField arguments metadata mPluginAttributes.clear(); mPluginAttributes.emplace_back(PluginField("dim", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("dstate", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("dt_rank", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("nheads", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("ngroups", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("chunk_size", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("delta_softplus", nullptr, PluginFieldType::kINT8, 1)); mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("remove_input_padding", nullptr, PluginFieldType::kINT8, 1)); mPluginAttributes.emplace_back(PluginField("paged_state", nullptr, PluginFieldType::kINT8, 1)); mPluginAttributes.emplace_back(PluginField("z_enabled", nullptr, PluginFieldType::kINT8, 1)); mPluginAttributes.emplace_back(PluginField("is_mamba2", nullptr, PluginFieldType::kINT8, 1)); mFC.nbFields = mPluginAttributes.size(); mFC.fields = mPluginAttributes.data(); } char const* SelectiveScanPluginCreator::getPluginName() const noexcept { return SELECTIVE_SCAN_PLUGIN_NAME; } char const* SelectiveScanPluginCreator::getPluginVersion() const noexcept { return SELECTIVE_SCAN_PLUGIN_VERSION; } PluginFieldCollection const* SelectiveScanPluginCreator::getFieldNames() noexcept { return &mFC; } IPluginV2* SelectiveScanPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { PluginField const* fields = fc->fields; int dim, dstate, dtRank, nHeads, nGroups, chunkSize; bool deltaSoftplus, removePadding, pagedState, zEnabled, isMamab2; nvinfer1::DataType type; // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) { char const* attrName = fields[i].name; if (!strcmp(attrName, "dim")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); dim = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "dstate")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); dstate = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "dt_rank")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); dtRank = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "nheads")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); nHeads = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "ngroups")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); nGroups = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "chunk_size")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); chunkSize = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "delta_softplus")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); deltaSoftplus = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "type_id")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); type = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "remove_input_padding")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); removePadding = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "paged_state")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); pagedState = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "z_enabled")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); zEnabled = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "is_mamba2")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); isMamab2 = static_cast(*(static_cast(fields[i].data))); } } try { auto* obj = new SelectiveScanPlugin(dim, dstate, dtRank, nHeads, nGroups, chunkSize, deltaSoftplus, type, removePadding, pagedState, zEnabled, isMamab2); obj->setPluginNamespace(mNamespace.c_str()); return obj; } catch (std::exception const& e) { caughtError(e); } return nullptr; } IPluginV2* SelectiveScanPluginCreator::deserializePlugin( char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call SelectiveScanPlugin::destroy() try { auto* obj = new SelectiveScanPlugin(serialData, serialLength); obj->setPluginNamespace(mNamespace.c_str()); return obj; } catch (std::exception const& e) { caughtError(e); } return nullptr; }