diff --git a/.gitattributes b/.gitattributes index e72ba0fe7b..7486041ffd 100644 --- a/.gitattributes +++ b/.gitattributes @@ -7,3 +7,5 @@ triton_backend/tools/gpt/input_data.json filter=lfs diff=lfs merge=lfs -text *cubin.cpp filter=lfs diff=lfs merge=lfs -text docs/source/blogs/media/tech_blog3_mla_absorb.png filter=lfs diff=lfs merge=lfs -text +tests/integration/test_input_files/*.png filter=lfs diff=lfs merge=lfs -text +tests/integration/test_input_files/*.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index f6625a0559..745713e581 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ TensorRT-LLM [![python](https://img.shields.io/badge/python-3.10-green)](https://www.python.org/downloads/release/python-31012/) [![cuda](https://img.shields.io/badge/cuda-12.9.1-green)](https://developer.nvidia.com/cuda-downloads) [![trt](https://img.shields.io/badge/TRT-10.11.0-green)](https://developer.nvidia.com/tensorrt) -[![version](https://img.shields.io/badge/release-1.1.0rc1-green)](./tensorrt_llm/version.py) +[![version](https://img.shields.io/badge/release-1.1.0rc2-green)](./tensorrt_llm/version.py) [![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE) [Architecture](./docs/source/torch/arch_overview.md)   |   [Performance](./docs/source/performance/perf-overview.md)   |   [Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html)   |   [Documentation](./docs/source/)   |   [Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap) @@ -18,10 +18,9 @@ TensorRT-LLM
## Tech Blogs -* [08/06] Running a High Performance GPT-OSS-120B Inference Server with TensorRT-LLM +* [08/05] Running a High-Performance GPT-OSS-120B Inference Server with TensorRT-LLM ✨ [➡️ link](./docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md) - * [08/01] Scaling Expert Parallelism in TensorRT-LLM (Part 2: Performance Status and Optimization) ✨ [➡️ link](./docs/source/blogs/tech_blog/blog8_Scaling_Expert_Parallelism_in_TensorRT-LLM_part2.md) @@ -44,6 +43,7 @@ TensorRT-LLM ✨ [➡️ link](./docs/source/blogs/tech_blog/blog1_Pushing_Latency_Boundaries_Optimizing_DeepSeek-R1_Performance_on_NVIDIA_B200_GPUs.md) ## Latest News +* [08/05] 🌟 TensorRT-LLM delivers Day-0 support for OpenAI's latest open-weights models: GPT-OSS-120B [➡️ link](https://huggingface.co/openai/gpt-oss-120b) and GPT-OSS-20B [➡️ link](https://huggingface.co/openai/gpt-oss-20b) * [07/15] 🌟 TensorRT-LLM delivers Day-0 support for LG AI Research's latest model, EXAONE 4.0 [➡️ link](https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B) * [06/17] Join NVIDIA and DeepInfra for a developer meetup on June 26 ✨ [➡️ link](https://events.nvidia.com/scaletheunscalablenextgenai) * [05/22] Blackwell Breaks the 1,000 TPS/User Barrier With Meta’s Llama 4 Maverick @@ -253,5 +253,5 @@ Deprecation is used to inform developers that some APIs and tools are no longer ## Useful Links - [Quantized models on Hugging Face](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4): A growing collection of quantized (e.g., FP8, FP4) and optimized LLMs, including [DeepSeek FP4](https://huggingface.co/nvidia/DeepSeek-R1-FP4), ready for fast inference with TensorRT-LLM. - [NVIDIA Dynamo](https://github.com/ai-dynamo/dynamo): A datacenter scale distributed inference serving framework that works seamlessly with TensorRT-LLM. -- [AutoDeploy](./examples/auto_deploy/README.md): A prototype backend for TensorRT-LLM to simplify and accelerate the deployment of PyTorch models. +- [AutoDeploy](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html): A prototype backend for TensorRT-LLM to simplify and accelerate the deployment of PyTorch models. - [WeChat Discussion Group](https://github.com/NVIDIA/TensorRT-LLM/issues/5359): A real-time channel for TensorRT-LLM Q&A and news. diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 4f9d5f0f22..198fd22a71 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -69,7 +69,7 @@ add_compile_definitions("TLLM_GEN_EXPORT_INTERFACE") add_compile_definitions("TLLM_ENABLE_CUDA") set(BINDING_TYPE - "pybind" + "nanobind" CACHE STRING "Binding type of Python bindings for C++ runtime and batch manager") diff --git a/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h b/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h index 394f7fb7bf..0978905b5e 100644 --- a/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h +++ b/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h @@ -24,7 +24,6 @@ #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/modelConfig.h" -#include "tensorrt_llm/runtime/request.h" #include "tensorrt_llm/runtime/worldConfig.h" namespace tensorrt_llm::runtime @@ -88,37 +87,6 @@ public: SizeType32 maxSequenceLength, OptionalRef medusaBuffers) const; private: - //! @brief Setups decoder internal tensors for new speculative decoding request - static void newRequestSpeculativeDecoding(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig, - DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream, - CudaStream const& decoderStream, SpeculativeDecodingMode const& speculativeDecodingMode, - SizeType32 maxDecodingEngineTokens); - - //! @brief Setups decoder internal tensors for new request in Draft model Sps mode - static void newRequestDraftTokensExternal(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - SamplingConfig const& samplingConfig, DecodingInput& jointDecodingInput, CudaStream const& decoderStream); - - //! @brief Setups decoder internal tensors for new Medusa request - static void newRequestMedusa(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - DecodingInput& jointDecodingInput, CudaStream const& decoderStream, SizeType32 maxDecodingEngineTokens); - - //! @brief Setups decoder internal tensors for new Lookahead request - static void newRequestLookahead(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream); - - //! @brief Setups decoder internal tensors for new Explicit draft tokens request - static void newRequestExplicitDraftTokens(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream); - - //! @brief Setups decoder internal tensors for new Eagle request - static void newRequestEagle(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - runtime::ModelConfig const& modelConfig, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream); - - [[nodiscard]] std::shared_ptr retrieveDraftLogits(runtime::ModelConfig const& modelConfig, - runtime::WorldConfig const& worldConfig, std::shared_ptr const& tensor, - runtime::BufferManager const& bufferManager) const; - bool mSpeculativeDecodingFastLogits; bool mIsLeaderInOrchMode; bool mIsNormalizeLogProbs; diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index e4d13c9e17..f069e3ac7f 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -1110,7 +1110,7 @@ public: [[nodiscard]] SizeType32 getNumDraftTokens() const { - return mDraftTokens->size(); + return hasDraftTokens() ? mDraftTokens->size() : 0; } void discardDraftTokens(SizeType32 numTokensToDiscard) diff --git a/cpp/include/tensorrt_llm/common/logger.h b/cpp/include/tensorrt_llm/common/logger.h index df84e22638..c8164b10e5 100644 --- a/cpp/include/tensorrt_llm/common/logger.h +++ b/cpp/include/tensorrt_llm/common/logger.h @@ -54,20 +54,21 @@ public: #if defined(_MSC_VER) template - void log(Level level, char const* format, Args const&... args); + void log(Level const level, char const* format, Args const&... args); template - void log(Level level, int rank, char const* format, Args const&... args); + void log(Level const level, int const rank, char const* format, Args const&... args); #else template - void log(Level level, char const* format, Args const&... args) __attribute__((format(printf, 3, 0))); + void log(Level const level, char const* format, Args const&... args) __attribute__((format(printf, 3, 0))); template - void log(Level level, int rank, char const* format, Args const&... args) __attribute__((format(printf, 4, 0))); + void log(Level const level, int const rank, char const* format, Args const&... args) + __attribute__((format(printf, 4, 0))); #endif template - void log(Level level, std::string const& format, Args const&... args) + void log(Level const level, std::string const& format, Args const&... args) { return log(level, format.c_str(), args...); } @@ -134,7 +135,7 @@ private: }; template -void Logger::log(Logger::Level level, char const* format, Args const&... args) +void Logger::log(Logger::Level const level, char const* format, Args const&... args) { if (isEnabled(level)) { diff --git a/cpp/include/tensorrt_llm/executor/dataTransceiverState.h b/cpp/include/tensorrt_llm/executor/dataTransceiverState.h index 2a2f1f8369..98b26a276c 100644 --- a/cpp/include/tensorrt_llm/executor/dataTransceiverState.h +++ b/cpp/include/tensorrt_llm/executor/dataTransceiverState.h @@ -52,29 +52,30 @@ public: AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2) : mModelConfig(std::move(modelConfig)) , mParallelConfig{worldConfig.getTensorParallelism(), worldConfig.getPipelineParallelism(), - worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(), worldConfig.getTensorParallelism()} + worldConfig.getContextParallelism(), worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(), + worldConfig.getTensorParallelism()} , mDataType{dataType} , mAttentionConfig(attentionType, kvFactor) { } CacheState(std::vector nbKvHeadPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock, - SizeType32 tensorParallelism, SizeType32 pipelineParallelism, nvinfer1::DataType dataType, - AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false, - int DPrank = 0, int DPsize = 0) + SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism, + nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, + bool enableAttentionDP = false, int DPrank = 0, int DPsize = 0) : mModelConfig{std::move(nbKvHeadPerLayer), sizePerHead, tokensPerBlock} - , mParallelConfig{tensorParallelism, pipelineParallelism, enableAttentionDP, DPrank, DPsize} + , mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize} , mDataType{dataType} , mAttentionConfig(attentionType, kvFactor) { } CacheState(SizeType32 nbAttentionLayers, SizeType32 nbKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock, - SizeType32 tensorParallelism, SizeType32 pipelineParallelism, nvinfer1::DataType dataType, - AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false, - int DPrank = 0, int DPsize = 0) + SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism, + nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, + bool enableAttentionDP = false, int DPrank = 0, int DPsize = 0) : mModelConfig{std::vector(nbAttentionLayers, nbKvHeads), sizePerHead, tokensPerBlock} - , mParallelConfig{tensorParallelism, pipelineParallelism, enableAttentionDP, DPrank, DPsize} + , mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize} , mDataType{dataType} , mAttentionConfig(attentionType, kvFactor) { @@ -83,7 +84,7 @@ public: [[nodiscard]] bool operator==(kv_cache::CacheState const& other) const noexcept { return mModelConfig == other.mModelConfig && mParallelConfig == other.mParallelConfig - && mDataType == other.mDataType; + && mAttentionConfig == other.mAttentionConfig && mDataType == other.mDataType; } struct ModelConfig @@ -103,6 +104,7 @@ public: { SizeType32 mTensorParallelism; SizeType32 mPipelineParallelism; + SizeType32 mContextParallelism; bool mEnableAttentionDP; SizeType32 mDPrank; SizeType32 mDPsize; @@ -110,8 +112,8 @@ public: [[nodiscard]] bool operator==(ParallelConfig const& other) const noexcept { return mTensorParallelism == other.mTensorParallelism && mPipelineParallelism == other.mPipelineParallelism - && mEnableAttentionDP == other.mEnableAttentionDP && mDPrank == other.mDPrank - && mDPsize == other.mDPsize; + && mContextParallelism == other.mContextParallelism && mEnableAttentionDP == other.mEnableAttentionDP + && mDPrank == other.mDPrank && mDPsize == other.mDPsize; } }; @@ -125,6 +127,11 @@ public: { } + [[nodiscard]] bool operator==(AttentionConfig const& other) const noexcept + { + return mAttentionType == other.mAttentionType && mKvFactor == other.mKvFactor; + } + // attentionType ; AttentionType mAttentionType; int mKvFactor; @@ -162,6 +169,7 @@ public: sstring << "mTokensPerBlock:" << mModelConfig.mTokensPerBlock << "\n"; sstring << "tp:" << mParallelConfig.mTensorParallelism << "\n"; sstring << "pp:" << mParallelConfig.mPipelineParallelism << "\n"; + sstring << "cp:" << mParallelConfig.mContextParallelism << "\n"; sstring << "enableAttentionDP:" << mParallelConfig.mEnableAttentionDP << "\n"; sstring << "datatype:" << static_cast(mDataType) << "\n"; sstring << "attentionType:" << static_cast(mAttentionConfig.mAttentionType) << "\n"; diff --git a/cpp/include/tensorrt_llm/runtime/decodingInput.h b/cpp/include/tensorrt_llm/runtime/decodingInput.h index deeb0fa0af..4344f423ac 100644 --- a/cpp/include/tensorrt_llm/runtime/decodingInput.h +++ b/cpp/include/tensorrt_llm/runtime/decodingInput.h @@ -102,11 +102,13 @@ public: { public: TensorPtr draftLogits; + TensorPtr draftLogitsHost; TensorPtr draftProbs; TensorPtr targetProbs; TensorPtr numDraftTokens; TensorPtr numDraftTokensHost; TensorPtr draftTokenIds; + TensorPtr draftTokenIdsHost; TensorPtr useDraftLogits; TensorPtr useDraftLogitsHost; diff --git a/cpp/include/tensorrt_llm/runtime/request.h b/cpp/include/tensorrt_llm/runtime/request.h deleted file mode 100644 index e8f851b7d7..0000000000 --- a/cpp/include/tensorrt_llm/runtime/request.h +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "tensorrt_llm/executor/executor.h" -#include "tensorrt_llm/runtime/iTensor.h" - -#include - -namespace tensorrt_llm::runtime::decoder_batch -{ - -class Request -{ -public: - using TensorConstPtr = ITensor::SharedConstPtr; - using TensorPtr = ITensor::SharedPtr; - using BufferPtr = IBuffer::SharedPtr; - - explicit Request(SizeType32 inputLen) - : inputLen(inputLen) - { - } - - //! Mandatory parameters - SizeType32 inputLen; // Input length without draft tokens, increasing with generation steps - - // optional parameters - SizeType32 generatedTokensPerEngineStep{1}; // - - //! Optional parameters for speculative decoding - BufferPtr draftTokens; // [generatedTokensPerEngineStep - 1] on gpu - std::optional draftLogits; // [generatedTokensPerEngineStep - 1, vocabSize] on gpu - TensorPtr medusaPaths; // [maxDecodingTokens, maxPathLen], on gpu - TensorPtr medusaTreeIds; // [maxDecodingTokens], on gpu - std::optional lookaheadRuntimeConfig; - std::optional eagleConfig; -}; - -} // namespace tensorrt_llm::runtime::decoder_batch diff --git a/cpp/kernels/xqa/mha_sm90.cu b/cpp/kernels/xqa/mha_sm90.cu index 9a438df9a2..da44fba60c 100644 --- a/cpp/kernels/xqa/mha_sm90.cu +++ b/cpp/kernels/xqa/mha_sm90.cu @@ -1012,7 +1012,7 @@ CUBIN_EXPORT __global__ if (threadIdx.x < smem.gemm1AccColMax.size) { auto const idx = threadIdx.x; - smem.gemm1AccColMax[idx] = mha::numeric_limits::lowest(); + smem.gemm1AccColMax[idx] = safeInitRowMax; smem.gemm1AccColSum[idx] = 0; } smem.gemm1WarpGrpBar.arrive_and_wait(); @@ -1949,7 +1949,7 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec, uint32_t const globalRow = tileStartRow + row; if (globalRow >= cacheSeqLen) { - acc(m, n)(i, j) = mha::numeric_limits::lowest(); + acc(m, n)(i, j) = safeInitRowMax; continue; } if (globalRow >= maskStartRow) @@ -1957,7 +1957,7 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec, uint32_t const maskRow = globalRow - maskStartRow; if ((bit_mask >> maskRow) == 0) { - acc(m, n)(i, j) = mha::numeric_limits::lowest(); + acc(m, n)(i, j) = safeInitRowMax; } } } @@ -2087,7 +2087,7 @@ __device__ inline void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32 #pragma unroll for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { - acc(m, n)(i, j) = mha::numeric_limits::lowest(); + acc(m, n)(i, j) = safeInitRowMax; } } } @@ -2380,9 +2380,9 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec, { uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j; assert((col < nbValidCols) == bool(endMask & (1ULL << col))); - if (((mask >> col) & 1) == 0) + if ((mask & (1ULL << col)) == 0) { - acc(m, n)(i, j) = mha::numeric_limits::lowest(); + acc(m, n)(i, j) = safeInitRowMax; } } } @@ -2410,7 +2410,7 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uin #pragma unroll for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { - acc(m, n)(i, j) = mha::numeric_limits::lowest(); + acc(m, n)(i, j) = safeInitRowMax; } } } diff --git a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h index 2559ae5484..36cbe76544 100644 --- a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h +++ b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h @@ -833,7 +833,7 @@ public: // Runs for 3 iterations or 1 second and picks the best option int pickBestTactic(MOEParallelismConfig parallelism_config, GemmToProfile gemm_to_profile) { - auto tactics = mMoERunner.getTactics(); + auto tactics = mMoERunner.getTactics(static_cast(gemm_to_profile)); ::nvtx3::scoped_range nvtx(tensorrt_llm::common::nvtx::nextColor(), "Tactic Profiling GEMM " + std::to_string(static_cast(gemm_to_profile))); // We save space by reusing the same workspace buffer for all tactics when doing full layer profiling. So we @@ -925,12 +925,14 @@ public: std::pair setTactic( int tactic_idx1, int tactic_idx2, MOEParallelismConfig parallelism_config, GemmToProfile gemm_to_profile) { - auto tactics = mMoERunner.getTactics(); + auto tactics1 = mMoERunner.getTactics(MoeGemmId::GEMM_1); + auto tactics2 = mMoERunner.getTactics(MoeGemmId::GEMM_2); std::vector, GemmToProfile>> tactics_to_profile{ {tactic_idx1, GemmToProfile::GEMM_1}, {tactic_idx2, GemmToProfile::GEMM_2}}; for (auto& combo : tactics_to_profile) { auto& t = combo.first.get(); + auto& tactics = combo.second == GemmToProfile::GEMM_1 ? tactics1 : tactics2; if (combo.second != gemm_to_profile && gemm_to_profile != GemmToProfile::LAYER) { t = 0; // Unneeded tactic, set to 0 @@ -947,7 +949,7 @@ public: } } - mMoERunner.setTactic(tactics[tactic_idx1], tactics[tactic_idx2]); + mMoERunner.setTactic(tactics1[tactic_idx1], tactics2[tactic_idx2]); mBestTacticGemm1 = tactic_idx1; mBestTacticGemm2 = tactic_idx2; return {tactic_idx1, tactic_idx2}; @@ -965,7 +967,7 @@ public: auto expert_weights_size = gemm_to_profile == GemmToProfile::GEMM_1 ? mExpertWeight1Size : mExpertWeight2Size; - auto tactics = mMoERunner.getTactics()[tactic_idx]; + auto tactics = mMoERunner.getTactics(static_cast(gemm_to_profile))[tactic_idx]; if (static_cast(gemm_to_profile) != static_cast(mGemmProfilerBackend.mGemmToProfile)) { throw std::runtime_error("Configuration mismatch between mGemmProfilerBackend and runMoEPermute"); @@ -1074,11 +1076,12 @@ void MixtureOfExpertsBenchmark::runBenchmark(benchmark::State& state } if (LOG_LEVEL >= INFO) { - auto tactics = mMoERunner.getTactics(); - std::cout << "Selected tactic #1: " << tactic_idx1 << "/" << tactics.size() << "\n" - << tactics[tactic_idx1].toString() << std::endl; - std::cout << "Selected tactic #2: " << tactic_idx2 << "/" << tactics.size() << "\n" - << tactics[tactic_idx2].toString() << std::endl; + auto tactics1 = mMoERunner.getTactics(MoeGemmId::GEMM_1); + auto tactics2 = mMoERunner.getTactics(MoeGemmId::GEMM_2); + std::cout << "Selected tactic #1: " << tactic_idx1 << "/" << tactics1.size() << "\n" + << tactics1[tactic_idx1].toString() << std::endl; + std::cout << "Selected tactic #2: " << tactic_idx2 << "/" << tactics2.size() << "\n" + << tactics2[tactic_idx2].toString() << std::endl; } state.counters["tactic_idx1"] = tactic_idx1; state.counters["tactic_idx2"] = tactic_idx2; diff --git a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu index b784c6d0bc..8e18694ad7 100644 --- a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu +++ b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu @@ -42,148 +42,15 @@ struct WeightParams ->Apply(argGen>>) template -auto listAllTactics() +auto listAllTactics(MoeGemmId gemm_id) { int const sm = getSMVersion(); using RunnerType = decltype(BenchClass::mMoERunner); - return RunnerType::getTactics(sm); + return RunnerType::getTactics(sm, gemm_id); } template -int parseTacticToId(nlohmann::json tactic_config) -{ - bool is_tma_warp_specialized = tactic_config.at("is_tma_warp_specialized").get(); - int tile_shape_id = -1; - std::array tile_shape; - if (tactic_config.at("tile_shape").is_array()) - tactic_config.at("tile_shape").get_to(tile_shape); - else - tile_shape_id = tactic_config.at("tile_shape").get(); - - std::vector confs = listAllTactics(); - - try - { - for (int i = 0; i < confs.size(); i++) - { - auto const& c = confs[i]; - if (c.is_tma_warp_specialized != is_tma_warp_specialized) - continue; - - if (!is_tma_warp_specialized) - { - int stages = tactic_config.at("stages").get(); - if (c.stages != stages) - continue; - } - - if (tile_shape_id != -1) - { - int comp = c.getTileConfigAsInt(); - if (tile_shape_id != comp) - continue; - if (is_tma_warp_specialized && (int) c.cluster_shape != tactic_config.at("cluster_shape").get()) - continue; - - // Found matching config - return i; - } - - // Handle if the user provided a shape instead of the enum value - if (is_tma_warp_specialized) - { - // TODO Add cases for blackwell shapes - using Kv = uint64_t; - constexpr static auto K = [](int m, int n) { return (uint64_t(m) << 32) | uint64_t(n); }; - static std::unordered_map const tile_map{ - {K(64, 16), CutlassTileConfigSM90::CtaShape64x16x128B}, - {K(64, 32), CutlassTileConfigSM90::CtaShape64x32x128B}, - {K(64, 64), CutlassTileConfigSM90::CtaShape64x64x128B}, - {K(64, 128), CutlassTileConfigSM90::CtaShape64x128x128B}, - {K(64, 256), CutlassTileConfigSM90::CtaShape64x256x128B}, - - {K(128, 16), CutlassTileConfigSM90::CtaShape128x16x128B}, - {K(128, 32), CutlassTileConfigSM90::CtaShape128x32x128B}, - {K(128, 64), CutlassTileConfigSM90::CtaShape128x64x128B}, - {K(128, 128), CutlassTileConfigSM90::CtaShape128x128x128B}, - {K(128, 256), CutlassTileConfigSM90::CtaShape128x256x128B}, - {K(256, 128), CutlassTileConfigSM90::CtaShape256x128x128B}, - }; - - if (c.getTileConfigAsInt() != (int) tile_map.at(K(tile_shape[0], tile_shape[1]))) - continue; - - static std::unordered_map const cluster_map{ - // CTA configs for M=64 - {K(1, 1), ClusterShape::ClusterShape_1x1x1}, - {K(2, 1), ClusterShape::ClusterShape_2x1x1}, - {K(1, 2), ClusterShape::ClusterShape_1x2x1}, - {K(2, 2), ClusterShape::ClusterShape_2x2x1}, - }; - - std::array cluster_shape; - tactic_config.at("cluster_shape").get_to(cluster_shape); - - if (c.cluster_shape != cluster_map.at(K(cluster_shape[0], cluster_shape[1]))) - continue; - - // Found matching config - return i; - } - else - { - std::array warp_shape; - tactic_config.at("warp_shape").get_to(warp_shape); - - using Kv = uint64_t; - constexpr static auto K = [](std::array a, std::array b) - { - uint64_t sum = 0; - for (auto v : a) - sum = sum * 512 + v; - for (auto v : b) - sum = sum * 256 + v; - return sum; - }; - static std::unordered_map tile_map{ - {K({128, 128, 8}, {64, 64, 8}), CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}, - - {K({16, 128, 64}, {16, 32, 64}), CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64}, - {K({32, 128, 64}, {32, 32, 64}), CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64}, - - {K({64, 128, 64}, {32, 64, 64}), CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64}, - {K({64, 64, 128}, {32, 64, 64}), CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64}, - {K({64, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64}, - - {K({128, 64, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64}, - {K({128, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64}, - {K({128, 128, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64}, - {K({128, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}, - {K({128, 256, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64}, - - {K({256, 128, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}, - - {K({16, 256, 64}, {16, 64, 64}), CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64} - - }; - if (c.tile_config_sm80 != tile_map.at(K(tile_shape, warp_shape))) - continue; - - // Found matching config - return i; - } - } - } - catch (std::out_of_range const& e) - { - std::cerr << "Warning: error parsing tactic " << tactic_config.dump(2) << std::endl; - } - - return -1; -} - -template -void parseTacticToVectorID(nlohmann::json& tactic, std::vector& tactic_ids) +void parseTacticToVectorID(nlohmann::json& tactic, std::vector& tactic_ids, MoeGemmId gemm_id) { if (tactic.is_number_integer()) { @@ -193,20 +60,16 @@ void parseTacticToVectorID(nlohmann::json& tactic, std::vector& tactic_ids) { for (auto c : tactic) { - parseTacticToVectorID(c, tactic_ids); + parseTacticToVectorID(c, tactic_ids, gemm_id); } } - else if (tactic.is_object()) - { - tactic_ids.push_back(parseTacticToId(tactic)); - } else if (tactic.is_string()) { assert(tactic.is_string()); auto tactic_name = tactic.get(); if (tactic_name == "all") { - auto all_tactics = listAllTactics(); + auto all_tactics = listAllTactics(gemm_id); tactic_ids.resize(all_tactics.size()); std::iota(tactic_ids.begin(), tactic_ids.end(), 0); } @@ -410,39 +273,15 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark) } // Do this after filtering datatypes as tactics only make sense if we know the data type - bool has_tactic_ids2 = false; std::vector tactic_ids1{}; std::vector tactic_ids2{}; - if (run_config.contains("tactic_id1") || run_config.contains("tactic_id2")) + if (run_config.contains("tactic_id1")) { - if (run_config.contains("tactic_id")) - { - throw std::invalid_argument("Cannot use tactic_id and tactic_idX"); - } - has_tactic_ids2 = true; - parseTacticToVectorID(run_config["tactic_id1"], tactic_ids1); - parseTacticToVectorID(run_config["tactic_id2"], tactic_ids2); + parseTacticToVectorID(run_config["tactic_id1"], tactic_ids1, MoeGemmId::GEMM_1); } - else + if (run_config.contains("tactic_id2")) { - parseTacticToVectorID(run_config["tactic_id"], tactic_ids1); - has_tactic_ids2 = false; - tactic_ids2.resize(1); // Dummy value so we loop exactly once below - } - if (tactic_ids1.empty() || tactic_ids2.empty()) - { - std::cerr << "Warning: Skipping benchmark, no valid tactic found" << std::endl; - static bool printed = false; - if (!printed) - { - printed = true; - std::cerr << __PRETTY_FUNCTION__ << ": Valid Tactics are:\n"; - auto confs = listAllTactics(); - for (auto c : confs) - std::cerr << c.toString(); - } - - continue; + parseTacticToVectorID(run_config["tactic_id2"], tactic_ids2, MoeGemmId::GEMM_2); } auto get_or = [&](auto name, auto def) @@ -478,8 +317,6 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark) } else if (gemm_to_profile == (int) GemmToProfile::GEMM_2) { - if (!has_tactic_ids2) - tactic_ids2 = std::move(tactic_ids1); tactic_ids1 = {-1}; } } @@ -494,14 +331,31 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark) return val; }; + if (tactic_ids1.empty() || tactic_ids2.empty()) + { + std::cerr << "Warning: Skipping benchmark, no valid tactic found" << std::endl; + static bool printed = false; + if (!printed) + { + printed = true; + std::cerr << __PRETTY_FUNCTION__ << ": Valid Tactics are:\n"; + for (auto gemm_id : {MoeGemmId::GEMM_1, MoeGemmId::GEMM_2}) + { + std::cerr << "GEMM " << (int) gemm_id << ":\n"; + auto confs = listAllTactics(gemm_id); + for (auto c : confs) + std::cerr << c.toString(); + std::cerr << std::endl; + } + } + + continue; + } + for (auto t1 : tactic_ids1) { - // tactic_ids2 will have one dummy value if has_tactic_ids2 = false for (auto t2 : tactic_ids2) { - if (!has_tactic_ids2) - t2 = t1; - benchmark->Args({num_experts, // get_range("k"), // get_range("hidden_size"), // @@ -531,7 +385,7 @@ void argGenHardcoded(benchmark::internal::Benchmark* benchmark) // {ActivationType::Relu, ActivationType::Gelu, // ActivationType::Silu, ActivationType::Geglu, // ActivationType::Swiglu}; - auto cutlass_tactic = {-1}; // {0,..., listAllTactics().size()}; + auto cutlass_tactic = {-1}; // {0,..., listAllTactics(MoeGemmId).size()}; auto routing_config = {LOAD_BALANCED_ROUTING_CONFIG}; // {0, 1, 2}; for (auto num_expert : num_experts) @@ -558,14 +412,18 @@ void argGen(benchmark::internal::Benchmark* benchmark) { if (LOG_LEVEL >= VERBOSE) { - std::cout << "List of all tactics for dtype " << (int) BenchClass::toDTypeID() << ":\n"; - int i = 0; - for (auto& t : listAllTactics()) + std::cout << "== List of all tactics for dtype " << (int) BenchClass::toDTypeID() << " ==\n"; + for (auto gemm_id : {MoeGemmId::GEMM_1, MoeGemmId::GEMM_2}) { - std::cout << "Tactic " << i << ":\n"; - std::cout << t.toString() << std::endl; + int i = 0; + std::cout << "=== GEMM " << (int) gemm_id << " ===\n"; + for (auto& t : listAllTactics(gemm_id)) + { + std::cout << "==== Tactic " << i << " ====\n"; + std::cout << t.toString() << std::endl; - i++; + i++; + } } } @@ -652,7 +510,6 @@ void help() " \"bias\": int, (optional)\n" " \"do_final_scale\": int, (optional)\n" " \"act_fn\": int,\n" - " \"tactic_id\": tactic, (see below)\n" " \"tactic_id1\": tactic, (see below)\n" " \"tactic_id2\": tactic, (see below)\n" " \"dtypes\": [string, ...], (optional)\n" @@ -676,27 +533,14 @@ void help() "- \"do_final_scale\" - If final scales should be applied, 0 = no scale, 1 = scale\n" "- \"act_fn\" - The activation function to use, 0 = identity, 1 = relu, 2 = gelu, 3 = silu, 4 = geglu, 5 = " "swiglu\n" - "- \"tactic_id, tactic_id1, tactic_id2\"\n" - "The config for the CUTLASS GEMM. tactic_id sets the same tactic for both to the same tactic (except in " - "auto mode)\n" - "Use tactic_idX to set the tactic for the corresponding GEMM" + "- \"tactic_id1, tactic_id2\"\n" + "The config for the CUTLASS GEMM. tactic_idX sets the tactic for the corresponding GEMM" "Valid tactics are:\n" - " - An object:\n" - " {\n" - " \"is_tma_warp_specialized\": bool,\n" - " \"tile_shape\": [int, int, int] or int,\n" - " \"cluster_shape\": [int, int, int] or int, (required for sm90, type must be an int if tile_shape " - "is " - "an int)\n" - " \"warp_shape\": [int, int, int], (required for non-sm90 if tile_shape is an array)\n" - " \"stages\": int, (required for non-sm90)\n" - " },\n" - " - An integer: corresponds to an index in the tactics array. WARNING this is not stable between test " - "configurations\n" - " - An array: of integers or objects, forms a list of tactics to sweep\n" + " - An integer: corresponds to an index in the tactics array. WARNING this is not stable between data types " + "or GPU architectures\n" + " - An array: of integers, forms a list of tactics to sweep\n" " - The string \"all\": This will sweep through all possible tactics\n" - " - The string \"auto\": This runs a short benchmark to pick the fastest tactic before each benchmark " - "case. " + " - The string \"auto\": This runs a short benchmark to pick the fastest tactic before each benchmark case. " "Useful for quick perf tests, prefer a full sweep and manually setting the tactic for more accurate " "results" "- dtypes - A list of dtypes to run this config through.\n" diff --git a/cpp/tensorrt_llm/CMakeLists.txt b/cpp/tensorrt_llm/CMakeLists.txt index c4814c1d4e..2e625f4687 100644 --- a/cpp/tensorrt_llm/CMakeLists.txt +++ b/cpp/tensorrt_llm/CMakeLists.txt @@ -294,8 +294,7 @@ if(TARGET ${NIXL_WRAPPER_TARGET}) endif() if(NOT WIN32) - set_target_properties(${SHARED_TARGET} PROPERTIES LINK_FLAGS - "-Wl,-rpath='$ORIGIN'") + set_target_properties(${SHARED_TARGET} PROPERTIES BUILD_RPATH "$ORIGIN") endif() if(BUILD_PYT) diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index 503c2e6c5d..e73e0f1541 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -822,6 +822,14 @@ void CacheFormatter::unformat(TransferSession& session) TLLM_LOG_WARNING("CacheFormatter::inquireSupport: only support non-MLA"); return false; } + if (selfConfig.getParallelConfig().mContextParallelism != 1 + || destConfig.getParallelConfig().mContextParallelism != 1) + { + TLLM_LOG_WARNING( + "CacheFormatter::inquireSupport: context parallelism is not currently supported (selfCP=%d, destCP=%d).", + selfConfig.getParallelConfig().mContextParallelism, destConfig.getParallelConfig().mContextParallelism); + return false; + } std::unordered_set setVecDest{ destConfig.getModelConfig().mNbKvHeadsPerLayer.begin(), destConfig.getModelConfig().mNbKvHeadsPerLayer.end()}; diff --git a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp index 16771709bb..3335d69a01 100644 --- a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp +++ b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp @@ -20,11 +20,14 @@ #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/batch_manager/medusaBuffers.h" #include "tensorrt_llm/batch_manager/utils/logitsThread.h" +#include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/nvtxUtils.h" +#include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/decoderState.h" #include "tensorrt_llm/runtime/decodingInput.h" #include "tensorrt_llm/runtime/decodingOutput.h" +#include "tensorrt_llm/runtime/iBuffer.h" #include "tensorrt_llm/runtime/runtimeKernels.h" #include "tensorrt_llm/runtime/speculativeDecodingMode.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" @@ -45,6 +48,8 @@ namespace tensorrt_llm::batch_manager using SizeType32 = CreateNewDecoderRequests::SizeType32; using TensorPtr = CreateNewDecoderRequests::TensorPtr; using SharedConstPtr = CreateNewDecoderRequests::SharedConstPtr; +template +using OptionalRef = tensorrt_llm::common::OptionalRef; namespace { @@ -320,149 +325,165 @@ void initializeOutputs(DecodingOutput& dJointOutput, SizeType32 batchSlot, SizeT TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -} // namespace - -void CreateNewDecoderRequests::newRequestSpeculativeDecoding(SizeType32 batchIdx, - runtime::decoder_batch::Request const& request, SamplingConfig const& samplingConfig, - runtime::ModelConfig const& modelConfig, DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, - CudaStream const& runtimeStream, CudaStream const& decoderStream, - SpeculativeDecodingMode const& speculativeDecodingMode, SizeType32 maxDecodingEngineTokens) +void retrieveDraftLogits(TensorPtr& draftLogitsHost, std::shared_ptr const& reqDraftLogits, + ModelConfig const& modelConfig, WorldConfig const& worldConfig, bool speculativeDecodingFastLogits, + bool isLeaderInOrchMode, BufferManager const& bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - if (speculativeDecodingMode.predictsDraftTokens()) + if (!speculativeDecodingFastLogits) { - auto const& stream = decoderStream; - BufferManager manager{std::make_shared(stream.get())}; + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + bufferManager.copy(*reqDraftLogits, *draftLogitsHost); + return; + } - auto& dJointOutput = jointDecodingOutput; + if (isLeaderInOrchMode) + { + // reqDraftLogits contains metadata for fast-logits path; validate size. + auto constexpr fastLogitsInfoSize = sizeof(te::SpeculativeDecodingFastLogitsInfo); + TLLM_CHECK_WITH_INFO(reqDraftLogits->getSizeInBytes() >= fastLogitsInfoSize, + "Draft logits metadata buffer is too small to hold SpeculativeDecodingFastLogitsInfo."); + te::SpeculativeDecodingFastLogitsInfo fastLogitsInfo{}; + std::memcpy(&fastLogitsInfo, reqDraftLogits->data(), fastLogitsInfoSize); + utils::targetModelReceiveLogits(draftLogitsHost, fastLogitsInfo, modelConfig.getLogitsDtype()); - TensorPtr nextDraftTokens - = ITensor::slice(dJointOutput.speculativeDecodingOutputs->nextDraftTokens, batchIdx, 1); - // FIXME: can we skip this? - manager.setZero(*nextDraftTokens); - if (speculativeDecodingMode.variableDraftLength()) + // Broadcast to other ranks if needed + if (worldConfig.isTensorParallel()) { - TensorPtr nextDraftTokensLen - = ITensor::slice(dJointOutput.speculativeDecodingOutputs->nextDraftTokensLen, batchIdx, 1); - manager.setZero(*nextDraftTokensLen); + auto const& commSession = COMM_SESSION; + auto shape = draftLogitsHost->getShape(); + commSession.bcastValue(shape.d[0], 0); + commSession.bcastValue(shape.d[1], 0); + commSession.bcast(draftLogitsHost->data(), draftLogitsHost->getSizeInBytes(), mpi::MpiType::kUINT8, 0); } } + else + { + TLLM_CHECK_WITH_INFO(worldConfig.isTensorParallel(), + "Fast logits path requires tensor-parallel broadcast for non-leader ranks."); - if (speculativeDecodingMode.isDraftTokensExternal()) - { - newRequestDraftTokensExternal(batchIdx, request, samplingConfig, jointDecodingInput, decoderStream); - } - else if (speculativeDecodingMode.isMedusa()) - { - newRequestMedusa(batchIdx, request, jointDecodingInput, decoderStream, maxDecodingEngineTokens); - } - else if (speculativeDecodingMode.isLookaheadDecoding()) - { - newRequestLookahead(batchIdx, request, jointDecodingInput, jointDecodingOutput, runtimeStream); - } - else if (speculativeDecodingMode.isExplicitDraftTokens()) - { - newRequestExplicitDraftTokens(batchIdx, request, jointDecodingOutput, runtimeStream); - } - else if (speculativeDecodingMode.isEagle()) - { - newRequestEagle(batchIdx, request, modelConfig, jointDecodingOutput, runtimeStream); + // Get logits from leader rank + auto const& commSession = COMM_SESSION; + int64_t dims[2]; + commSession.bcastValue(dims[0], 0); + commSession.bcastValue(dims[1], 0); + draftLogitsHost->reshape(ITensor::makeShape({dims[0], dims[1]})); + commSession.bcast(draftLogitsHost->data(), draftLogitsHost->getSizeInBytes(), mpi::MpiType::kUINT8, 0); } + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); -} +}; -void CreateNewDecoderRequests::newRequestDraftTokensExternal(SizeType32 batchIdx, - runtime::decoder_batch::Request const& request, SamplingConfig const& samplingConfig, - DecodingInput& jointDecodingInput, CudaStream const& decoderStream) +//! @brief Setups decoder internal tensors for new request in Draft model Sps mode +void newRequestDraftTokensExternal(DecodingInput& jointDecodingInput, SizeType32 batchIdx, LlmRequest const& llmReq, + SizeType32 numDecodingEngineTokens, runtime::ModelConfig const& modelConfig, WorldConfig const& worldConfig, + bool speculativeDecodingFastLogits, bool isLeaderInOrchMode, CudaStream const& decoderStream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - BufferManager manager{std::make_shared(decoderStream.get())}; + BufferManager decoderBufferManager{std::make_shared(decoderStream.get())}; - auto& dJointInput = jointDecodingInput; + TLLM_CHECK(jointDecodingInput.externalDraftTokensInputs); + auto& externalDraftTokensInputs = jointDecodingInput.externalDraftTokensInputs; - auto const numDraftTokens = request.generatedTokensPerEngineStep - 1; + auto const& draftTokens = llmReq.getDraftTokens(); + auto const numDraftTokens = numDecodingEngineTokens - 1; - auto const useDraftLogits = request.draftLogits.has_value(); - if (useDraftLogits) - { - TensorPtr draftLogitsView = ITensor::view(request.draftLogits.value()); - - TensorPtr draftLogitsReqBatchSlice - = ITensor::slice(dJointInput.externalDraftTokensInputs->draftLogits, batchIdx, 1); - draftLogitsReqBatchSlice->squeeze(0); - TensorPtr draftLogitsReqTokensSlice = ITensor::slice(draftLogitsReqBatchSlice, 0, numDraftTokens); - manager.copy(*draftLogitsView, *draftLogitsReqTokensSlice); - } - auto* useDraftLogitsHostPtr = runtime::bufferCast(*dJointInput.externalDraftTokensInputs->useDraftLogitsHost); - useDraftLogitsHostPtr[batchIdx] = useDraftLogits; - auto useDraftLogitsView = ITensor::slice(dJointInput.externalDraftTokensInputs->useDraftLogits, batchIdx, 1); - runtime::kernels::invokeFill(*useDraftLogitsView, useDraftLogits, decoderStream); + auto numDraftTokensHostRange = runtime::BufferRange(*externalDraftTokensInputs->numDraftTokensHost); + numDraftTokensHostRange[batchIdx] = numDraftTokens; + auto numDraftTokensView = ITensor::slice(externalDraftTokensInputs->numDraftTokens, batchIdx, 1); + runtime::kernels::invokeFill(*numDraftTokensView, numDraftTokens, decoderStream); if (numDraftTokens > 0) { - TensorPtr draftTokensReqBatchSlice - = ITensor::slice(dJointInput.externalDraftTokensInputs->draftTokenIds, batchIdx, 1); - draftTokensReqBatchSlice->squeeze(0); - TensorPtr draftTokensReqTokensSlice = ITensor::slice(draftTokensReqBatchSlice, 0, numDraftTokens); - TensorPtr draftTokensView = ITensor::view(request.draftTokens, ITensor::makeShape({numDraftTokens})); - manager.copy(*draftTokensView, *draftTokensReqTokensSlice); + TensorPtr draftTokenIdsHostSlice + = ITensor::slice(externalDraftTokensInputs->draftTokenIdsHost, {batchIdx, 0}, numDraftTokens); + // Copy to pinned host memory (don't care about stream of bufferManager) + decoderBufferManager.copy(draftTokens->data(), *draftTokenIdsHostSlice); + + TensorPtr draftTokenIdsSlice + = ITensor::slice(externalDraftTokensInputs->draftTokenIds, {batchIdx, 0}, numDraftTokens); + decoderBufferManager.copy(*draftTokenIdsHostSlice, *draftTokenIdsSlice); } - auto* numDraftTokensHostPtr - = runtime::bufferCast(*dJointInput.externalDraftTokensInputs->numDraftTokensHost); - numDraftTokensHostPtr[batchIdx] = numDraftTokens; - auto numDraftTokensView = ITensor::slice(dJointInput.externalDraftTokensInputs->numDraftTokens, batchIdx, 1); - runtime::kernels::invokeFill(*numDraftTokensView, numDraftTokens, decoderStream); + auto const& draftLogits = llmReq.getDraftLogits(); + auto const useDraftLogits = draftLogits.has_value(); + auto useDraftLogitsHostRange = runtime::BufferRange(*externalDraftTokensInputs->useDraftLogitsHost); + useDraftLogitsHostRange[batchIdx] = useDraftLogits; + auto useDraftLogitsView = ITensor::slice(externalDraftTokensInputs->useDraftLogits, batchIdx, 1); + runtime::kernels::invokeFill(*useDraftLogitsView, useDraftLogits, decoderStream); + + if (useDraftLogits) + { + TensorPtr draftLogitsHostSlice + = ITensor::slice(externalDraftTokensInputs->draftLogitsHost, {batchIdx, 0}, numDraftTokens); + retrieveDraftLogits(draftLogitsHostSlice, draftLogits.value(), modelConfig, worldConfig, + speculativeDecodingFastLogits, isLeaderInOrchMode, decoderBufferManager); + + TensorPtr draftLogitsSlice + = ITensor::slice(externalDraftTokensInputs->draftLogits, {batchIdx, 0}, numDraftTokens); + decoderBufferManager.copy(*draftLogitsHostSlice, *draftLogitsSlice); + } + + auto const& samplingConfig = llmReq.mSamplingConfig; bool const useRandomAcceptanceThreshold = !samplingConfig.draftAcceptanceThreshold.has_value(); float const constantThreshold = useRandomAcceptanceThreshold ? 0 : samplingConfig.draftAcceptanceThreshold.value()[0]; - dJointInput.externalDraftTokensInputs->useRandomAcceptanceThreshold = useRandomAcceptanceThreshold; - dJointInput.externalDraftTokensInputs->constantThreshold = constantThreshold; + externalDraftTokensInputs->useRandomAcceptanceThreshold = useRandomAcceptanceThreshold; + externalDraftTokensInputs->constantThreshold = constantThreshold; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void CreateNewDecoderRequests::newRequestMedusa(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - DecodingInput& jointDecodingInput, CudaStream const& decoderStream, SizeType32 maxDecodingEngineTokens) +//! @brief Setups decoder internal tensors for new Medusa request +void newRequestMedusa(DecodingInput& jointDecodingInput, SizeType32 batchIdx, LlmRequest& llmReq, + SizeType32 numDecodingEngineTokens, SizeType32 maxDecodingEngineTokens, MedusaBuffers const& medusaBuffers, + CudaStream const& decoderStream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + llmReq.mSamplingConfig.topKMedusaHeads = {medusaBuffers.mTopKs}; + // FIXME: we must set medusa paths and tree ids not from seq slot, but from llmRequest? + // When multiple microbatches buffers are used, runtime buffers can not be addressed with seqSlot. + auto medusaPaths = ITensor::slice(medusaBuffers.medusaPathsDevice, 0, 1); + auto medusaTreeIds = ITensor::slice(medusaBuffers.medusaTreeIdsDevice, 0, 1); + BufferManager manager{std::make_shared(decoderStream.get())}; - auto& dJointInput = jointDecodingInput; + auto& medusaInputs = jointDecodingInput.medusaInputs; TensorPtr curTokensPerStepSlice - = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaCurTokensPerStep), batchIdx, 1); + = ITensor::slice(constPointerCast(medusaInputs->medusaCurTokensPerStep), batchIdx, 1); // Context phase Medusa processes 1 token only, new value from targetTokensPerStep will be filled at the end // of first decoder runtime::kernels::invokeFill(*curTokensPerStepSlice, 1, decoderStream); TensorPtr targetTokensPerStepSlice - = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaTargetTokensPerStep), batchIdx, 1); - auto const generatedTokensPerEngineStep = request.generatedTokensPerEngineStep; - TLLM_CHECK_WITH_INFO(generatedTokensPerEngineStep <= maxDecodingEngineTokens, - "Tokens per step for (%d) is larger than maximum tokens per step (%d)", generatedTokensPerEngineStep, + = ITensor::slice(constPointerCast(medusaInputs->medusaTargetTokensPerStep), batchIdx, 1); + TLLM_CHECK_WITH_INFO(numDecodingEngineTokens <= maxDecodingEngineTokens, + "Tokens per step for (%d) is larger than maximum tokens per step (%d)", numDecodingEngineTokens, maxDecodingEngineTokens); - runtime::kernels::invokeFill(*targetTokensPerStepSlice, generatedTokensPerEngineStep, decoderStream); + runtime::kernels::invokeFill(*targetTokensPerStepSlice, numDecodingEngineTokens, decoderStream); - TensorPtr pathsSlice = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaPaths), batchIdx, 1); - manager.copy(*request.medusaPaths, *pathsSlice); + TensorPtr pathsSlice = ITensor::slice(constPointerCast(medusaInputs->medusaPaths), batchIdx, 1); + manager.copy(*medusaPaths, *pathsSlice); - TensorPtr treeIdsSlice = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaTreeIds), batchIdx, 1); - manager.copy(*request.medusaTreeIds, *treeIdsSlice); + TensorPtr treeIdsSlice = ITensor::slice(constPointerCast(medusaInputs->medusaTreeIds), batchIdx, 1); + manager.copy(*medusaTreeIds, *treeIdsSlice); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void CreateNewDecoderRequests::newRequestLookahead(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream) +//! @brief Setups decoder internal tensors for new Lookahead request +void newRequestLookahead(DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, SizeType32 batchIdx, + CudaStream const& runtimeStream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(jointDecodingOutput.lookaheadOutputs); + TLLM_CHECK(jointDecodingInput.lookaheadInputs); // The first generation step only generate 1 token. TensorPtr curTokensPerStepSlice @@ -472,65 +493,72 @@ void CreateNewDecoderRequests::newRequestLookahead(SizeType32 batchIdx, runtime: TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void CreateNewDecoderRequests::newRequestExplicitDraftTokens(SizeType32 batchIdx, - runtime::decoder_batch::Request const& request, DecodingOutput& jointDecodingOutput, - CudaStream const& runtimeStream) +//! @brief Setups decoder internal tensors for new Explicit draft tokens request +void newRequestExplicitDraftTokens( + DecodingOutput& jointDecodingOutput, SizeType32 batchIdx, LlmRequest const& llmReq, CudaStream const& runtimeStream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(jointDecodingOutput.explicitDraftTokensBuffers); + auto const inputLen = llmReq.getPromptLen(); + TensorPtr positionIdsBaseSlice = ITensor::slice(jointDecodingOutput.explicitDraftTokensBuffers->positionIdsBase, batchIdx, 1); - runtime::kernels::invokeFill(*positionIdsBaseSlice, request.inputLen, runtimeStream); + runtime::kernels::invokeFill(*positionIdsBaseSlice, inputLen, runtimeStream); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - runtime::ModelConfig const& modelConfig, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream) +//! @brief Setups decoder internal tensors for new Eagle request +void newRequestEagle(DecodingOutput& jointDecodingOutput, SizeType32 batchIdx, LlmRequest const& llmReq, + runtime::ModelConfig const& modelConfig, executor::DecodingConfig const& decodingConfig, + CudaStream const& runtimeStream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(jointDecodingOutput.eagleBuffers); + auto& eagleBuffers = *jointDecodingOutput.eagleBuffers; + + auto const inputLen = llmReq.getPromptLen(); BufferManager manager{std::make_shared(runtimeStream.get())}; - TensorPtr eagleNetCtxRequestTypesHostSlice - = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetCtxRequestTypesHost, batchIdx, 1); + TensorPtr eagleNetCtxRequestTypesHostSlice = ITensor::slice(eagleBuffers.eagleNetCtxRequestTypesHost, batchIdx, 1); TensorPtr eagleNetCtxContextLengthsHostSlice - = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetCtxContextLengthsHost, batchIdx, 1); + = ITensor::slice(eagleBuffers.eagleNetCtxContextLengthsHost, batchIdx, 1); TensorPtr eagleNetCtxPastKeyValueLengthsHostSlice - = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetCtxPastKeyValueLengthsHost, batchIdx, 1); + = ITensor::slice(eagleBuffers.eagleNetCtxPastKeyValueLengthsHost, batchIdx, 1); runtime::bufferCast(*eagleNetCtxRequestTypesHostSlice)[0] = 0; - runtime::bufferCast(*eagleNetCtxContextLengthsHostSlice)[0] = request.inputLen; - runtime::bufferCast(*eagleNetCtxPastKeyValueLengthsHostSlice)[0] = request.inputLen; + runtime::bufferCast(*eagleNetCtxContextLengthsHostSlice)[0] = inputLen; + runtime::bufferCast(*eagleNetCtxPastKeyValueLengthsHostSlice)[0] = inputLen; - TensorPtr eagleNetGenRequestTypesHostSlice - = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetGenRequestTypesHost, batchIdx, 1); + TensorPtr eagleNetGenRequestTypesHostSlice = ITensor::slice(eagleBuffers.eagleNetGenRequestTypesHost, batchIdx, 1); TensorPtr eagleNetGenContextLengthsHostSlice - = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetGenContextLengthsHost, batchIdx, 1); + = ITensor::slice(eagleBuffers.eagleNetGenContextLengthsHost, batchIdx, 1); TensorPtr eagleNetGenPastKeyValueLengthsHostSlice - = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetGenPastKeyValueLengthsHost, batchIdx, 1); + = ITensor::slice(eagleBuffers.eagleNetGenPastKeyValueLengthsHost, batchIdx, 1); runtime::bufferCast(*eagleNetGenRequestTypesHostSlice)[0] = 1; - runtime::bufferCast(*eagleNetGenContextLengthsHostSlice)[0] = request.inputLen; - runtime::bufferCast(*eagleNetGenPastKeyValueLengthsHostSlice)[0] = request.inputLen; + runtime::bufferCast(*eagleNetGenContextLengthsHostSlice)[0] = inputLen; + runtime::bufferCast(*eagleNetGenPastKeyValueLengthsHostSlice)[0] = inputLen; auto const eagleModule = std::dynamic_pointer_cast( modelConfig.getSpeculativeDecodingModulePtr()); std::optional eagleChoicesOpt; - if (request.eagleConfig) + auto const& eagleConfig = llmReq.getEagleConfig() ? llmReq.getEagleConfig() : decodingConfig.getEagleConfig(); + + if (eagleConfig) { - eagleChoicesOpt = request.eagleConfig->getEagleChoices(); + eagleChoicesOpt = eagleConfig->getEagleChoices(); } - if (!request.eagleConfig || !request.eagleConfig->useDynamicTree()) + if (!eagleConfig || !eagleConfig->useDynamicTree()) { - TensorPtr draftPathsHostSlice = ITensor::slice(jointDecodingOutput.eagleBuffers->draftPathsHost, batchIdx, 1); - TensorPtr draftPathsSlice = ITensor::slice(jointDecodingOutput.eagleBuffers->draftPaths, batchIdx, 1); + TensorPtr draftPathsHostSlice = ITensor::slice(eagleBuffers.draftPathsHost, batchIdx, 1); + TensorPtr draftPathsSlice = ITensor::slice(eagleBuffers.draftPaths, batchIdx, 1); // eagleConfig is nullptr or Eagle-1 std::vector topKs; @@ -546,6 +574,61 @@ void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::dec TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } +//! @brief Setups decoder internal tensors for new speculative decoding request +void newRequestSpeculativeDecoding(DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, + SizeType32 batchIdx, LlmRequest& llmReq, SpeculativeDecodingMode const& speculativeDecodingMode, + SizeType32 numDecodingEngineTokens, SizeType32 maxDecodingEngineTokens, + OptionalRef medusaBuffers, runtime::ModelConfig const& modelConfig, + WorldConfig const& worldConfig, executor::DecodingConfig const& decodingConfig, bool speculativeDecodingFastLogits, + bool isLeaderInOrchMode, CudaStream const& runtimeStream, CudaStream const& decoderStream) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + if (speculativeDecodingMode.predictsDraftTokens()) + { + BufferManager manager{std::make_shared(decoderStream.get())}; + + TLLM_CHECK(jointDecodingOutput.speculativeDecodingOutputs); + auto& speculativeDecodingOutputs = *jointDecodingOutput.speculativeDecodingOutputs; + + TensorPtr nextDraftTokens = ITensor::slice(speculativeDecodingOutputs.nextDraftTokens, batchIdx, 1); + // FIXME: can we skip this? + manager.setZero(*nextDraftTokens); + if (speculativeDecodingMode.variableDraftLength()) + { + TensorPtr nextDraftTokensLen = ITensor::slice(speculativeDecodingOutputs.nextDraftTokensLen, batchIdx, 1); + manager.setZero(*nextDraftTokensLen); + } + } + + if (speculativeDecodingMode.isDraftTokensExternal()) + { + newRequestDraftTokensExternal(jointDecodingInput, batchIdx, llmReq, numDecodingEngineTokens, modelConfig, + worldConfig, speculativeDecodingFastLogits, isLeaderInOrchMode, decoderStream); + } + else if (speculativeDecodingMode.isMedusa()) + { + TLLM_CHECK(medusaBuffers); + newRequestMedusa(jointDecodingInput, batchIdx, llmReq, numDecodingEngineTokens, maxDecodingEngineTokens, + medusaBuffers.value(), decoderStream); + } + else if (speculativeDecodingMode.isLookaheadDecoding()) + { + newRequestLookahead(jointDecodingInput, jointDecodingOutput, batchIdx, runtimeStream); + } + else if (speculativeDecodingMode.isExplicitDraftTokens()) + { + newRequestExplicitDraftTokens(jointDecodingOutput, batchIdx, llmReq, runtimeStream); + } + else if (speculativeDecodingMode.isEagle()) + { + newRequestEagle(jointDecodingOutput, batchIdx, llmReq, modelConfig, decodingConfig, runtimeStream); + } + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +} // namespace + std::tuple, std::vector> CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedContextRequests, TensorPtr const& inputIds, executor::DecodingConfig const& decodingConfig, runtime::decoder::DecoderState& decoderState, @@ -563,9 +646,6 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon } inputIds->resize(decoderInputSize); - std::vector decoderRequests; - decoderRequests.reserve(finishedContextRequests.size()); - std::vector lookaheadPrompt; std::vector lookaheadAlgoConfigs; if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding()) @@ -597,36 +677,18 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon auto const promptLen = llmReq->getPromptLen(); - auto decoderRequest = decoder_batch::Request{promptLen}; - + SizeType32 numDecodingEngineTokens{1}; if (modelConfig.getSpeculativeDecodingMode().isDraftTokensExternal()) { - if (llmReq->hasDraftTokens()) - { - auto const& draftTokens = llmReq->getDraftTokens(); - // Copy to pinned host memory (don't care about stream of bufferManager) - decoderRequest.draftTokens = decoderBufferManager.copyFrom(*draftTokens, MemoryType::kPINNEDPOOL); - auto const& draftLogits = llmReq->getDraftLogits(); - if (draftLogits.has_value()) - { - decoderRequest.draftLogits - = retrieveDraftLogits(modelConfig, worldConfig, draftLogits.value(), decoderBufferManager); - } - decoderRequest.generatedTokensPerEngineStep = draftTokens->size() + 1; - } - else - { - decoderRequest.generatedTokensPerEngineStep = 1; - } + numDecodingEngineTokens = llmReq->getNumDraftTokens() + 1; } else if (!modelConfig.getSpeculativeDecodingMode().isNone()) { - decoderRequest.generatedTokensPerEngineStep = modelConfig.getMaxDecodingTokens(); + numDecodingEngineTokens = modelConfig.getMaxDecodingTokens(); } auto& dJointInput = decoderState.getJointDecodingInput(); - auto const numDecodingEngineTokens = decoderRequest.generatedTokensPerEngineStep; initializeInputLengths(dJointInput, batchSlot, promptLen, llmReq->mMaxNewTokens, numDecodingEngineTokens, maxSequenceLength, decoderBufferManager); decoderState.setNumDecodingEngineTokens(batchSlot, numDecodingEngineTokens); @@ -667,16 +729,7 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon { TLLM_CHECK(beamWidth == 1); - if (modelConfig.getSpeculativeDecodingMode().isMedusa()) - { - TLLM_CHECK(medusaBuffers); - llmReq->mSamplingConfig.topKMedusaHeads = {medusaBuffers->mTopKs}; - // FIXME: we must set medusa paths and tree ids not from seq slot, but from llmRequest? - // When multiple microbatches buffers are used, runtime buffers can not be addressed with seqSlot. - decoderRequest.medusaPaths = ITensor::slice(medusaBuffers->medusaPathsDevice, 0, 1); - decoderRequest.medusaTreeIds = ITensor::slice(medusaBuffers->medusaTreeIdsDevice, 0, 1); - } - else if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding()) + if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding()) { lookaheadPrompt.emplace_back(requestIds); @@ -684,67 +737,17 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon = llmReq->getLookaheadConfig().value_or(decodingConfig.getLookaheadDecodingConfig().value()); lookaheadAlgoConfigs.emplace_back(lookaheadRuntimeConfig); } - else if (modelConfig.getSpeculativeDecodingMode().isEagle()) - { - decoderRequest.eagleConfig - = llmReq->getEagleConfig() ? llmReq->getEagleConfig() : decodingConfig.getEagleConfig(); - } - newRequestSpeculativeDecoding(batchSlot, decoderRequest, samplingConfig, modelConfig, - decoderState.getJointDecodingInput(), decoderState.getJointDecodingOutput(), runtimeStream, - decoderStream, decoderState.getSpeculativeDecodingMode(), decoderState.getMaxDecodingEngineTokens()); + newRequestSpeculativeDecoding(decoderState.getJointDecodingInput(), decoderState.getJointDecodingOutput(), + batchSlot, *llmReq, decoderState.getSpeculativeDecodingMode(), numDecodingEngineTokens, + decoderState.getMaxDecodingEngineTokens(), medusaBuffers, modelConfig, worldConfig, decodingConfig, + mSpeculativeDecodingFastLogits, mIsLeaderInOrchMode, runtimeStream, decoderStream); } - decoderRequests.push_back(decoderRequest); - inputOffset += promptLen; } return {std::move(lookaheadPrompt), std::move(lookaheadAlgoConfigs)}; } -std::shared_ptr CreateNewDecoderRequests::retrieveDraftLogits(ModelConfig const& modelConfig, - WorldConfig const& worldConfig, std::shared_ptr const& tensor, - BufferManager const& bufferManager) const -{ - TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - - if (!mSpeculativeDecodingFastLogits) - { - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return bufferManager.copyFrom(*tensor, MemoryType::kPINNEDPOOL); - } - - if (mIsLeaderInOrchMode) - { - te::SpeculativeDecodingFastLogitsInfo fastLogitsInfo; - std::memcpy(&fastLogitsInfo, tensor->data(), sizeof(fastLogitsInfo)); - auto logits = utils::targetModelReceiveLogits(fastLogitsInfo, modelConfig).value(); - - // Broadcast to other ranks if needed - if (worldConfig.isTensorParallel()) - { - auto const& commSession = COMM_SESSION; - auto shape = logits->getShape(); - commSession.bcastValue(shape.d[0], 0); - commSession.bcastValue(shape.d[1], 0); - commSession.bcast(logits->data(), logits->getSizeInBytes(), mpi::MpiType::kUINT8, 0); - } - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return logits; - } - - // Get logits from leader rank - auto const& commSession = COMM_SESSION; - int64_t dims[2]; - commSession.bcastValue(dims[0], 0); - commSession.bcastValue(dims[1], 0); - auto const logitsDtype = modelConfig.getLogitsDtype(); - auto logits = tensorrt_llm::runtime::BufferManager::pinnedPool(ITensor::makeShape({dims[0], dims[1]}), logitsDtype); - commSession.bcast(logits->data(), logits->getSizeInBytes(), mpi::MpiType::kUINT8, 0); - - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return logits; -}; - } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index 22756f2552..eaa2e957e8 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -558,18 +558,20 @@ void MLACacheFormatter::unformat(TransferSession& session) TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support MLA"); return false; } - - if (selfConfig.getAttentionConfig().mKvFactor != destConfig.getAttentionConfig().mKvFactor) - { - TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support same kv factor"); - return false; - } if (selfConfig.getParallelConfig().mEnableAttentionDP && (selfConfig.getParallelConfig().mTensorParallelism % selfConfig.getParallelConfig().mDPsize != 0)) { TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: TP size must be divisible by DP size"); return false; } + if (selfConfig.getParallelConfig().mContextParallelism != 1 + || destConfig.getParallelConfig().mContextParallelism != 1) + { + TLLM_LOG_WARNING( + "MLACacheFormatter::inquireSupport: context parallelism is not currently supported (selfCP=%d, destCP=%d).", + selfConfig.getParallelConfig().mContextParallelism, destConfig.getParallelConfig().mContextParallelism); + return false; + } if (destConfig.getParallelConfig().mEnableAttentionDP && (destConfig.getParallelConfig().mTensorParallelism % destConfig.getParallelConfig().mDPsize != 0)) { diff --git a/cpp/tensorrt_llm/batch_manager/utils/logitsThread.cpp b/cpp/tensorrt_llm/batch_manager/utils/logitsThread.cpp index 484cd7c3c7..7234ca9ba5 100644 --- a/cpp/tensorrt_llm/batch_manager/utils/logitsThread.cpp +++ b/cpp/tensorrt_llm/batch_manager/utils/logitsThread.cpp @@ -121,8 +121,8 @@ void draftModelSendLogitsThread(int device, std::atomic* draftModelThreadS #endif // ENABLE_MULTI_DEVICE } -std::optional targetModelReceiveLogits( - executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, runtime::ModelConfig const& modelConfig) +void targetModelReceiveLogits(runtime::ITensor::SharedPtr& draftLogitsHost, + executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, nvinfer1::DataType logitsDtype) { #if ENABLE_MULTI_DEVICE auto const& worldComm = tensorrt_llm::mpi::MpiComm::world(); @@ -151,10 +151,7 @@ std::optional targetModelReceiveLogits( int64_t dims[2]; MPICHECK(MPI_Mrecv(&dims, count, MPI_INT64_T, &msg, &status)); - auto const logitsDtype = modelConfig.getLogitsDtype(); - - auto tensor = tensorrt_llm::runtime::BufferManager::pinnedPool( - runtime::ITensor::makeShape({dims[0], dims[1]}), logitsDtype); + draftLogitsHost->reshape(runtime::ITensor::makeShape({dims[0], dims[1]})); worldComm.mprobe(fastLogitsInfo.draftParticipantId, mpi::MpiTag::kSpecDecLogitsData, &msg, &status); @@ -163,11 +160,7 @@ std::optional targetModelReceiveLogits( uint64_t const expectedSize = static_cast(dims[0]) * dims[1] * tc::getDTypeSize(logitsDtype); TLLM_CHECK((uint64_t) count == expectedSize); - MPICHECK(MPI_Mrecv(tensor->data(), count, MPI_UINT8_T, &msg, &status)); - - return tensor; -#else - return std::nullopt; + MPICHECK(MPI_Mrecv(draftLogitsHost->data(), count, MPI_UINT8_T, &msg, &status)); #endif // ENABLE_MULTI_DEVICE } diff --git a/cpp/tensorrt_llm/batch_manager/utils/logitsThread.h b/cpp/tensorrt_llm/batch_manager/utils/logitsThread.h index 6d87ebee16..f19d5f5ef3 100644 --- a/cpp/tensorrt_llm/batch_manager/utils/logitsThread.h +++ b/cpp/tensorrt_llm/batch_manager/utils/logitsThread.h @@ -21,10 +21,8 @@ #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/iTensor.h" -#include "tensorrt_llm/runtime/modelConfig.h" #include -#include namespace tensorrt_llm::batch_manager { @@ -52,7 +50,7 @@ void draftModelSendLogitsThread(int device, std::atomic* draftModelThreadS std::shared_ptr const& crossKvCacheManager, std::shared_ptr const& peftCacheManager); -std::optional targetModelReceiveLogits( - executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, runtime::ModelConfig const& modelConfig); +void targetModelReceiveLogits(runtime::ITensor::SharedPtr& draftLogitsHost, + executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, nvinfer1::DataType logitsDtype); } // namespace tensorrt_llm::batch_manager::utils diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h index 3fab43a3b4..e226c0d812 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h @@ -435,6 +435,14 @@ struct CutlassGemmConfig int sm_version = 80; // Use 80 as a catch all for <90 bool is_tma_warp_specialized = false; + enum class EpilogueFusionType : int + { + NONE, + FINALIZE + }; + + EpilogueFusionType epilogue_fusion_type = EpilogueFusionType::NONE; + CutlassGemmConfig() = default; CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages) @@ -505,7 +513,8 @@ struct CutlassGemmConfig << "\n\tsm: " << sm_version << "\n\ttile shape ID: " << getTileConfigAsInt() << "\n\tcluster shape ID: " << (int) cluster_shape << "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule - << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false"); + << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false") + << "\n\tepilogue fusion type: " << (int) epilogue_fusion_type; } else if (tile_config_sm80 != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) { @@ -537,7 +546,8 @@ inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& conf << ", mainloop_schedule_enum: " << int(config.mainloop_schedule) << ", epilogue_schedule_enum: " << int(config.epilogue_schedule) << ", cluster_shape_enum: " << int(config.cluster_shape) - << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false"); + << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false") + << ", epilogue_fusion_type: " << int(config.epilogue_fusion_type); } else { diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index 738a095eef..bba8d19e2f 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -531,14 +531,15 @@ kv_cache::CacheState Serialization::deserializeCacheState(std::istream& is) auto tokensPerBlock = su::deserialize(is); auto tensorParallelism = su::deserialize(is); auto pipelineParallelism = su::deserialize(is); + auto contextParallelism = su::deserialize(is); auto enableAttentionDP = su::deserialize(is); auto DPrank = su::deserialize(is); auto DPsize = su::deserialize(is); auto dataType = su::deserialize(is); auto attentionType = su::deserialize(is); auto kvFactor = su::deserialize(is); - return CacheState{nbKvHeadsPerLayer, sizePerHead, tokensPerBlock, tensorParallelism, pipelineParallelism, dataType, - attentionType, kvFactor, enableAttentionDP, DPrank, DPsize}; + return CacheState{nbKvHeadsPerLayer, sizePerHead, tokensPerBlock, tensorParallelism, pipelineParallelism, + contextParallelism, dataType, attentionType, kvFactor, enableAttentionDP, DPrank, DPsize}; } void Serialization::serialize(kv_cache::CacheState const& state, std::ostream& os) @@ -548,6 +549,7 @@ void Serialization::serialize(kv_cache::CacheState const& state, std::ostream& o su::serialize(state.mModelConfig.mTokensPerBlock, os); su::serialize(state.mParallelConfig.mTensorParallelism, os); su::serialize(state.mParallelConfig.mPipelineParallelism, os); + su::serialize(state.mParallelConfig.mContextParallelism, os); su::serialize(state.mParallelConfig.mEnableAttentionDP, os); su::serialize(state.mParallelConfig.mDPrank, os); su::serialize(state.mParallelConfig.mDPsize, os); @@ -564,6 +566,7 @@ size_t Serialization::serializedSize(kv_cache::CacheState const& state) totalSize += su::serializedSize(state.mModelConfig.mTokensPerBlock); totalSize += su::serializedSize(state.mParallelConfig.mTensorParallelism); totalSize += su::serializedSize(state.mParallelConfig.mPipelineParallelism); + totalSize += su::serializedSize(state.mParallelConfig.mContextParallelism); totalSize += su::serializedSize(state.mParallelConfig.mEnableAttentionDP); totalSize += su::serializedSize(state.mParallelConfig.mDPrank); totalSize += su::serializedSize(state.mParallelConfig.mDPsize); diff --git a/cpp/tensorrt_llm/kernels/customMoeRoutingKernels.cu b/cpp/tensorrt_llm/kernels/customMoeRoutingKernels.cu new file mode 100644 index 0000000000..eb3b958eb2 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/customMoeRoutingKernels.cu @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "moeTopKFuncs.cuh" +#include "tensorrt_llm/common/cudaTypeUtils.cuh" +#include "tensorrt_llm/common/envUtils.h" +#include "tensorrt_llm/kernels/archCondition.h" +#include "tensorrt_llm/kernels/customMoeRoutingKernels.h" +#include // For INT_MAX +#include +#include +#include +#include // For numeric_limits +#include + +namespace cg = cooperative_groups; +using namespace tensorrt_llm::common; + +namespace tensorrt_llm::kernels +{ + +static constexpr int BLOCK_SIZE = 1024; +static constexpr int WARP_SIZE = 32; +static constexpr int WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ T calcSoftmax(cg::thread_block_tile const& warp, T score, int32_t laneIdx, int32_t NumTopExperts) +{ + T maxScore = T{-INFINITY}; + if (laneIdx < NumTopExperts) + { + maxScore = score >= maxScore ? score : maxScore; + } + maxScore = cg::reduce(warp, maxScore, cg::greater()); + + float sumScore{0.f}; + float newScore; + // Get the summation of scores for each token + if (laneIdx < NumTopExperts) + { + newScore = static_cast(score) - static_cast(maxScore); + newScore = static_cast(exp(newScore)); + sumScore += newScore; + } + sumScore = cg::reduce(warp, sumScore, cg::plus()); + + if (laneIdx < NumTopExperts) + { + score = static_cast(newScore / sumScore); + } + + return score; +} + +template +__device__ void calcSoftmax(cg::thread_block_tile const& warp, DataType (&scores)[VecSize]) +{ + DataType maxScore = DataType{-INFINITY}; + DataType sumScore = DataType{0.f}; + + // Get the max score for each token +#pragma unroll + for (int i = 0; i < VecSize; ++i) + { + maxScore = scores[i] >= maxScore ? scores[i] : maxScore; + } + maxScore = cg::reduce(warp, maxScore, cg::greater()); + + // Get the summation of scores for each token +#pragma unroll + for (int i = 0; i < VecSize; ++i) + { + scores[i] = static_cast(exp(scores[i] - maxScore)); + sumScore += scores[i]; + } + sumScore = cg::reduce(warp, sumScore, cg::plus()); + + // Normalize the scores +#pragma unroll + for (int i = 0; i < VecSize; ++i) + { + scores[i] = static_cast(scores[i] / sumScore); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void customMoeRoutingKernel(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, + int32_t const numTokens, int32_t const numExperts, int32_t const topK) +{ + using BaseType = std::conditional_t; + uint32_t const blockRank = blockIdx.x; + uint32_t const tIdx = BLOCK_SIZE * blockRank + threadIdx.x; + uint32_t const warpIdx = tIdx / WARP_SIZE; + uint32_t const laneIdx = tIdx % WARP_SIZE; + uint32_t const warpNum = gridDim.x * WARPS_PER_BLOCK; + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + BaseType minScore = BaseType{-INFINITY}; + for (uint32_t tokenId = warpIdx; tokenId < numTokens; tokenId += warpNum) + { + auto scoreOffset = tokenId * numExperts; + auto outputOffset = tokenId * topK; + + BaseType inputScore[MaxNumExperts / WARP_SIZE]; + IdxT inputIndex[MaxNumExperts / WARP_SIZE]; + + BaseType warpTopKScore[MaxNumTopExperts]; + IdxT warpTopKExpertIdx[MaxNumTopExperts]; + + // Load scores and indices for this warp + for (uint32_t i = 0; i < MaxNumExperts / WARP_SIZE; ++i) + { + auto expertIdx = i * WARP_SIZE + laneIdx; + inputScore[i] + = expertIdx < numExperts ? static_cast(routerLogits[scoreOffset + expertIdx]) : minScore; + inputIndex[i] = expertIdx; + } + + if constexpr (DoSoftmaxBeforeTopK) + { + calcSoftmax(warp, inputScore); + } + // Reduce topK scores and indices for this warp + reduce_topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, inputScore, inputIndex, minScore); + + // Normalize the scores + if constexpr (DoSoftmaxBeforeTopK) + { + if (laneIdx < topK) + { + topkValues[outputOffset + laneIdx] = static_cast(warpTopKScore[laneIdx]); + topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx]; + } + } + else + { + auto softmaxScore = calcSoftmax(warp, + laneIdx < topK ? static_cast(warpTopKScore[laneIdx]) : static_cast(minScore), laneIdx, + topK); + if (laneIdx < topK) + { + topkValues[outputOffset + laneIdx] = static_cast(softmaxScore); + topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx]; + } + } + } // end for tokenId +} + +int nextPowerOfTwo(int num) +{ + if (num <= 0) + { + return 1; // Handle invalid input + } + int power = 1; + while (power < num) + { + // Check for overflow before shifting + if (power > INT_MAX / 2) + { + return power; + } + power <<= 1; + } + return power; +} + +#define CASE(MAX_NUM_EXPERTS) \ + case MAX_NUM_EXPERTS: \ + switch (maxNumTopExperts) \ + { \ + case 1: \ + kernelInstance = &customMoeRoutingKernel; \ + break; \ + case 2: \ + kernelInstance = &customMoeRoutingKernel; \ + break; \ + case 4: \ + kernelInstance = &customMoeRoutingKernel; \ + break; \ + case 8: \ + kernelInstance = &customMoeRoutingKernel; \ + break; \ + default: kernelInstance = nullptr; break; \ + } \ + break; + +template +void invokeRenormMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens, + int64_t const numExperts, int64_t const topK, cudaStream_t const stream) +{ + + const uint32_t maxNumBlocks = 1024; + const uint32_t numBlocks = std::min(static_cast((numTokens - 1) / WARPS_PER_BLOCK + 1), maxNumBlocks); + + uint32_t maxNumExperts = nextPowerOfTwo(numExperts) < 32 ? 32 : nextPowerOfTwo(numExperts); + uint32_t maxNumTopExperts = nextPowerOfTwo(topK); + + auto* kernelInstance = &customMoeRoutingKernel; + + switch (maxNumExperts) + { + CASE(32) + CASE(64) + CASE(96) + CASE(128) + default: kernelInstance = nullptr; break; + } + + if (kernelInstance == nullptr) + { + TLLM_CHECK_WITH_INFO(kernelInstance != nullptr, "Can not find corresponding kernel instance."); + } + + dim3 renormMoeRoutingGridDim(numBlocks); + dim3 renormMoeRoutingBlockDim(BLOCK_SIZE); + cudaLaunchConfig_t config; + config.gridDim = renormMoeRoutingGridDim; + config.blockDim = renormMoeRoutingBlockDim; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernelInstance, routerLogits, topkValues, topkIndices, static_cast(numTokens), + static_cast(numExperts), static_cast(topK)); + sync_check_cuda_error(stream); +} + +#define INSTANTIATE_RENORM_MOE_ROUTING(InputT, OutputT, IdxT, DoSoftmaxBeforeTopK) \ + template void invokeRenormMoeRouting(InputT * routerLogits, \ + OutputT * topkValues, IdxT * topkIndices, int64_t const numTokens, int64_t const numExperts, \ + int64_t const topK, cudaStream_t const stream); + +INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t, false); +INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t, false); +#ifdef ENABLE_BF16 +INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t, false); +#endif + +INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t, true); +INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t, true); +#ifdef ENABLE_BF16 +INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t, true); +#endif + +} // namespace tensorrt_llm::kernels diff --git a/cpp/tensorrt_llm/kernels/renormMoeRoutingKernels.h b/cpp/tensorrt_llm/kernels/customMoeRoutingKernels.h similarity index 86% rename from cpp/tensorrt_llm/kernels/renormMoeRoutingKernels.h rename to cpp/tensorrt_llm/kernels/customMoeRoutingKernels.h index 1e9b001f65..cfe0ae8f15 100644 --- a/cpp/tensorrt_llm/kernels/renormMoeRoutingKernels.h +++ b/cpp/tensorrt_llm/kernels/customMoeRoutingKernels.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ namespace tensorrt_llm::kernels { -template +template void invokeRenormMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens, int64_t const numExperts, int64_t const topK, cudaStream_t const stream); } // namespace tensorrt_llm::kernels diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h index ba755ca669..2c0894ab57 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h @@ -288,15 +288,20 @@ public: void moeGemm(GroupedGemmInput inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs); - std::vector getConfigs() const; - static std::vector getConfigs(int sm); - static std::vector getTmaWarpSpecializedConfigs(int sm); - static std::vector getBlackwellConfigs(int sm); - static std::vector getHopperConfigs(int sm); + std::vector getConfigs(bool supports_finalize_fusion) const; + static std::vector getConfigs(int sm, bool supports_finalize_fusion); + static std::vector getTmaWarpSpecializedConfigs( + int sm, bool supports_finalize_fusion); static std::vector getAmpereConfigs(int sm); [[nodiscard]] bool isTmaWarpSpecialized(cutlass_extensions::CutlassGemmConfig gemm_config) const; - [[nodiscard]] bool supportsTmaWarpSpecialized() const; + + [[nodiscard]] bool supportsTmaWarpSpecialized() const + { + return supportsTmaWarpSpecialized(sm_); + } + + [[nodiscard]] static bool supportsTmaWarpSpecialized(int sm); [[nodiscard]] bool isFusedGatedActivation(cutlass_extensions::CutlassGemmConfig gemm_config, ActivationType activation_type, int gemm_n, int gemm_k) const; [[nodiscard]] bool supportsFusedGatedActivation(ActivationType activation_type, int gemm_n, int gemm_k) const; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index 7d592bed0e..389591e7fe 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -228,6 +228,13 @@ struct MOEParallelismConfig } }; +enum class MoeGemmId : int +{ + Undefined = 0, + GEMM_1, + GEMM_2 +}; + struct QuantParams { // Int weight only quantization params @@ -446,7 +453,7 @@ public: virtual void setTactic(std::optional gemm1_config, std::optional gemm2_config) = 0; - virtual std::vector getTactics() = 0; + virtual std::vector getTactics(MoeGemmId gemm_id) = 0; virtual void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf, int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights, @@ -593,15 +600,15 @@ public: gemm2_config_ = std::move(gemm2_config); } - std::vector getTactics() override + std::vector getTactics(MoeGemmId gemm_id) override { - return moe_gemm_runner_.getConfigs(); + return moe_gemm_runner_.getConfigs(gemm_id == MoeGemmId::GEMM_2 && mayHaveFinalizeFused()); } - static std::vector getTactics(int sm) + static std::vector getTactics(int sm, MoeGemmId gemm_id) { using RunnerType = decltype(moe_gemm_runner_); - return RunnerType::getConfigs(sm); + return RunnerType::getConfigs(sm, gemm_id == MoeGemmId::GEMM_2 && Self::mayHaveFinalizeFused(sm)); } void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf, @@ -798,6 +805,12 @@ private: && !use_w4_groupwise; } + static bool mayHaveFinalizeFused(int sm) + { + using RunnerType = decltype(moe_gemm_runner_); + return RunnerType::supportsTmaWarpSpecialized(sm) && sm >= 90 && !use_w4_groupwise; + } + // TODO: This should eventually take the quant params to give more flexibility static auto getScalingType() { @@ -895,12 +908,7 @@ struct GemmProfilerBackend { public: using Config = cutlass_extensions::CutlassGemmConfig; - enum class GemmToProfile - { - Undefined = 0, - GEMM_1, - GEMM_2 - }; + using GemmToProfile = MoeGemmId; void init(CutlassMoeFCRunnerInterface& runner, GemmToProfile gemm_to_profile, nvinfer1::DataType dtype, nvinfer1::DataType wtype, nvinfer1::DataType otype, int num_experts, int k, int64_t hidden_size, @@ -951,7 +959,6 @@ public: CutlassMoeFCRunnerInterface* mInterface; GemmToProfile mGemmToProfile = GemmToProfile::Undefined; - std::vector mAllTacticsSaved; int mSM{}; int64_t mNumExperts{}; int64_t mNumExpertsPerNode{}; @@ -972,7 +979,7 @@ public: // This will be a unique value for every iteration of warmup and actual bench constexpr static int64_t NUM_ROUTING_SAMPLES = 16; - std::array mTmaInputCache; + std::array, NUM_ROUTING_SAMPLES> mTmaInputCache; QuantParams mQuantParams; bool mBias{}; @@ -985,7 +992,8 @@ public: private: void prepareRouting(int num_tokens, char* workspace, cudaStream_t stream); void prepareQuantParams(int num_tokens, char* workspace, cudaStream_t stream); - void prepareTmaWsInputs(int num_tokens, char* workspace, void const* expert_weights, cudaStream_t stream); + void prepareTmaWsInputs(int num_tokens, char* workspace, void const* expert_weights, + TmaWarpSpecializedGroupedGemmInput::EpilogueFusion fusion, cudaStream_t stream); }; // Populates a buffer with random values for use with MOE benchmarking diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h index c44caae0fa..ef06abceee 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h @@ -57,7 +57,6 @@ namespace kernels { namespace cutlass_kernels { - template void genericInt8GemmKernelLauncher(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol, float const* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h index 5a07062f06..899e0787cd 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h @@ -475,17 +475,18 @@ void dispatchMoeGemmToCutlass(GroupedGemmInput -std::vector -MoeGemmRunner::getConfigs() const +std::vector MoeGemmRunner::getConfigs( + bool supports_finalize_fusion) const { - return getConfigs(sm_); + return getConfigs(sm_, supports_finalize_fusion); } template std::vector MoeGemmRunner::getConfigs( - int sm) + int sm, bool supports_finalize_fusion) { - std::vector candidate_configs = getTmaWarpSpecializedConfigs(sm); + std::vector candidate_configs + = getTmaWarpSpecializedConfigs(sm, supports_finalize_fusion); std::vector ampere_configs = getAmpereConfigs(sm); std::copy(ampere_configs.begin(), ampere_configs.end(), std::back_inserter(candidate_configs)); return candidate_configs; @@ -521,7 +522,8 @@ MoeGemmRunner::getAmpereConfigs(int sm template std::vector -MoeGemmRunner::getTmaWarpSpecializedConfigs(int sm) +MoeGemmRunner::getTmaWarpSpecializedConfigs( + int sm, bool supports_finalize_fusion) { using tensorrt_llm::cutlass_extensions::CutlassGemmConfig; static constexpr auto weight_only_flag @@ -568,6 +570,17 @@ MoeGemmRunner::getTmaWarpSpecializedCo = tensorrt_llm::kernels::cutlass_kernels::get_candidate_configs(100, max_split_k, config_type_param); std::copy(sm100_configs.begin(), sm100_configs.end(), std::back_inserter(tma_ws_configs)); } + if (supports_finalize_fusion) + { + // Duplicate the configs and set the epilogue fusion type to FINALIZE + auto finalize_configs = tma_ws_configs; + std::transform(finalize_configs.begin(), finalize_configs.end(), std::back_inserter(tma_ws_configs), + [](auto& config) + { + config.epilogue_fusion_type = cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; + return config; + }); + } return tma_ws_configs; } @@ -580,13 +593,11 @@ bool MoeGemmRunner::isTmaWarpSpecializ } template -bool MoeGemmRunner::supportsTmaWarpSpecialized() const +bool MoeGemmRunner::supportsTmaWarpSpecialized(int sm) { - return (sm_ == 90 && tensorrt_llm::kernels::cutlass_kernels::isValidHopperMOESpecialisation()) - || (sm_ >= 100 && sm_ < 120 - && tensorrt_llm::kernels::cutlass_kernels::isValidBlackwellMOESpecialisation()) - || ((sm_ == 120 || sm_ == 121) - && tensorrt_llm::kernels::cutlass_kernels::isValidSM120MOESpecialisation()); + return (sm == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation()) + || (sm >= 100 && sm < 120 && kernels::cutlass_kernels::isValidBlackwellMOESpecialisation()) + || ((sm == 120 || sm == 121) && kernels::cutlass_kernels::isValidSM120MOESpecialisation()); } template @@ -833,7 +844,9 @@ size_t MoeGemmRunner::calcMaxWorkspace if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation() && !use_w4afp8 && !use_wfp4a16) { - auto configs = getTmaWarpSpecializedConfigs(sm_); + // Finalize fusion may not actually be supported by the kernel, + // if they are not we will catch the error and skip them + auto configs = getTmaWarpSpecializedConfigs(sm_, true); auto fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE; if constexpr (use_wfp4afp8) { diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index 730840717c..ef70b9d45e 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -2847,9 +2847,10 @@ void CutlassMoeFCRunnerepilogue_fusion_type == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; permuted_token_final_scales_ - = (gemm2_using_tma_ws && mayHaveFinalizeFused()) ? getWsPtr(float{}, "permuted_token_final_scales") : nullptr; + = gemm2_using_finalize_fusion ? getWsPtr(float{}, "permuted_token_final_scales") : nullptr; bool const is_gated_activation = isGatedActivation(activation_type); bool const gemm1_using_fused_moe @@ -4006,8 +4007,12 @@ CutlassMoeFCRunner:: bool apply_bias = parallelism_config.tp_rank == 0; auto* fc2_bias = apply_bias ? fc2_expert_biases : nullptr; + bool gemm2_using_finalize_fusion = gemm2_config_->epilogue_fusion_type + == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; bool using_fused_finalize - = use_fused_finalize_ && gemm2_config_->sm_version >= 90 && !use_w4_groupwise && !use_lora; + = use_fused_finalize_ && gemm2_using_finalize_fusion && !use_w4_groupwise && !use_lora; + TLLM_CHECK_WITH_INFO(using_fused_finalize == gemm2_using_finalize_fusion, + "GEMM2 tactic requests finalize fusion, but the runner is not configured to use it"); if (using_fused_finalize) { assert(min_latency_mode == false); @@ -4550,14 +4555,26 @@ void GemmProfilerBackend::prepareQuantParams(int num_tokens, char* workspace_ptr } } -void GemmProfilerBackend::prepareTmaWsInputs( - int num_tokens, char* workspace_ptr_char, void const* expert_weights, cudaStream_t stream) +void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr_char, void const* expert_weights, + TmaWarpSpecializedGroupedGemmInput::EpilogueFusion fusion, cudaStream_t stream) { if (mSM < 90) { return; } + bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4); + bool use_wfp4a16 = ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16) + && mWType == nvinfer1::DataType::kUINT8); + bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; + bool const use_finalize_fusion = fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE; + bool const finalize_fusion_not_supported = !mInterface->use_fused_finalize_ || mMinLatencyMode || use_w4_groupwise + || mGemmToProfile != GemmToProfile::GEMM_2; + if (use_finalize_fusion && finalize_fusion_not_supported) + { + return; + } + auto workspaces = getProfilerWorkspaces(num_tokens, mSM >= 90); #define GET_WS_PTR(type, name) \ @@ -4596,28 +4613,24 @@ void GemmProfilerBackend::prepareTmaWsInputs( size_t num_expanded_tokens = num_tokens * mK; for (int64_t i = 0; i < NUM_ROUTING_SAMPLES; i++) { - mTmaInputCache[i].configureWorkspace(tma_ws_input_workspace, mNumExpertsPerNode, gemm_workspace, + // Note: Even though we have separate TMA WS inputs for finalize fusion on/off we reuse the same pointers to + // save space. + auto& cache_element = mTmaInputCache[i][use_finalize_fusion]; + cache_element.configureWorkspace(tma_ws_input_workspace, mNumExpertsPerNode, gemm_workspace, workspaces.at("gemm_workspace").first, mScalingType); tma_ws_input_workspace += tma_ws_size; int64_t* expert_first_token_offset = expert_first_token_offset_base + i * (mNumExpertsPerNode + 1); int* permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_base + i * num_expanded_tokens; - auto& gemm1_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_1 ? mTmaInputCache[i] : dummy_tma_ws_input; - auto& gemm2_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_2 ? mTmaInputCache[i] : dummy_tma_ws_input; + auto& gemm1_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_1 ? cache_element : dummy_tma_ws_input; + auto& gemm2_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_2 ? cache_element : dummy_tma_ws_input; if (mSM >= 90) { /* GEMM1 */ gemm1_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; - - bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4); - bool use_wfp4a16 = ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16) - && mWType == nvinfer1::DataType::kUINT8); - bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; - bool using_fused_finalize - = mInterface->use_fused_finalize_ && mSM >= 90 && !mMinLatencyMode && !use_w4_groupwise; - if (using_fused_finalize) + if (use_finalize_fusion) { assert(!mMinLatencyMode); gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE; @@ -4652,7 +4665,6 @@ void GemmProfilerBackend::prepareTmaWsInputs( void GemmProfilerBackend::prepare( int num_tokens, char* workspace_ptr_char, void const* expert_weights, cudaStream_t stream) { - mAllTacticsSaved = mInterface->getTactics(); mSampleIndex = 0; auto workspace_size = getWorkspaceSize(num_tokens); @@ -4660,7 +4672,10 @@ void GemmProfilerBackend::prepare( prepareRouting(num_tokens, workspace_ptr_char, stream); prepareQuantParams(num_tokens, workspace_ptr_char, stream); - prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights, stream); + prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights, + TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE, stream); + prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights, + TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE, stream); } size_t GemmProfilerBackend::getWorkspaceSize(int maxM) @@ -4724,7 +4739,9 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac TmaWarpSpecializedGroupedGemmInput tma_ws_input_template; if (tactic.is_tma_warp_specialized) { - tma_ws_input_template = mTmaInputCache[mSampleIndex]; + tma_ws_input_template = mTmaInputCache[mSampleIndex][tactic.epilogue_fusion_type + == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE]; + TLLM_CHECK_WITH_INFO(tma_ws_input_template.isValid(), "TMA WS input template is not initialized"); } mInterface->is_profiler = true; diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz index 08cd9b6f66..5ebd5f7ebe 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:86586b9f6845e91e8ba0accad53a5a3418c50d8fd30ad49fa8837470c72b5dcf -size 67051604 +oid sha256:d6a3f6adef11003f794a6cec1235d0c622ead71b4e801a89866e91dfd91bb30c +size 67053244 diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt index 8b500f5c97..b93f46ea6d 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt @@ -1,2 +1,2 @@ -568cb6ca2413c93b0f5839dd05577c0c57bc4b5f2359366c79d0ace665de4bd6 libtensorrt_llm_internal_cutlass_kernels_static.a -commit 9c0a42825905952beaf9b35d5a35d58de1a123fa +317a25037093a6f3d156ffa58a68bce53071ef68dacdcb04cc0aaeea80b64e76 libtensorrt_llm_internal_cutlass_kernels_static.a +commit 444ef1b3b06cdc7ee66b4e612ce26ad25967440b diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz index f1a6b9dc88..bd07528460 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:6489751f16a4dadf42664738ded03fbbd60195619f2d5f80af8190554318257d -size 66872936 +oid sha256:489fb557b78062efedd1514f2995fafb9216bb0e0068a550e86763efb9d5eee9 +size 66874608 diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt index 4af58b0800..3c053c1a91 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt @@ -1,2 +1,2 @@ -813c237a565664b2acf2313f0e436f66f24deeb16a84d273dc007af55795e55f libtensorrt_llm_internal_cutlass_kernels_static.a -commit 9c0a42825905952beaf9b35d5a35d58de1a123fa +5a31acd0fb1415196bff71fa4a8d1dded147e15ea10821cc46c85684c66986ee libtensorrt_llm_internal_cutlass_kernels_static.a +commit 444ef1b3b06cdc7ee66b4e612ce26ad25967440b diff --git a/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh b/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh new file mode 100644 index 0000000000..933b599dbd --- /dev/null +++ b/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh @@ -0,0 +1,205 @@ + +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#ifndef TRTLLM_MOETOPKFUNCS_CUH_H +#define TRTLLM_MOETOPKFUNCS_CUH_H + +#include +#include +#include + +#include "tensorrt_llm/kernels/archCondition.h" + +namespace tensorrt_llm::kernels +{ + +namespace reduce_topk +{ +namespace cg = cooperative_groups; +static constexpr int kWARP_SIZE = 32; +static constexpr bool kTLLM_GEN_HAS_FAST_REDUX = tensorrt_llm::kernels::arch::is_major_v<10>; + +template +struct TopKRedType +{ + using T = T_; + static_assert(std::is_same_v || std::is_same_v || std::is_same_v + || std::is_same_v, + "Top K reduction only implemented for int, float, float16 and bfloat16"); + + using TypeCmp = std::conditional_t; + using IdxT = std::conditional_t; + + static constexpr int kMoveBits = (sizeof(T) == 4) ? 32 : 16; + static constexpr int kMaxIdx = 65535; + TypeCmp compValIdx; + + static __host__ __device__ inline TypeCmp makeCmpVal(T val, int32_t idx = 0) + { + auto valueBits = cub::Traits::TwiddleIn(reinterpret_cast::UnsignedBits&>(val)); + TypeCmp compactTmp = reinterpret_cast(valueBits); + compactTmp = (compactTmp << kMoveBits) | (0xFFFF & (kMaxIdx - idx)); + // Use 65535 minus idx to give higher priority to elements with smaller indices. + return compactTmp; + } + + static __host__ __device__ void unpack(T& value, int32_t& index, TypeCmp cmp) + { + // Since “65535-idx” is always smaller than 65536 and positive, we can directly use it as the lower 16 bits + index = kMaxIdx - static_cast((cmp & 0xFFFF)); + + auto compactTmp = cmp >> kMoveBits; + auto valueBits + = cub::Traits::TwiddleOut(reinterpret_cast::UnsignedBits&>(compactTmp)); + value = reinterpret_cast(valueBits); + } + + __host__ __device__ TopKRedType() = default; + + __host__ __device__ TopKRedType(T val, int32_t idx) + : compValIdx(makeCmpVal(val, idx)) + { + } + + __host__ __device__ operator TypeCmp() const noexcept + { + return compValIdx; + } + + __device__ inline TypeCmp reduce(cg::thread_block_tile const& warp) + { + if constexpr (!kTLLM_GEN_HAS_FAST_REDUX || sizeof(TypeCmp) == 8) + { + return cg::reduce(warp, compValIdx, cg::greater{}); + } + else + { + TypeCmp result; + asm("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(result) : "r"(compValIdx)); + return result; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TopKIdx +{ + // by default, empty +}; + +template +struct TopKIdx +{ + static constexpr int K = K_; + int32_t val[K]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define TOPK_SWAP(I, J) \ + { \ + auto pairMin = min(topK[I].compValIdx, topK[J].compValIdx); \ + auto pairMax = max(topK[I].compValIdx, topK[J].compValIdx); \ + topK[I].compValIdx = pairMax; \ + topK[J].compValIdx = pairMin; \ + } + +template +struct Sort; + +template +struct Sort<1, RedType> +{ + static __device__ void run(RedType* topK) {} +}; + +template +struct Sort<2, RedType> +{ + static __device__ void run(RedType* topK) + { + TOPK_SWAP(0, 1); + } +}; + +template +struct Sort<3, RedType> +{ + static __device__ void run(RedType* topK) + { + TOPK_SWAP(0, 1); + TOPK_SWAP(1, 2); + TOPK_SWAP(0, 1); + } +}; + +template +struct Sort<4, RedType> +{ + static __device__ void run(RedType* topK) + { + TOPK_SWAP(0, 2); + TOPK_SWAP(1, 3); + TOPK_SWAP(0, 1); + TOPK_SWAP(2, 3); + TOPK_SWAP(1, 2); + } +}; + +template +__device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], int32_t (&outIdx)[K], + Type (&value)[N], int32_t (&idx)[N], Type minValue) +{ + static_assert(K > 0, "Top K must have K > 0"); + static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE"); + static_assert(N > 0, "Top K must have N > 0"); + static_assert(N < 5, "Only support candidates number less than or equal to 128"); + using RedType = TopKRedType; + RedType topK[N]; +#pragma unroll + for (int nn = 0; nn < N; ++nn) + { + topK[nn] = RedType{value[nn], idx[nn]}; + } + + if constexpr (!IsSorted) + { + Sort::run(topK); + } + typename RedType::TypeCmp packedMax{}; +#pragma unroll + for (int kk = 0; kk < K; ++kk) + { + bool update = kk > 0 && packedMax == topK[0].compValIdx; +#pragma unroll + for (int nn = 0; nn < N; ++nn) + { + topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn]; + } + // get the next largest value + packedMax = topK[0].reduce(warp); + RedType::unpack(out[kk], outIdx[kk], packedMax); + } +}; + +#undef TOPK_SWAP + +} // namespace reduce_topk +} // namespace tensorrt_llm::kernels +#endif // TRTLLM_MOETOPKFUNCS_CUH_H diff --git a/cpp/tensorrt_llm/kernels/renormMoeRoutingKernels.cu b/cpp/tensorrt_llm/kernels/renormMoeRoutingKernels.cu deleted file mode 100644 index 1b4239e48c..0000000000 --- a/cpp/tensorrt_llm/kernels/renormMoeRoutingKernels.cu +++ /dev/null @@ -1,376 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/common/cudaTypeUtils.cuh" -#include "tensorrt_llm/common/envUtils.h" -#include "tensorrt_llm/kernels/archCondition.h" -#include "tensorrt_llm/kernels/renormMoeRoutingKernels.h" -#include // For INT_MAX -#include -#include -#include -#include // For numeric_limits -#include - -namespace cg = cooperative_groups; -using namespace tensorrt_llm::common; - -namespace tensorrt_llm::kernels -{ - -static constexpr int BLOCK_SIZE = 1024; -static constexpr int WARP_SIZE = 32; -static constexpr int WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE; - -namespace reduce_topk -{ - -static constexpr bool TLLM_GEN_HAS_FAST_REDUX = tensorrt_llm::kernels::arch::is_major_v<10>; - -template -struct TopKRedType -{ - using T = T_; - static_assert(std::is_same_v || std::is_same_v || std::is_same_v, - "Top K reduction only implemented for float, float16 and bfloat16"); - - using TypeCmp = std::conditional_t; - using IdxT = std::conditional_t; - static constexpr int moveBits = (sizeof(T) == 4) ? 32 : 16; - static constexpr int maxIdx = 65535; - TypeCmp compValIdx; - - static __host__ __device__ inline TypeCmp makeCmpVal(T val, int32_t idx = 0) - { - auto valueBits = cub::Traits::TwiddleIn(reinterpret_cast::UnsignedBits&>(val)); - TypeCmp compactTmp = reinterpret_cast(valueBits); - compactTmp = (compactTmp << moveBits) | (0xFFFF & (maxIdx - idx)); - // Use 65535 minus idx to give higher priority to elements with smaller indices. - return compactTmp; - } - - static __host__ __device__ void unpack(T& value, int32_t& index, TypeCmp cmp) - { - // Since “65535-idx” is always smaller than 65536 and positive, we can directly use it as the lower 16 bits - index = maxIdx - static_cast((cmp & 0xFFFF)); - - auto compactTmp = cmp >> moveBits; - auto valueBits - = cub::Traits::TwiddleOut(reinterpret_cast::UnsignedBits&>(compactTmp)); - value = reinterpret_cast(valueBits); - } - - __host__ __device__ TopKRedType() = default; - - __host__ __device__ TopKRedType(T val, int32_t idx) - : compValIdx(makeCmpVal(val, idx)) - { - } - - __host__ __device__ operator TypeCmp() const noexcept - { - return compValIdx; - } - - __device__ inline TypeCmp reduce(cg::thread_block_tile const& warp) - { - if constexpr (!TLLM_GEN_HAS_FAST_REDUX || sizeof(TypeCmp) == 8) - { - return cg::reduce(warp, compValIdx, cg::greater{}); - } - else - { - TypeCmp result; - asm("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(result) : "r"(compValIdx)); - return result; - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TopKIdx -{ - // by default, empty -}; - -template -struct TopKIdx -{ - static constexpr int K = K_; - int32_t val[K]; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define TOPK_SWAP(I, J) \ - { \ - auto pairMin = min(topK[I].compValIdx, topK[J].compValIdx); \ - auto pairMax = max(topK[I].compValIdx, topK[J].compValIdx); \ - topK[I].compValIdx = pairMax; \ - topK[J].compValIdx = pairMin; \ - } - -template -struct Sort; - -template -struct Sort<1, RedType> -{ - static __device__ void run(RedType* topK) {} -}; - -template -struct Sort<2, RedType> -{ - static __device__ void run(RedType* topK) - { - TOPK_SWAP(0, 1); - } -}; - -template -struct Sort<3, RedType> -{ - static __device__ void run(RedType* topK) - { - TOPK_SWAP(0, 1); - TOPK_SWAP(1, 2); - TOPK_SWAP(0, 1); - } -}; - -template -struct Sort<4, RedType> -{ - static __device__ void run(RedType* topK) - { - TOPK_SWAP(0, 2); - TOPK_SWAP(1, 3); - TOPK_SWAP(0, 1); - TOPK_SWAP(2, 3); - TOPK_SWAP(1, 2); - } -}; - -template -__device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], int32_t (&outIdx)[K], - Type (&value)[N], int32_t (&idx)[N], Type minValue) -{ - static_assert(K > 0, "Top K must have K > 0"); - static_assert(K < WARP_SIZE, "Top K must have K < WARP_SIZE"); - static_assert(N > 0, "Top K must have N > 0"); - static_assert(N < 5, "Only support candidates number less than or equal to 128"); - using RedType = TopKRedType; - RedType topK[N]; -#pragma unroll - for (int nn = 0; nn < N; ++nn) - { - topK[nn] = RedType{value[nn], idx[nn]}; - } - - if constexpr (!IsSorted) - { - Sort::run(topK); - } - typename RedType::TypeCmp packedMax{}; -#pragma unroll - for (int kk = 0; kk < K; ++kk) - { - bool update = kk > 0 && packedMax == topK[0].compValIdx; -#pragma unroll - for (int nn = 0; nn < N; ++nn) - { - topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn]; - } - // get the next largest value - packedMax = topK[0].reduce(warp); - RedType::unpack(out[kk], outIdx[kk], packedMax); - } -}; - -#undef TOPK_SWAP - -} // end of namespace reduce_topk - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ T calcSoftmax(cg::thread_block_tile const& warp, T score, int32_t laneIdx, int32_t NumTopExperts) -{ - T maxScore = T{-INFINITY}; - if (laneIdx < NumTopExperts) - { - maxScore = score >= maxScore ? score : maxScore; - } - maxScore = cg::reduce(warp, maxScore, cg::greater()); - - float sumScore = float{0.f}; - float newScore; - // Get the summation of scores for each token - if (laneIdx < NumTopExperts) - { - newScore = static_cast(score) - static_cast(maxScore); - newScore = static_cast(exp(newScore)); - sumScore += newScore; - } - sumScore = cg::reduce(warp, sumScore, cg::plus()); - - if (laneIdx < NumTopExperts) - { - score = static_cast(newScore / sumScore); - } - - return score; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__global__ void renormMoeRoutingKernel(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, - int32_t const numTokens, int32_t const numExperts, int32_t const topK) -{ - - uint32_t const blockRank = blockIdx.x; - uint32_t const tIdx = BLOCK_SIZE * blockRank + threadIdx.x; - uint32_t const warpIdx = tIdx / WARP_SIZE; - uint32_t const laneIdx = tIdx % WARP_SIZE; - uint32_t const warpNum = gridDim.x * WARPS_PER_BLOCK; - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - - InputT minScore = InputT{-INFINITY}; - for (uint32_t tokenId = warpIdx; tokenId < numTokens; tokenId += warpNum) - { - auto scoreOffset = tokenId * numExperts; - auto outputOffset = tokenId * topK; - InputT inputScore[MaxNumExperts / WARP_SIZE]; - IdxT inputIndex[MaxNumExperts / WARP_SIZE]; - - InputT warpTopKScore[MaxNumTopExperts]; - IdxT warpTopKExpertIdx[MaxNumTopExperts]; - - // Load scores and indices for this warp - for (uint32_t i = 0; i < MaxNumExperts / WARP_SIZE; ++i) - { - auto expertIdx = i * WARP_SIZE + laneIdx; - inputScore[i] - = expertIdx < numExperts ? static_cast(routerLogits[scoreOffset + expertIdx]) : minScore; - inputIndex[i] = expertIdx; - } - - // Reduce topK scores and indices for this warp - reduce_topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, inputScore, inputIndex, minScore); - - // Perform softmax on topK scores - auto score = calcSoftmax(warp, - laneIdx < topK ? static_cast(warpTopKScore[laneIdx]) : static_cast(minScore), laneIdx, topK); - if (laneIdx < topK) - { - topkValues[outputOffset + laneIdx] = static_cast(score); - topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx]; - } - } // end for tokenId -} - -int nextPowerOfTwo(int num) -{ - if (num <= 0) - { - return 1; // Handle invalid input - } - int power = 1; - while (power < num) - { - // Check for overflow before shifting - if (power > INT_MAX / 2) - { - return power; - } - power <<= 1; - } - return power; -} - -#define CASE(MAX_NUM_EXPERTS) \ - case MAX_NUM_EXPERTS: \ - switch (maxNumTopExperts) \ - { \ - case 1: kernelInstance = &renormMoeRoutingKernel; break; \ - case 2: kernelInstance = &renormMoeRoutingKernel; break; \ - case 4: kernelInstance = &renormMoeRoutingKernel; break; \ - case 8: kernelInstance = &renormMoeRoutingKernel; break; \ - default: kernelInstance = nullptr; break; \ - } \ - break; - -template -void invokeRenormMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens, - int64_t const numExperts, int64_t const topK, cudaStream_t const stream) -{ - - const uint32_t maxNumBlocks = 1024; - const uint32_t numBlocks = std::min(static_cast((numTokens - 1) / WARPS_PER_BLOCK + 1), maxNumBlocks); - - uint32_t maxNumExperts = nextPowerOfTwo(numExperts) < 32 ? 32 : nextPowerOfTwo(numExperts); - uint32_t maxNumTopExperts = nextPowerOfTwo(topK); - - auto* kernelInstance = &renormMoeRoutingKernel; - - switch (maxNumExperts) - { - CASE(32) - CASE(64) - CASE(96) - CASE(128) - default: kernelInstance = nullptr; break; - } - - if (kernelInstance == nullptr) - { - TLLM_CHECK_WITH_INFO(kernelInstance != nullptr, "Can not find corresponding kernel instance."); - } - - dim3 renormMoeRoutingGridDim(numBlocks); - dim3 renormMoeRoutingBlockDim(BLOCK_SIZE); - cudaLaunchConfig_t config; - config.gridDim = renormMoeRoutingGridDim; - config.blockDim = renormMoeRoutingBlockDim; - config.dynamicSmemBytes = 0; - config.stream = stream; - cudaLaunchAttribute attrs[1]; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); - config.numAttrs = 1; - config.attrs = attrs; - cudaLaunchKernelEx(&config, kernelInstance, routerLogits, topkValues, topkIndices, static_cast(numTokens), - static_cast(numExperts), static_cast(topK)); - sync_check_cuda_error(stream); -} - -#define INSTANTIATE_RENORM_MOE_ROUTING(InputT, OutputT, IdxT) \ - template void invokeRenormMoeRouting(InputT * routerLogits, OutputT * topkValues, \ - IdxT * topkIndices, int64_t const numTokens, int64_t const numExperts, int64_t const topK, \ - cudaStream_t const stream); - -INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t); -INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t); -#ifdef ENABLE_BF16 -INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t); -#endif - -} // namespace tensorrt_llm::kernels diff --git a/cpp/tensorrt_llm/kernels/topkLastDim.cu b/cpp/tensorrt_llm/kernels/topkLastDim.cu index b13cd00b8f..e6e4e82c92 100644 --- a/cpp/tensorrt_llm/kernels/topkLastDim.cu +++ b/cpp/tensorrt_llm/kernels/topkLastDim.cu @@ -22,11 +22,17 @@ */ #include +#include "moeTopKFuncs.cuh" #include "topkLastDim.h" +#include +#include #include #include +#include +#include #include #include +#include namespace tensorrt_llm { @@ -203,12 +209,12 @@ __host__ __device__ IdxT calc_buf_len(IdxT len) * @param len the number of elements to read * @param f the lambda taking two arguments (T x, IdxT idx) */ -template -__device__ void vectorized_process(size_t thread_rank, size_t num_threads, T const* in, idxT len, Func f) +template +__device__ void vectorized_process(size_t thread_rank, size_t num_threads, T const* in, IdxT len, Func f) { if constexpr (sizeof(T) >= sizeof(WideT)) { - for (idxT i = thread_rank; i < len; i += num_threads) + for (IdxT i = thread_rank; i < len; i += num_threads) { f(in[i], i); } @@ -233,12 +239,12 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, T con skip_cnt = len; } WideT const* in_cast = reinterpret_cast(in + skip_cnt); - const idxT len_cast = (len - skip_cnt) / items_per_scalar; + const IdxT len_cast = (len - skip_cnt) / items_per_scalar; - for (idxT i = thread_rank; i < len_cast; i += num_threads) + for (IdxT i = thread_rank; i < len_cast; i += num_threads) { wide.scalar = in_cast[i]; - const idxT real_i = skip_cnt + i * items_per_scalar; + const IdxT real_i = skip_cnt + i * items_per_scalar; #pragma unroll for (int j = 0; j < items_per_scalar; ++j) { @@ -258,7 +264,7 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, T con // and so // len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= WARP_SIZE // no need to use loop - const idxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank; + const IdxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank; if (remain_i < len) { f(in[remain_i], remain_i); @@ -267,14 +273,14 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, T con } // sync_width should >= WARP_SIZE -template -__device__ void vectorized_process(T const* in, idxT len, Func f, int sync_width) +template +__device__ void vectorized_process(T const* in, IdxT len, Func f, int sync_width) { - const idxT stride = blockDim.x * gridDim.x; - const idxT tid = blockIdx.x * blockDim.x + threadIdx.x; + const IdxT stride = blockDim.x * gridDim.x; + const IdxT tid = blockIdx.x * blockDim.x + threadIdx.x; if constexpr (sizeof(T) >= sizeof(WideT)) { - for (idxT i = tid; i < len; i += stride) + for (IdxT i = tid; i < len; i += stride) { f(in[i], i, true); } @@ -298,17 +304,17 @@ __device__ void vectorized_process(T const* in, idxT len, Func f, int sync_width skip_cnt = len; } WideT const* in_cast = reinterpret_cast(in + skip_cnt); - const idxT len_cast = (len - skip_cnt) / items_per_scalar; + const IdxT len_cast = (len - skip_cnt) / items_per_scalar; - const idxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width; - for (idxT i = tid; i < len_cast_for_sync; i += stride) + const IdxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width; + for (IdxT i = tid; i < len_cast_for_sync; i += stride) { bool valid = i < len_cast; if (valid) { wide.scalar = in_cast[i]; } - const idxT real_i = skip_cnt + i * items_per_scalar; + const IdxT real_i = skip_cnt + i * items_per_scalar; #pragma unroll for (int j = 0; j < items_per_scalar; ++j) { @@ -325,7 +331,7 @@ __device__ void vectorized_process(T const* in, idxT len, Func f, int sync_width T value = valid ? in[tid] : T(); f(value, tid, valid); - const idxT remain_i = skip_cnt + len_cast * items_per_scalar + tid; + const IdxT remain_i = skip_cnt + len_cast * items_per_scalar + tid; valid = remain_i < len; value = valid ? in[remain_i] : T(); f(value, remain_i, valid); @@ -1166,6 +1172,77 @@ __global__ void radix_topk_one_block_kernel(T const* in, IdxT const* in_idx, con } // namespace air_topk_stable //} +namespace moe_topk +{ +namespace cg = cooperative_groups; +static constexpr int kBLOCK_SIZE = 1024; +static constexpr int kWARP_SIZE = 32; +static constexpr int kWARPS_PER_BLOCK = kBLOCK_SIZE / kWARP_SIZE; + +template +__device__ T negativeInfinity() +{ + return -INFINITY; +} + +template <> +__device__ half negativeInfinity() +{ + return -CUDART_INF_FP16; +} + +template <> +__device__ __nv_bfloat16 negativeInfinity<__nv_bfloat16>() +{ + return -CUDART_INF_BF16; +} + +/****************TopK kernel for candidate number<= 128 and K <= 8 **************** */ +template +__global__ void moe_topk_kernel( + InputT const* in, OutputT* out, IdxT* outIdx, int32_t const batchSize, int32_t const len, int32_t const topK) +{ + + uint32_t const blockRank = blockIdx.x; + uint32_t const tIdx = kBLOCK_SIZE * blockRank + threadIdx.x; + uint32_t const warpIdx = tIdx / kWARP_SIZE; + uint32_t const laneIdx = tIdx % kWARP_SIZE; + uint32_t const warpNum = gridDim.x * kWARPS_PER_BLOCK; + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + InputT minScore = negativeInfinity(); + + for (uint32_t tokenId = warpIdx; tokenId < batchSize; tokenId += warpNum) + { + auto scoreOffset = tokenId * len; + auto outputOffset = tokenId * topK; + InputT inputScore[MaxLen / kWARP_SIZE]; + IdxT inputIndex[MaxLen / kWARP_SIZE]; + + InputT warpTopKScore[MaxTopK]; + IdxT warpTopKExpertIdx[MaxTopK]; + + // Load scores and indices for this warp + for (uint32_t i = 0; i < MaxLen / kWARP_SIZE; ++i) + { + auto expertIdx = i * kWARP_SIZE + laneIdx; + inputScore[i] = expertIdx < len ? static_cast(in[scoreOffset + expertIdx]) : minScore; + inputIndex[i] = expertIdx; + } + + // Reduce topK scores and indices for this warp + tensorrt_llm::kernels::reduce_topk::reduceTopK( + warp, warpTopKScore, warpTopKExpertIdx, inputScore, inputIndex, minScore); + + if (laneIdx < topK) + { + out[outputOffset + laneIdx] = static_cast(warpTopKScore[laneIdx]); + outIdx[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx]; + } + } // end for tokenId +} +} // namespace moe_topk /***************Runtime API****************/ @@ -1223,9 +1300,11 @@ void standalone_stable_radix_topk_(void* buf, size_t& buf_size, T const* in, Idx IdxT* sort_in_idx = nullptr; air_topk_stable::ComputeOffset computeoffset(k); + thrust::counting_iterator counting_iter(0); thrust::transform_iterator, thrust::counting_iterator> transform_iter( counting_iter, computeoffset); + cub::DeviceSegmentedSort::SortPairs(NULL, temp_storage_bytes, out_idx, out_idx, out, out, k * batch_size, batch_size, transform_iter, transform_iter + 1, stream); if (sorted) @@ -1277,8 +1356,8 @@ void standalone_stable_radix_topk_(void* buf, size_t& buf_size, T const* in, Idx sort_in = static_cast(aligned_pointers[9]); sort_in_idx = static_cast(aligned_pointers[10]); } - cudaMemsetAsync( - buf, 0, static_cast(aligned_pointers[2]) - static_cast(aligned_pointers[0]), stream); + cudaMemsetAsync(aligned_pointers[0], 0, + static_cast(aligned_pointers[2]) - static_cast(aligned_pointers[0]), stream); } T const* in_buf = nullptr; @@ -1423,36 +1502,120 @@ void standalone_stable_radix_topk_one_block_(void* buf, size_t& buf_size, T cons } } -template -void standalone_stable_radix_11bits(void* buf, size_t& buf_size, T const* in, int batch_size, idxT len, idxT k, T* out, - idxT* out_idx, bool greater, cudaStream_t stream = 0) +template +void standalone_stable_radix_11bits(void* buf, size_t& buf_size, T const* in, int batch_size, IdxT len, IdxT k, T* out, + IdxT* out_idx, bool greater, cudaStream_t stream = 0) { constexpr int items_per_thread = 32; constexpr int block_dim = 512; constexpr bool fused_last_filter = false; if (len <= block_dim * items_per_thread) { - standalone_stable_radix_topk_one_block_( - buf, buf_size, in, static_cast(nullptr), batch_size, len, k, out, out_idx, !greater, stream, sorted); + standalone_stable_radix_topk_one_block_( + buf, buf_size, in, static_cast(nullptr), batch_size, len, k, out, out_idx, !greater, stream, sorted); } else { int sm_cnt = tensorrt_llm::common::getMultiProcessorCount(); - unsigned grid_dim = air_topk_stable::calc_grid_dim(batch_size, len, sm_cnt); + unsigned grid_dim = air_topk_stable::calc_grid_dim(batch_size, len, sm_cnt); if (grid_dim == 1) { - standalone_stable_radix_topk_one_block_(buf, buf_size, in, - static_cast(nullptr), batch_size, len, k, out, out_idx, !greater, stream, sorted); + standalone_stable_radix_topk_one_block_(buf, buf_size, in, + static_cast(nullptr), batch_size, len, k, out, out_idx, !greater, stream, sorted); } else { - standalone_stable_radix_topk_(buf, buf_size, in, static_cast(nullptr), + standalone_stable_radix_topk_(buf, buf_size, in, static_cast(nullptr), batch_size, len, k, out, out_idx, !greater, fused_last_filter, grid_dim, stream, sorted); } } } +int nextPowerOfTwo(int num) +{ + if (num <= 0) + { + return 1; // Handle invalid input + } + int power = 1; + while (power < num) + { + // Check for overflow before shifting + if (power > INT_MAX / 2) + { + return power; + } + power <<= 1; + } + return power; +} + +template +void moe_reduce_topk( + T const* in, int batch_size, IdxT len, IdxT k, T* out, IdxT* out_idx, bool greater, cudaStream_t stream = 0) +{ + using InputT = T; + using OutputT = T; + const uint32_t max_num_blocks = 1024; + const uint32_t num_blocks + = std::min(static_cast((batch_size - 1) / moe_topk::kWARPS_PER_BLOCK + 1), max_num_blocks); + + uint32_t max_len = nextPowerOfTwo(len) < 32 ? 32 : nextPowerOfTwo(len); + uint32_t moe_topk = nextPowerOfTwo(k); + + auto* kernel_instance = &moe_topk::moe_topk_kernel; + + switch (max_len) + { + case 32: + switch (moe_topk) + { + case 1: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 2: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 4: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 8: kernel_instance = &moe_topk::moe_topk_kernel; break; + default: kernel_instance = nullptr; break; + } + break; + case 64: + switch (moe_topk) + { + case 1: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 2: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 4: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 8: kernel_instance = &moe_topk::moe_topk_kernel; break; + default: kernel_instance = nullptr; break; + } + break; + case 96: + switch (moe_topk) + { + case 1: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 2: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 4: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 8: kernel_instance = &moe_topk::moe_topk_kernel; break; + default: kernel_instance = nullptr; break; + } + break; + case 128: + switch (moe_topk) + { + case 1: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 2: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 4: kernel_instance = &moe_topk::moe_topk_kernel; break; + case 8: kernel_instance = &moe_topk::moe_topk_kernel; break; + default: kernel_instance = nullptr; break; + } + break; + default: kernel_instance = nullptr; break; + } + + dim3 moe_topk_grid_dim(num_blocks); + dim3 moe_topk_block_dim(moe_topk::kBLOCK_SIZE); + + kernel_instance<<>>(in, out, out_idx, batch_size, len, k); +} #endif /////////////// @@ -1461,22 +1624,22 @@ template size_t invokeComputeTopkLastDimWorkspaceSize( SizeType32 batchSize, SizeType32 inputLength, SizeType32 k, bool is_largest) { - using idxT = SizeType32; + using IdxT = SizeType32; size_t buf_size = 0; void* workspace = nullptr; T const* in = nullptr; T* out_val = nullptr; - idxT* out_idx = nullptr; + IdxT* out_idx = nullptr; constexpr int block_dim = 512; constexpr bool fused_last_filter = false; constexpr bool sorted = true; int sm_cnt = tensorrt_llm::common::getMultiProcessorCount(); - unsigned grid_dim = air_topk_stable::calc_grid_dim(batchSize, inputLength, sm_cnt); + unsigned grid_dim = air_topk_stable::calc_grid_dim(batchSize, inputLength, sm_cnt); - standalone_stable_radix_topk_(workspace, buf_size, in, static_cast(nullptr), + standalone_stable_radix_topk_(workspace, buf_size, in, static_cast(nullptr), batchSize, inputLength, k, out_val, out_idx, !is_largest, fused_last_filter, grid_dim, 0, sorted); return buf_size; } @@ -1506,8 +1669,17 @@ void invokeTopkLastDim(SizeType32 batchSize, SizeType32 inputLength, SizeType32 T const* in = reinterpret_cast(input); T* out_val_ = reinterpret_cast(out_val); SizeType32* out_idx_ = reinterpret_cast(out_idx); - standalone_stable_radix_11bits( - workspace, buf_size, in, batchSize, inputLength, k, out_val_, out_idx_, is_largest, stream); + if (inputLength <= 128 && k <= 8 && is_largest == true) + { + // This method does not require a buffer, but since the implementation may vary in different cases, + // we still allocate the buffer in case AIR TopK is used instead. + moe_reduce_topk(in, batchSize, inputLength, k, out_val_, out_idx_, !is_largest, stream); + } + else + { + standalone_stable_radix_11bits( + workspace, buf_size, in, batchSize, inputLength, k, out_val_, out_idx_, is_largest, stream); + } } #define INSTANTIATE_TOPK_LastDim_DATA_TYPE(T) \ diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh index 750658fad7..92d020fd19 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh @@ -378,7 +378,7 @@ __device__ void routingPermutation(KernelParams params, PackedScoreIdx // We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx, // mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens // TODO: this is not sufficient to ensure visibility in the next kernel! -#if !defined(PDL_PROFILE) || PDL_PROFILE == 0 +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) if constexpr (KernelParams::UsePdl) { cudaTriggerProgrammaticLaunchCompletion(); @@ -757,15 +757,13 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -// Trigger secondary kernel. -// Note: this does not guarantee the visibility of prior writes unless the consumer executes a -// dependency sync. -#if !defined(PDL_PROFILE) || PDL_PROFILE == 0 + // Trigger secondary kernel. + // Note: this does not guarantee the visibility of prior writes unless the consumer executes a + // dependency sync. if constexpr (KernelParams::UsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } -#endif #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) } diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu index f1f60abdc2..5c39892039 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu @@ -227,13 +227,11 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -#if !defined(PDL_PROFILE) || PDL_PROFILE == 0 // we can trigger the next kernel at this point if constexpr (KernelParams::UsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } -#endif #endif // at this point, all values for offsets are ready, except the final offsets diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu index f03e02c2e2..f6364e0cc9 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu @@ -199,13 +199,11 @@ __global__ void __launch_bounds__(NumThreadsSingleBlock) routingIndicesBlockKern } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -#if !defined(PDL_PROFILE) || PDL_PROFILE == 0 // we can trigger the next kernel at this point if constexpr (KernelParams::UsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } -#endif #endif for (int tokenIdx = 0; tokenIdx < params.mNumTokens; tokenIdx++) diff --git a/cpp/tensorrt_llm/nanobind/CMakeLists.txt b/cpp/tensorrt_llm/nanobind/CMakeLists.txt index 1ccb50a02b..8c62584108 100755 --- a/cpp/tensorrt_llm/nanobind/CMakeLists.txt +++ b/cpp/tensorrt_llm/nanobind/CMakeLists.txt @@ -43,7 +43,7 @@ target_link_libraries( ${Python3_LIBRARIES} ${TORCH_LIBRARIES} torch_python - CUDA::cuda_driver + ${CUDA_DRV_LIB} ${CUDA_NVML_LIB} th_common) target_compile_definitions( diff --git a/cpp/tensorrt_llm/nanobind/common/customCasters.h b/cpp/tensorrt_llm/nanobind/common/customCasters.h index a77a3bcb5a..432ce5c26b 100644 --- a/cpp/tensorrt_llm/nanobind/common/customCasters.h +++ b/cpp/tensorrt_llm/nanobind/common/customCasters.h @@ -285,5 +285,35 @@ struct type_caster>> return make_caster>::from_cpp(result, policy, cleanup); } }; + +template <> +struct type_caster +{ + NB_TYPE_CASTER(torch::ScalarType, const_name("torch.dtype")); + + bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept + { + std::string dtype_name = nb::cast(nb::str(src)); + if (dtype_name.substr(0, 6) == "torch.") + { + dtype_name = dtype_name.substr(6); + } + + auto const& dtype_map = c10::getStringToDtypeMap(); + auto it = dtype_map.find(dtype_name); + if (it != dtype_map.end()) + { + value = it->second; + return true; + } + + return false; + } + + static handle from_cpp(torch::ScalarType src, rv_policy policy, cleanup_list* cleanup) + { + throw std::runtime_error("from_cpp for torch::ScalarType is not implemented"); + } +}; } // namespace detail } // namespace NB_NAMESPACE diff --git a/cpp/tensorrt_llm/nanobind/executor/bindings.cpp b/cpp/tensorrt_llm/nanobind/executor/bindings.cpp index d3f482df89..ae4936a4df 100644 --- a/cpp/tensorrt_llm/nanobind/executor/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/bindings.cpp @@ -240,7 +240,8 @@ void initBindings(nb::module_& m) nb::class_(executor_kv_cache, "KVCacheEvent") .def_ro("event_id", &tle::KVCacheEvent::eventId) .def_ro("data", &tle::KVCacheEvent::data) - .def_ro("window_size", &tle::KVCacheEvent::windowSize); + .def_ro("window_size", &tle::KVCacheEvent::windowSize) + .def_ro("attention_dp_rank", &tle::KVCacheEvent::attentionDpRank); nb::class_(executor_kv_cache, "KVCacheEventManager") .def( diff --git a/cpp/tensorrt_llm/nanobind/executor/request.cpp b/cpp/tensorrt_llm/nanobind/executor/request.cpp index 1949474a10..e56341b53e 100644 --- a/cpp/tensorrt_llm/nanobind/executor/request.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/request.cpp @@ -27,6 +27,7 @@ #include #include +#include #include #include #include diff --git a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp index a22a62bf80..47be92e13f 100644 --- a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp @@ -279,7 +279,7 @@ void initBindings(nb::module_& m) .def(nb::init(), nb::arg("stream")) .def("setup", &tr::GptDecoderBatched::setup, nb::arg("mode"), nb::arg("max_num_sequences"), nb::arg("max_beam_width"), nb::arg("dtype"), nb::arg("model_config"), nb::arg("world_config")) - .def("forward_async", &tr::GptDecoderBatched::forwardAsync, nb::arg("output"), nb::arg("input")) + .def("forward_async", &tr::GptDecoderBatched::forwardAsync, nb::arg("decoder_state"), nb::arg("input")) .def("underlying_decoder", &tr::GptDecoderBatched::getUnderlyingDecoder, nb::rv_policy::reference) .def("finalize", &tr::GptDecoderBatched::finalize, nb::arg("decoder_state"), nb::arg("batch_idx"), nb::arg("sampling_config"), nb::arg("streaming")) diff --git a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp index 189e23b8ac..59d92e6429 100644 --- a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp @@ -946,8 +946,8 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, std::optional gemm2; if (common::getEnvForceDeterministicMOE()) { - gemm1 = mMOERunner->getTactics()[0]; - gemm2 = mMOERunner->getTactics()[0]; + gemm1 = mMOERunner->getTactics(MoeGemmId::GEMM_1)[0]; + gemm2 = mMOERunner->getTactics(MoeGemmId::GEMM_2)[0]; } else { @@ -1278,7 +1278,7 @@ void MixtureOfExpertsGemmProfiler::runTactic(int m, int n, int k, MixtureOfExper auto MixtureOfExpertsGemmProfiler::getTactics(int m, int n, int k) const -> std::vector { assert(mRunner); - return mRunner->mMOERunner->getTactics(); + return mRunner->mMOERunner->getTactics(backend.mGemmToProfile); } void MixtureOfExpertsGemmProfiler::initTmpData( diff --git a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h index cd3aaf52c2..feb1f10cdc 100644 --- a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h +++ b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h @@ -43,6 +43,7 @@ namespace kernels = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE; using MoeMinLatencyParams = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::MoeMinLatencyParams; using MOEParallelismConfig = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::MOEParallelismConfig; using QuantParams = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::QuantParams; +using MoeGemmId = CUTLASS_MOE_GEMM_NAMESPACE::MoeGemmId; using ActivationType = CUTLASS_MOE_GEMM_NAMESPACE::ActivationType; using ActivationParams = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::ActivationParams; using TmaWarpSpecializedGroupedGemmInput = CUTLASS_MOE_GEMM_NAMESPACE::TmaWarpSpecializedGroupedGemmInput; diff --git a/cpp/tensorrt_llm/pybind/CMakeLists.txt b/cpp/tensorrt_llm/pybind/CMakeLists.txt index 91b5ebf548..9d758b427c 100755 --- a/cpp/tensorrt_llm/pybind/CMakeLists.txt +++ b/cpp/tensorrt_llm/pybind/CMakeLists.txt @@ -44,7 +44,7 @@ target_link_libraries( ${Python3_LIBRARIES} ${TORCH_LIBRARIES} torch_python - CUDA::cuda_driver + ${CUDA_DRV_LIB} ${CUDA_NVML_LIB} th_common) target_compile_definitions( diff --git a/cpp/tensorrt_llm/runtime/decoderState.cpp b/cpp/tensorrt_llm/runtime/decoderState.cpp index abccbe60a1..b5851dc1c2 100644 --- a/cpp/tensorrt_llm/runtime/decoderState.cpp +++ b/cpp/tensorrt_llm/runtime/decoderState.cpp @@ -131,6 +131,7 @@ void DecoderState::setupSpeculativeDecodingBuffers( mSpeculativeDecodingMode = speculativeDecodingMode; + auto constexpr nvTokenIdType = TRTDataType::value; auto constexpr nvSizeType = TRTDataType::value; auto& dInput = mJointDecodingInput; @@ -179,6 +180,7 @@ void DecoderState::setupSpeculativeDecodingBuffers( DecodingInput::ExternalDraftTokensInputs externalDraftTokensInputs; externalDraftTokensInputs.draftLogits = bufferManager.emptyTensor(MemoryType::kGPU, dtype); + externalDraftTokensInputs.draftLogitsHost = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, dtype); externalDraftTokensInputs.draftProbs = bufferManager.emptyTensor(MemoryType::kGPU, dtype); externalDraftTokensInputs.targetProbs = bufferManager.emptyTensor(MemoryType::kGPU, dtype); externalDraftTokensInputs.numDraftTokens = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); @@ -187,8 +189,8 @@ void DecoderState::setupSpeculativeDecodingBuffers( = bufferManager.emptyTensor(MemoryType::kGPU, TRTDataType::value); externalDraftTokensInputs.useDraftLogitsHost = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType::value); - externalDraftTokensInputs.draftTokenIds - = bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); + externalDraftTokensInputs.draftTokenIds = bufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType); + externalDraftTokensInputs.draftTokenIdsHost = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvTokenIdType); dInput->externalDraftTokensInputs = externalDraftTokensInputs; } @@ -366,10 +368,16 @@ void DecoderState::reshapeSpeculativeDecodingBuffers(SpeculativeDecodingMode con {mMaxNumSequences, mMaxDecodingEngineTokens, mMaxBeamWidth, static_cast(vocabSizePadded)}); dInput.externalDraftTokensInputs->draftProbs->reshape(probsShape); dInput.externalDraftTokensInputs->targetProbs->reshape(probsShape); - dInput.externalDraftTokensInputs->draftLogits->reshape( - ITensor::makeShape({mMaxNumSequences, mMaxDecodingEngineTokens, static_cast(vocabSizePadded)})); - dInput.externalDraftTokensInputs->draftTokenIds->reshape( - ITensor::makeShape({mMaxNumSequences, mMaxDecodingEngineTokens})); + + auto const logitsShape = ITensor::makeShape( + {mMaxNumSequences, mMaxDecodingEngineTokens, static_cast(vocabSizePadded)}); + dInput.externalDraftTokensInputs->draftLogits->reshape(logitsShape); + dInput.externalDraftTokensInputs->draftLogitsHost->reshape(logitsShape); + + auto const tokenIdsShape = ITensor::makeShape({mMaxNumSequences, mMaxDecodingEngineTokens}); + dInput.externalDraftTokensInputs->draftTokenIds->reshape(tokenIdsShape); + dInput.externalDraftTokensInputs->draftTokenIdsHost->reshape(tokenIdsShape); + dInput.externalDraftTokensInputs->numDraftTokens->reshape(maxNumSequencesShape); dInput.externalDraftTokensInputs->numDraftTokensHost->reshape(maxNumSequencesShape); dInput.externalDraftTokensInputs->useDraftLogits->reshape(maxNumSequencesShape); diff --git a/cpp/tensorrt_llm/thop/CMakeLists.txt b/cpp/tensorrt_llm/thop/CMakeLists.txt index 494788c228..6224c0d2c9 100644 --- a/cpp/tensorrt_llm/thop/CMakeLists.txt +++ b/cpp/tensorrt_llm/thop/CMakeLists.txt @@ -83,7 +83,7 @@ add_library( reducescatterOp.cpp relativeAttentionBiasOp.cpp dsv3RouterGemmOp.cpp - renormMoeRoutingOp.cpp + customMoeRoutingOp.cpp selectiveScanOp.cpp userbuffersFinalizeOp.cpp userbuffersTensor.cpp @@ -119,9 +119,9 @@ endif() if(NOT WIN32) set_target_properties( - th_common - PROPERTIES LINK_FLAGS - "-Wl,-rpath='$ORIGIN' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}") + th_common PROPERTIES BUILD_RPATH "$ORIGIN;$ORIGIN/../../nvidia/nccl/lib") + set_target_properties( + th_common PROPERTIES LINK_FLAGS "${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}") else() target_link_libraries(th_common PRIVATE context_attention_src) endif() diff --git a/cpp/tensorrt_llm/thop/renormMoeRoutingOp.cpp b/cpp/tensorrt_llm/thop/customMoeRoutingOp.cpp similarity index 75% rename from cpp/tensorrt_llm/thop/renormMoeRoutingOp.cpp rename to cpp/tensorrt_llm/thop/customMoeRoutingOp.cpp index 616cf3bb7e..814fdf87c3 100644 --- a/cpp/tensorrt_llm/thop/renormMoeRoutingOp.cpp +++ b/cpp/tensorrt_llm/thop/customMoeRoutingOp.cpp @@ -15,7 +15,7 @@ */ #include "tensorrt_llm/common/opUtils.h" -#include "tensorrt_llm/kernels/renormMoeRoutingKernels.h" +#include "tensorrt_llm/kernels/customMoeRoutingKernels.h" #include "tensorrt_llm/runtime/torchUtils.h" namespace th = torch; @@ -25,7 +25,8 @@ namespace tk = tensorrt_llm::kernels; namespace torch_ext { -std::tuple renorm_moe_routing_op(th::Tensor const& router_logits, int64_t topk) +template +std::tuple custom_moe_routing_op(th::Tensor const& router_logits, int64_t topk) { auto data_type = router_logits.scalar_type(); auto input_size = router_logits.sizes(); @@ -44,20 +45,22 @@ std::tuple renorm_moe_routing_op(th::Tensor const& route { case torch::kFloat32: // Handle Float32 - tk::invokeRenormMoeRouting(reinterpret_cast(router_logits.mutable_data_ptr()), + tk::invokeRenormMoeRouting( + reinterpret_cast(router_logits.mutable_data_ptr()), reinterpret_cast(topk_values.mutable_data_ptr()), reinterpret_cast(topk_indices.mutable_data_ptr()), num_tokens, num_experts, topk, stream); break; case torch::kBFloat16: // Handle BFloat16 - tk::invokeRenormMoeRouting<__nv_bfloat16, float, int32_t>( + tk::invokeRenormMoeRouting<__nv_bfloat16, float, int32_t, DoSoftmaxBeforeTopK>( reinterpret_cast<__nv_bfloat16*>(router_logits.mutable_data_ptr()), reinterpret_cast(topk_values.mutable_data_ptr()), reinterpret_cast(topk_indices.mutable_data_ptr()), num_tokens, num_experts, topk, stream); break; case torch::kHalf: // Handle Half - tk::invokeRenormMoeRouting(reinterpret_cast(router_logits.mutable_data_ptr()), + tk::invokeRenormMoeRouting( + reinterpret_cast(router_logits.mutable_data_ptr()), reinterpret_cast(topk_values.mutable_data_ptr()), reinterpret_cast(topk_indices.mutable_data_ptr()), num_tokens, num_experts, topk, stream); break; @@ -69,6 +72,15 @@ std::tuple renorm_moe_routing_op(th::Tensor const& route return {topk_indices, topk_values}; } +std::tuple renorm_moe_routing_op(th::Tensor const& router_logits, int64_t topk) +{ + return custom_moe_routing_op(router_logits, topk); +} + +std::tuple default_moe_routing_op(th::Tensor const& router_logits, int64_t topk) +{ + return custom_moe_routing_op(router_logits, topk); +} } // namespace torch_ext TORCH_LIBRARY_FRAGMENT(trtllm, m) @@ -82,3 +94,15 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("renorm_moe_routing_op", &torch_ext::renorm_moe_routing_op); } + +TORCH_LIBRARY_FRAGMENT(trtllm, m) +{ + m.def( + "default_moe_routing_op(Tensor router_logits, SymInt topk" + ") -> (Tensor, Tensor)"); +} + +TORCH_LIBRARY_IMPL(trtllm, CUDA, m) +{ + m.impl("default_moe_routing_op", &torch_ext::default_moe_routing_op); +} diff --git a/cpp/tensorrt_llm/thop/moeOp.cpp b/cpp/tensorrt_llm/thop/moeOp.cpp index 328cce3d01..abeba273a8 100644 --- a/cpp/tensorrt_llm/thop/moeOp.cpp +++ b/cpp/tensorrt_llm/thop/moeOp.cpp @@ -48,6 +48,7 @@ namespace common = tensorrt_llm::common; namespace kernels = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE; using ActivationParams = CUTLASS_MOE_GEMM_NAMESPACE::ActivationParams; using ActivationType = CUTLASS_MOE_GEMM_NAMESPACE::ActivationType; +using MoeGemmId = CUTLASS_MOE_GEMM_NAMESPACE::MoeGemmId; // Always use public header as it is just utility functions and types using TmaWarpSpecializedGroupedGemmInput = tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput; using profiler_backend = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::GemmProfilerBackend; @@ -215,7 +216,8 @@ public: mKernelRunner->use_fused_finalize_ = mUseFusedFinalize; mProfiler = std::make_shared(); - mAllProfiles = mKernelRunner->getTactics(); + mGemm1Profiles = mKernelRunner->getTactics(MoeGemmId::GEMM_1); + mGemm2Profiles = mKernelRunner->getTactics(MoeGemmId::GEMM_2); } ~FusedMoeRunner() @@ -585,10 +587,11 @@ public: return std::make_tuple(output, num_active_experts_per_node, experts_to_token_score, active_expert_global_ids); } - int64_t getTacticNum() + int64_t getTacticNum(int64_t const gemm_idx) { std::lock_guard lock(mMutex); - return mAllProfiles.size(); + TORCH_CHECK(gemm_idx == 1 || gemm_idx == 2, "gemm_idx must be 1 or 2"); + return (gemm_idx == 1) ? mGemm1Profiles.size() : mGemm2Profiles.size(); } // TODO Update this to be able to tell if we are profiling swiglu bias @@ -624,10 +627,14 @@ public: : group_size_; int const num_experts = static_cast(fc2_expert_weights.sizes()[0] * ep_size); + auto const gemm_to_profile + = (gemm_idx == 1) ? profiler_backend::GemmToProfile::GEMM_1 : profiler_backend::GemmToProfile::GEMM_2; + auto const& profiles = (gemm_idx == 1) ? mGemm1Profiles : mGemm2Profiles; + // Get specific profile configs according to the profile_id. // Fallback tactic is set to be 0 // TODO: use the best tactic id found offline for a better default inference perf - auto const& profile = profile_id == -1 ? mAllProfiles.front() : mAllProfiles[profile_id]; + auto const& profile = profile_id == -1 ? profiles.front() : profiles[profile_id]; auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); @@ -638,8 +645,7 @@ public: if (do_preparation) { // Set profiled gemm idx - mProfiler->mGemmToProfile - = (gemm_idx == 1) ? profiler_backend::GemmToProfile::GEMM_1 : profiler_backend::GemmToProfile::GEMM_2; + mProfiler->mGemmToProfile = gemm_to_profile; // mProfiler init auto parallelism_config = kernels::MOEParallelismConfig(static_cast(tp_size), @@ -704,7 +710,8 @@ private: bool mUseFusedFinalize = true; using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig; - std::vector mAllProfiles; + std::vector mGemm1Profiles; + std::vector mGemm2Profiles; void freeProfileWorkspace() { @@ -730,15 +737,15 @@ private: return; } - auto best_gemm1_profile = mAllProfiles.front(); - auto best_gemm2_profile = mAllProfiles.front(); + auto best_gemm1_profile = mGemm1Profiles.front(); + auto best_gemm2_profile = mGemm2Profiles.front(); if (profile_ids.has_value()) { TORCH_CHECK(profile_ids.value().size() == 2, "Expecting 2 profile ids"); best_gemm1_profile - = profile_ids.value()[0] == -1 ? best_gemm1_profile : mAllProfiles.at(profile_ids.value()[0]); + = profile_ids.value()[0] == -1 ? best_gemm1_profile : mGemm1Profiles.at(profile_ids.value()[0]); best_gemm2_profile - = profile_ids.value()[1] == -1 ? best_gemm2_profile : mAllProfiles.at(profile_ids.value()[1]); + = profile_ids.value()[1] == -1 ? best_gemm2_profile : mGemm2Profiles.at(profile_ids.value()[1]); } mKernelRunner->setTactic(best_gemm1_profile, best_gemm2_profile); } diff --git a/cpp/tests/batch_manager/cacheTransceiverTest.cpp b/cpp/tests/batch_manager/cacheTransceiverTest.cpp index 99c40f810f..4b513ae57f 100644 --- a/cpp/tests/batch_manager/cacheTransceiverTest.cpp +++ b/cpp/tests/batch_manager/cacheTransceiverTest.cpp @@ -99,7 +99,7 @@ TEST_F(RequestInfoTest, Basic) } auto state = std::make_unique(); state->setCommState(texec::kv_cache::CommState{12, "127.0.0.1"}); - state->setCacheState(texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, nvinfer1::DataType::kFLOAT}); + state->setCacheState(texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 8, nvinfer1::DataType::kFLOAT}); RequestInfo info{1, *state}; auto info2 = serializeDeserialize(info); EXPECT_EQ(info, info2); @@ -133,7 +133,7 @@ TEST_F(CacheConfigTest, EqualTo) constexpr SizeType32 tokensPerBlock{64}; constexpr SizeType32 tensorParallelism{8}; constexpr SizeType32 pipelineParallelism{2}; - constexpr SizeType32 contextParallelism{1}; + constexpr SizeType32 contextParallelism{2}; constexpr SizeType32 sizePerHead{hiddenSize / nbHeads}; constexpr CacheState::AttentionType attentionType{CacheState::AttentionType::kDEFAULT}; constexpr int kvFactor = 2; @@ -148,7 +148,7 @@ TEST_F(CacheConfigTest, EqualTo) texec::kv_cache::CacheState state0{ cacheStateCfg, worldConfig, modelConfig.getKvDataType(), attentionType, kvFactor}; texec::kv_cache::CacheState state1{nbAttentionLayers, nbHeads, sizePerHead, tokensPerBlock, tensorParallelism, - pipelineParallelism, dtype, attentionType, kvFactor, false, 0, tensorParallelism}; + pipelineParallelism, contextParallelism, dtype, attentionType, kvFactor, false, 0, tensorParallelism}; EXPECT_EQ(state0, state1); } @@ -165,7 +165,7 @@ public: ON_CALL(*this, recvRequestInfo) .WillByDefault(Return(RequestInfo{0, texec::DataTransceiverState{ - texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, nvinfer1::DataType::kFLOAT}, + texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 1, nvinfer1::DataType::kFLOAT}, texec::kv_cache::CommState{std::vector{0}, 0}}})); ON_CALL(*this, getCounterpartsCount).WillByDefault(Return(1)); } @@ -217,7 +217,8 @@ TEST_F(MockTransceiverTest, MpiResponderBasic) auto sender = std::make_unique(); EXPECT_CALL(*sender, recvRequestInfo) .WillOnce(Return(RequestInfo{0, - texec::DataTransceiverState{texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, nvinfer1::DataType::kFLOAT}, + texec::DataTransceiverState{ + texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 1, nvinfer1::DataType::kFLOAT}, texec::kv_cache::CommState{std::vector{0}, 0}}})); EXPECT_CALL(*sender, sendSync).WillOnce(Return()); EXPECT_CALL(*sender, getCounterpartsCount).WillOnce(Return(1)); @@ -318,7 +319,7 @@ protected: dataType, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks, CacheType::kSELF, std::nullopt, nullptr, true); mCacheState = std::make_unique( - numLayers, numHeads, sizePerHead, tokensPerBlock, 1, 1, dataType); + numLayers, numHeads, sizePerHead, tokensPerBlock, 1, 1, 1, dataType); if (tensorrt_llm::common::getEnvUseUCXKvCache()) { @@ -506,7 +507,7 @@ TEST_F(SymmetricalCacheTest, SimpleTest) #if ENABLE_MULTI_DEVICE using AsymmetricTestParam - = std::tuple; + = std::tuple; class AsymmetricalCacheTest : public ::testing::TestWithParam { @@ -516,8 +517,8 @@ protected: void TearDown() override {} - void setUpCommunicator(int contextTp, int contextPp, int genTp, int genPp, bool isMLA = false, - bool contextDP = false, bool generationDP = false) + void setUpCommunicator(int contextTp, int contextPp, int contextCp, int genTp, int genPp, int genCp, + bool isMLA = false, bool contextDP = false, bool generationDP = false) { #if ENABLE_MULTI_DEVICE tensorrt_llm::mpi::initialize(tensorrt_llm::mpi::MpiThreadSupport::THREAD_MULTIPLE); @@ -572,11 +573,13 @@ protected: { mTpSize = contextTp; mPpSize = contextPp; + mCpSize = contextCp; } if (mIsGeneration) { mTpSize = genTp; mPpSize = genPp; + mCpSize = genCp; } mTpRank = mRankInInstance % mTpSize; @@ -585,6 +588,7 @@ protected: mGenRankSize = genRanks; mContextTpSize = contextTp; mContextPpSize = contextPp; + mContextCpSize = contextCp; EXPECT_EQ((sessionComm.getRank()), mRankInInstance); EXPECT_EQ(sessionComm.getSize(), mSizeInInstance); @@ -696,11 +700,12 @@ protected: texec::kv_cache::CacheState::AttentionType attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA : texec::kv_cache::CacheState::AttentionType::kDEFAULT; - mCacheState = std::make_unique(numLayers, numHeadsPerRank, sizePerHead, - tokensPerBlock, mTpSize, mPpSize, dataType, attentionType, kvFactor, enableDPAttention, DPrank, DPsize); + mCacheState + = std::make_unique(numLayers, numHeadsPerRank, sizePerHead, tokensPerBlock, + mTpSize, mPpSize, mCpSize, dataType, attentionType, kvFactor, enableDPAttention, DPrank, DPsize); mContextCacheState = std::make_unique(numLayers, numHeadsPerRankForContext, - sizePerHead, tokensPerBlock, mContextTpSize, mContextPpSize, dataType, attentionType, kvFactor, mContextDP, - DPrank, mContextTpSize); + sizePerHead, tokensPerBlock, mContextTpSize, mContextPpSize, mContextCpSize, dataType, attentionType, + kvFactor, mContextDP, DPrank, mContextTpSize); // UVM seems to be incompatible with MPI, and it is continuing to investigate. bool constexpr useUvm = false; @@ -859,7 +864,8 @@ protected: texec::kv_cache::CacheState cacheState{mContextCacheState->getModelConfig().mNbKvHeadsPerLayer, mContextCacheState->getModelConfig().mSizePerHead, mContextCacheState->getModelConfig().mTokensPerBlock, mContextCacheState->getParallelConfig().mTensorParallelism, - mContextCacheState->getParallelConfig().mPipelineParallelism, mContextCacheState->getDataType(), + mContextCacheState->getParallelConfig().mPipelineParallelism, + mContextCacheState->getParallelConfig().mContextParallelism, mContextCacheState->getDataType(), mContextCacheState->getAttentionConfig().mAttentionType, mContextCacheState->getAttentionConfig().mKvFactor, mContextCacheState->getParallelConfig().mEnableAttentionDP, contextDpRank, mContextCacheState->getParallelConfig().mTensorParallelism}; @@ -1094,8 +1100,8 @@ protected: tensorrt_llm::mpi::MpiComm const* mComm; tensorrt_llm::mpi::MpiComm mParticipatingComm{nullptr, false}; SizeType32 mWorldSize{0}, mRank{0}, mRankInInstance{0}; - SizeType32 mSizeInInstance{0}, mTpRank{0}, mPpRank{0}, mTpSize{0}, mPpSize{0}, mContextRankSize{0}, mGenRankSize{0}, - mContextTpSize{0}, mContextPpSize{0}; + SizeType32 mSizeInInstance{0}, mTpRank{0}, mPpRank{0}, mTpSize{0}, mPpSize{0}, mCpSize{0}, mContextRankSize{0}, + mGenRankSize{0}, mContextTpSize{0}, mContextPpSize{0}, mContextCpSize{0}; LlmRequest::RequestIdType mRequestId{0}; bool mContextDP{false}; bool mGenerationDP{false}; @@ -1129,22 +1135,24 @@ TEST_P(AsymmetricalCacheTest, TestCase) AsymmetricTestParam param = GetParam(); int contextTp = std::get<0>(param); int contextPp = std::get<1>(param); - int genTp = std::get<2>(param); - int genPp = std::get<3>(param); - int numLayers = std::get<4>(param); - int numHeads = std::get<5>(param); - int sizePerHead = std::get<6>(param); - int tokensPerBlock = std::get<7>(param); - nvinfer1::DataType dataType = std::get<8>(param); + int contextCp = std::get<2>(param); + int genTp = std::get<3>(param); + int genPp = std::get<4>(param); + int genCp = std::get<5>(param); + int numLayers = std::get<6>(param); + int numHeads = std::get<7>(param); + int sizePerHead = std::get<8>(param); + int tokensPerBlock = std::get<9>(param); + nvinfer1::DataType dataType = std::get<10>(param); - int kvFactor = std::get<9>(param); - bool isMLA = std::get<10>(param); - bool contextDP = std::get<11>(param); - bool generationDP = std::get<12>(param); + int kvFactor = std::get<11>(param); + bool isMLA = std::get<12>(param); + bool contextDP = std::get<13>(param); + bool generationDP = std::get<14>(param); - bool isWindow = std::get<13>(param); + bool isWindow = std::get<15>(param); - setUpCommunicator(contextTp, contextPp, genTp, genPp, isMLA, contextDP, generationDP); + setUpCommunicator(contextTp, contextPp, contextCp, genTp, genPp, genCp, isMLA, contextDP, generationDP); if (mIsContext || mIsGeneration) { @@ -1221,21 +1229,23 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase) AsymmetricTestParam param = GetParam(); int contextTp = std::get<0>(param); int contextPp = std::get<1>(param); - int genTp = std::get<2>(param); - int genPp = std::get<3>(param); - int numLayers = std::get<4>(param); - int numHeads = std::get<5>(param); - int sizePerHead = std::get<6>(param); - int tokensPerBlock = std::get<7>(param); - nvinfer1::DataType dataType = std::get<8>(param); + int contextCp = std::get<2>(param); + int genTp = std::get<3>(param); + int genPp = std::get<4>(param); + int genCp = std::get<5>(param); + int numLayers = std::get<6>(param); + int numHeads = std::get<7>(param); + int sizePerHead = std::get<8>(param); + int tokensPerBlock = std::get<9>(param); + nvinfer1::DataType dataType = std::get<10>(param); - int kvFactor = std::get<9>(param); - bool isMLA = std::get<10>(param); - bool contextDP = std::get<11>(param); - bool generationDP = std::get<12>(param); - bool isWindow = std::get<13>(param); + int kvFactor = std::get<11>(param); + bool isMLA = std::get<12>(param); + bool contextDP = std::get<13>(param); + bool generationDP = std::get<14>(param); + bool isWindow = std::get<15>(param); - setUpCommunicator(contextTp, contextPp, genTp, genPp, isMLA, contextDP, generationDP); + setUpCommunicator(contextTp, contextPp, contextCp, genTp, genPp, genCp, isMLA, contextDP, generationDP); if (mIsContext || mIsGeneration) { @@ -1324,95 +1334,95 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase) } INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest0, AsymmetricalCacheTest, - testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), - testing::Values(4), testing::Values(4), testing::Values(4), testing::Values(16), - testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), + testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2), + testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(4), testing::Values(4), + testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(true, false))); INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithWindow, AsymmetricalCacheTest, - testing::Combine(testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(5), - testing::Values(4), testing::Values(4), testing::Values(8), + testing::Combine(testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(1), + testing::Values(1), testing::Values(5), testing::Values(4), testing::Values(4), testing::Values(8), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(true))); INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest1, AsymmetricalCacheTest, - testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(8), - testing::Values(4), testing::Values(4), testing::Values(8), + testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(4), + testing::Values(1), testing::Values(8), testing::Values(4), testing::Values(4), testing::Values(8), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(false, true))); INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest2, AsymmetricalCacheTest, - testing::Combine(testing::Values(1), testing::Values(2), testing::Values(1), testing::Values(1, 4), - testing::Values(16), testing::Values(16), testing::Values(4), testing::Values(8), - testing::Values(nvinfer1::DataType::kFLOAT), testing::Values(2), testing::Values(false), testing::Values(false), - testing::Values(false), testing::Values(false))); + testing::Combine(testing::Values(1), testing::Values(2), testing::Values(1), testing::Values(1), + testing::Values(1, 4), testing::Values(1), testing::Values(16), testing::Values(16), testing::Values(4), + testing::Values(8), testing::Values(nvinfer1::DataType::kFLOAT), testing::Values(2), testing::Values(false), + testing::Values(false), testing::Values(false), testing::Values(false))); INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest0ForMLA, AsymmetricalCacheTest, - testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), - testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16), - testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1), + testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2), + testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), + testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1), testing::Values(true), testing::Values(false), testing::Values(false), testing::Values(false))); INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest1ForMLA, AsymmetricalCacheTest, - testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(4), - testing::Values(1), testing::Values(4), testing::Values(8), + testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(4), + testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(8), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1), testing::Values(true), testing::Values(false), testing::Values(false), testing::Values(false))); INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA1, AsymmetricalCacheTestWithDP, - testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), - testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16), - testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1), + testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2), + testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), + testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1), testing::Values(true), testing::Values(true), testing::Values(true), testing::Values(false))); INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA2, AsymmetricalCacheTestWithDP, - testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), - testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16), - testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1), + testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2), + testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), + testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1), testing::Values(true), testing::Values(true), testing::Values(false), testing::Values(false))); INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA3, AsymmetricalCacheTestWithDP, - testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), - testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16), - testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1), + testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2), + testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), + testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1), testing::Values(true), testing::Values(false), testing::Values(true), testing::Values(false))); INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLA, AsymmetricalCacheTestWithDP, - testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), - testing::Values(4), testing::Values(4), testing::Values(4), testing::Values(16), - testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), + testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2), + testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(4), testing::Values(4), + testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), testing::Values(false), testing::Values(true), testing::Values(true), testing::Values(false))); INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLA1, AsymmetricalCacheTestWithDP, - testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), - testing::Values(4), testing::Values(4), testing::Values(4), testing::Values(16), - testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), + testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2), + testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(4), testing::Values(4), + testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), testing::Values(false), testing::Values(true), testing::Values(false), testing::Values(false))); INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLA2, AsymmetricalCacheTestWithDP, - testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), - testing::Values(4), testing::Values(4), testing::Values(4), testing::Values(16), - testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), + testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2), + testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(4), testing::Values(4), + testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), testing::Values(false), testing::Values(false), testing::Values(true), testing::Values(false))); INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate0, AsymmetricalCacheTestWithDP, - testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(4), testing::Values(1), - testing::Values(4), testing::Values(2), testing::Values(4), testing::Values(16), - testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), + testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(4), + testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(2), testing::Values(4), + testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), testing::Values(false), testing::Values(true, false), testing::Values(false), testing::Values(false))); INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate1, AsymmetricalCacheTestWithDP, - testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(2), testing::Values(2), - testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16), - testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), + testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(2), + testing::Values(2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), + testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), testing::Values(false), testing::Values(true, false), testing::Values(false), testing::Values(false))); INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate2, AsymmetricalCacheTestWithDP, - testing::Combine(testing::Values(4), testing::Values(1), testing::Values(4, 2), testing::Values(1), - testing::Values(4), testing::Values(2), testing::Values(4), testing::Values(16), - testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), + testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(4, 2), + testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(2), testing::Values(4), + testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(false))); INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate4, AsymmetricalCacheTestWithDP, - testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1, 2), testing::Values(2), - testing::Values(4), testing::Values(1, 2), testing::Values(4), testing::Values(16), - testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), + testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(1, 2), + testing::Values(2), testing::Values(1), testing::Values(4), testing::Values(1, 2), testing::Values(4), + testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(false))); #endif @@ -1430,8 +1440,10 @@ TEST(targetTest, CacheStateNODP) int contextPP = 2; int contextTP = 4; + int contextCP = 1; int genPP = 2; int genTP = 2; + int genCP = 1; bool const contextEnableDP = false; bool const genEnableDP = false; @@ -1441,10 +1453,10 @@ TEST(targetTest, CacheStateNODP) auto attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA : texec::kv_cache::CacheState::AttentionType::kDEFAULT; auto const contextCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, - tokensPerBlock, contextTP, contextPP, dataType, attentionType, kvFactor, contextEnableDP, 0, 0}; + tokensPerBlock, contextTP, contextPP, contextCP, dataType, attentionType, kvFactor, contextEnableDP, 0, 0}; auto const genCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, - tokensPerBlock, genTP, genPP, dataType, attentionType, kvFactor, genEnableDP, 0, 0}; + tokensPerBlock, genTP, genPP, genCP, dataType, attentionType, kvFactor, genEnableDP, 0, 0}; auto const contextTragetInfo = tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(genCache, contextCache, contextRank); @@ -1504,8 +1516,10 @@ TEST(targetTest, CacheStateContextDP) int contextPP = 1; int contextTP = 4; + int contextCP = 1; int genPP = 1; int genTP = 2; + int genCP = 1; bool contextEnableDP = true; bool genEnableDP = true; @@ -1519,10 +1533,11 @@ TEST(targetTest, CacheStateContextDP) auto const contextCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, tokensPerBlock, contextTP, - contextPP, dataType, attentionType, kvFactor, contextEnableDP, contextDPRank, contextTP}; + contextPP, contextCP, dataType, attentionType, kvFactor, contextEnableDP, contextDPRank, contextTP}; - auto const genCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, - tokensPerBlock, genTP, genPP, dataType, attentionType, kvFactor, genEnableDP, generationDPRank, genTP}; + auto const genCache + = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, tokensPerBlock, genTP, + genPP, genCP, dataType, attentionType, kvFactor, genEnableDP, generationDPRank, genTP}; auto const contextTragetInfo = tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(genCache, contextCache, contextRank); @@ -1625,10 +1640,11 @@ TEST(targetTest, CacheStateContextDP) auto const contextCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, tokensPerBlock, contextTP, - contextPP, dataType, attentionType, kvFactor, contextEnableDP, contextDPRank, contextTP}; + contextPP, contextCP, dataType, attentionType, kvFactor, contextEnableDP, contextDPRank, contextTP}; - auto const genCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, - tokensPerBlock, genTP, genPP, dataType, attentionType, kvFactor, genEnableDP, generationDPRank, genTP}; + auto const genCache + = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, tokensPerBlock, genTP, + genPP, genCP, dataType, attentionType, kvFactor, genEnableDP, generationDPRank, genTP}; auto const contextTragetInfo = tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(contextCache, genCache, generationRank); diff --git a/cpp/tests/unit_tests/executor/agentCommTest.cpp b/cpp/tests/unit_tests/executor/agentCommTest.cpp index 9c23f33f50..d9e6aaa138 100644 --- a/cpp/tests/unit_tests/executor/agentCommTest.cpp +++ b/cpp/tests/unit_tests/executor/agentCommTest.cpp @@ -90,7 +90,7 @@ protected: size_t maxNumTokens = 1024; mTransBufferManager = std::make_unique(mCacheManager.get(), maxNumTokens); - mCacheState = std::make_unique(numLayers, numHeads, sizePerHead, tokensPerBlock, 1, 1, dataType); + mCacheState = std::make_unique(numLayers, numHeads, sizePerHead, tokensPerBlock, 1, 1, 1, dataType); } void TearDown() override diff --git a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp index 27fff8df7d..1dad1fa2bb 100644 --- a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp +++ b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp @@ -726,7 +726,7 @@ TEST(SerializeUtilsTest, ContextPhaseParams) { auto state = std::make_unique(); state->setCommState(texec::kv_cache::CommState{12, "127.0.0.1"}); - state->setCacheState(texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, nvinfer1::DataType::kFLOAT}); + state->setCacheState(texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 8, nvinfer1::DataType::kFLOAT}); auto stats = texec::ContextPhaseParams({10, 20, 30, 40, 50, 60}, 0, state.release(), VecTokens{10, 20}); auto stats2 = serializeDeserialize(stats); EXPECT_EQ(stats, stats2); diff --git a/cpp/tests/unit_tests/executor/transferAgentTest.cpp b/cpp/tests/unit_tests/executor/transferAgentTest.cpp index c73d9a2140..4745e8e40b 100644 --- a/cpp/tests/unit_tests/executor/transferAgentTest.cpp +++ b/cpp/tests/unit_tests/executor/transferAgentTest.cpp @@ -255,7 +255,8 @@ TEST_F(TransferAgentTest, SyncMessage) checked = nixlAgent0->checkRemoteDescs(agent1, regMem3.getDescs()); } while (!checked); auto syncMessage = std::string("agent_sync_message"); - TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem3.getDescs(), agent1, syncMessage}; + nixlAgent0->notifySyncMessage(agent1, syncMessage); + TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem3.getDescs(), agent1}; auto status = nixlAgent0->submitTransferRequests(writeReq); auto notif = nixlAgent1->getNotifiedSyncMessages(); @@ -302,7 +303,8 @@ TEST_F(TransferAgentTest, SyncMessage) } while (!checked2); std::string syncMessage4 = "four_agent_sync_message"; - TransferRequest writeReq1{TransferOp::kWRITE, regMem2.getDescs(), regMem1.getDescs(), agent0, syncMessage4}; + nixlAgent1->notifySyncMessage(agent0, syncMessage4); + TransferRequest writeReq1{TransferOp::kWRITE, regMem2.getDescs(), regMem1.getDescs(), agent0}; auto status1 = nixlAgent1->submitTransferRequests(writeReq1); auto notif4 = nixlAgent0->getNotifiedSyncMessages(); for (std::size_t i = 0; i < MAX_QUERY_TIMES && notif4.size() == 0; i++) diff --git a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu index 6f2ce0f93e..11ae4273dc 100644 --- a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu +++ b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu @@ -370,8 +370,8 @@ protected: float mSparseMixerEpsilon = 0.2f; - // Default this to true. This only matters for K>2, and so by doing this we will test the fused and unfused paths - bool mUseDeterministicHopperReduce = true; + // Default this to false. This only matters for K>2, and so by doing this we will test the fused and unfused paths + bool mUseFusedFinalize = false; // Disable this for long running tests to speed up runtime bool mIsLongTest = false; @@ -456,7 +456,7 @@ protected: { managed_buffers.clear(); - mMoERunner.use_fused_finalize_ = k < 3 || !mUseDeterministicHopperReduce; + mMoERunner.use_fused_finalize_ = k < 3 || mUseFusedFinalize; mHiddenSize = hidden_size; mInterSize = hidden_size * mInterSizeFraction; @@ -1087,9 +1087,9 @@ protected: return std::tuple{(void*) weight_1, (void*) weight_2, bias_1, bias2_ptr, scale_1, scale_2, scale_3}; } - auto getFilteredConfigs(int sm) + auto getFilteredConfigs(int sm, MoeGemmId gemm_id) { - auto tactics = mMoERunner.getTactics(); + auto tactics = mMoERunner.getTactics(gemm_id); if (sm == 89 || sm >= 120) { // Filter some unsupported configs for L40S @@ -1120,17 +1120,27 @@ protected: auto selectTacticsForArch(int sm) { bool is_tma_warp_specialized = sm >= 90 && !INT_QUANT; - auto tactics = getFilteredConfigs(sm); - auto it = std::find_if(tactics.begin(), tactics.end(), + auto epilogue_fusion_type = (is_tma_warp_specialized && mUseFusedFinalize) + ? tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE + : tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::NONE; + auto tactics1 = getFilteredConfigs(sm, MoeGemmId::GEMM_1); + auto tactics2 = getFilteredConfigs(sm, MoeGemmId::GEMM_2); + auto it1 = std::find_if(tactics1.begin(), tactics1.end(), [is_tma_warp_specialized](auto& c) { return c.is_tma_warp_specialized == is_tma_warp_specialized; }); - if (it == tactics.end()) + auto it2 = std::find_if(tactics2.begin(), tactics2.end(), + [is_tma_warp_specialized, epilogue_fusion_type](auto& c) { + return c.is_tma_warp_specialized == is_tma_warp_specialized + && c.epilogue_fusion_type == epilogue_fusion_type; + }); + if (it1 == tactics1.end() || it2 == tactics2.end()) { // Fall back to any tactic std::cout << "WARNING: Could not find config for sm version " << sm << std::endl; - return std::pair{tactics[0], tactics[0]}; + it1 = (it1 == tactics1.end()) ? tactics1.begin() : it1; + it2 = (it2 == tactics2.end()) ? tactics2.begin() : it2; } - return std::pair(*it, *it); + return std::pair(*it1, *it2); } using ConfigsToTestVec = std::vectorget(); auto tactic1 = mInternalSelectedConfig1; auto tactic2 = mInternalSelectedConfig2; - if (!tactic1) + if (!tactic1 || !tactic2) { int sm = getSMVersion(); std::tie(tactic1, tactic2) = selectTacticsForArch(sm); @@ -1629,8 +1639,9 @@ void MixtureOfExpertsTest::BasicPermuteTest( auto [expected_experts, token_final_scales] = populateRouting(num_experts, num_tokens, k); runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k); - bool should_be_deterministic - = mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; + bool is_finalize_fusion = gemm2.epilogue_fusion_type + == tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; + bool should_be_deterministic = !is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; if (should_be_deterministic && !mIsLongTest) { auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize); @@ -1749,7 +1760,7 @@ TYPED_TEST(MixtureOfExpertsTest, PermuteSwigluBias) TYPED_TEST(MixtureOfExpertsTest, PermuteNonDeterministic) { - this->mUseDeterministicHopperReduce = false; + this->mUseFusedFinalize = true; // Just test case 3, cases 1&2 always use the fused paths this->BasicPermuteTest(3); } @@ -1896,8 +1907,10 @@ void MixtureOfExpertsTest::ParallelismTest( // Only need to init the inputs on the first iteration runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k, MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall); + bool is_finalize_fusion = gemm2.epilogue_fusion_type + == tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; bool should_be_deterministic - = mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; + = !is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; if (should_be_deterministic && !mIsLongTest) { auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize); @@ -1912,8 +1925,10 @@ void MixtureOfExpertsTest::ParallelismTest( else { runMoEPermute(MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall); + bool is_finalize_fusion = gemm2.epilogue_fusion_type + == tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE; bool should_be_deterministic - = mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; + = !is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120; if (should_be_deterministic && !mIsLongTest) { auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize); @@ -2077,6 +2092,7 @@ PARALLEL_TEST_SUITE(MixedParallel) TYPED_TEST(MixtureOfExpertsTest, ConfigSweep) { this->mIsLongTest = true; + this->mUseFusedFinalize = true; // True for all cases because we sweep both auto genConfigName = [](auto conf) -> std::string { using namespace tensorrt_llm::cutlass_extensions; @@ -2103,12 +2119,13 @@ TYPED_TEST(MixtureOfExpertsTest, ConfigSweep) auto activation_pool = std::vector{ActivationType::Relu, ActivationType::Swiglu, ActivationType::SwigluBias}; if (this->NVFP4) activation_pool = {ActivationType::Relu}; - auto configs = this->getFilteredConfigs(getSMVersion()); + auto configs1 = this->getFilteredConfigs(getSMVersion(), MoeGemmId::GEMM_1); + auto configs2 = this->getFilteredConfigs(getSMVersion(), MoeGemmId::GEMM_2); for (auto const activation_type : activation_pool) { - for (auto conf1 : configs) + for (auto conf1 : configs1) { - for (auto conf2 : configs) + for (auto conf2 : configs2) { auto name1 = genConfigName(conf1); auto name2 = genConfigName(conf2); @@ -2120,7 +2137,6 @@ TYPED_TEST(MixtureOfExpertsTest, ConfigSweep) this->mActType = activation_type; for (auto k : {2, 3}) { - this->mOverrideSelectedConfig1 = conf1; this->mOverrideSelectedConfig2 = conf2; this->BasicPermuteTest(k, this->MINIMUM_ALIGNMENT); diff --git a/docker/common/install_nixl.sh b/docker/common/install_nixl.sh index 18ee554f69..cecd61a7af 100644 --- a/docker/common/install_nixl.sh +++ b/docker/common/install_nixl.sh @@ -4,8 +4,9 @@ set -ex GITHUB_URL="https://github.com" UCX_INSTALL_PATH="/usr/local/ucx/" CUDA_PATH="/usr/local/cuda" -NIXL_VERSION="0.3.1" +NIXL_VERSION="0.5.0" NIXL_REPO="https://github.com/ai-dynamo/nixl.git" +OLD_LD_LIBRARY_PATH=$LD_LIBRARY_PATH ARCH_NAME="x86_64-linux-gnu" GDS_PATH="$CUDA_PATH/targets/x86_64-linux" @@ -18,25 +19,26 @@ pip3 install --no-cache-dir meson ninja pybind11 git clone --depth 1 -b ${NIXL_VERSION} ${NIXL_REPO} cd nixl -cuda_path=$(find / -name "libcuda.so.1" 2>/dev/null | head -n1) -if [[ -z "$cuda_path" ]]; then - echo "libcuda.so.1 not found " +CUDA_SO_PATH=$(find "/usr/local" -name "libcuda.so.1" 2>/dev/null | head -n1) + +if [[ -z "$CUDA_SO_PATH" ]]; then + echo "libcuda.so.1 not found" exit 1 fi -ln -sf $cuda_path $CUDA_PATH/lib64/libcuda.so.1 +CUDA_SO_PATH=$(dirname $CUDA_SO_PATH) +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CUDA_SO_PATH meson setup builddir \ -Ducx_path=$UCX_INSTALL_PATH \ -Dcudapath_lib="$CUDA_PATH/lib64" \ -Dcudapath_inc="$CUDA_PATH/include" \ -Dgds_path="$GDS_PATH" \ - -Dinstall_headers=true \ - -Dstatic_plugins=UCX + -Dinstall_headers=true cd builddir && ninja install cd ../.. rm -rf nixl* # Remove NIXL source tree to save space -rm $CUDA_PATH/lib64/libcuda.so.1 +export LD_LIBRARY_PATH=$OLD_LD_LIBRARY_PATH echo "export LD_LIBRARY_PATH=/opt/nvidia/nvda_nixl/lib/${ARCH_NAME}:/opt/nvidia/nvda_nixl/lib64:\$LD_LIBRARY_PATH" >> "${ENV}" diff --git a/docker/common/install_ucx.sh b/docker/common/install_ucx.sh index 613ac1c773..ba35e82ce6 100644 --- a/docker/common/install_ucx.sh +++ b/docker/common/install_ucx.sh @@ -2,29 +2,28 @@ set -ex GITHUB_URL="https://github.com" -UCX_VERSION="v1.19.0" +UCX_VERSION="v1.19.x" UCX_INSTALL_PATH="/usr/local/ucx/" CUDA_PATH="/usr/local/cuda" UCX_REPO="https://github.com/openucx/ucx.git" -if [ ! -d ${UCX_INSTALL_PATH} ]; then - git clone --depth 1 -b ${UCX_VERSION} ${UCX_REPO} - cd ucx - ./autogen.sh - ./contrib/configure-release \ - --prefix=${UCX_INSTALL_PATH} \ - --enable-shared \ - --disable-static \ - --disable-doxygen-doc \ - --enable-optimizations \ - --enable-cma \ - --enable-devel-headers \ - --with-cuda=${CUDA_PATH} \ - --with-verbs \ - --with-dm \ - --enable-mt - make install -j$(nproc) - cd .. - rm -rf ucx # Remove UCX source to save space - echo "export LD_LIBRARY_PATH=${UCX_INSTALL_PATH}/lib:\$LD_LIBRARY_PATH" >> "${ENV}" -fi +rm -rf ${UCX_INSTALL_PATH} +git clone --depth 1 -b ${UCX_VERSION} ${UCX_REPO} +cd ucx +./autogen.sh +./contrib/configure-release \ + --prefix=${UCX_INSTALL_PATH} \ + --enable-shared \ + --disable-static \ + --disable-doxygen-doc \ + --enable-optimizations \ + --enable-cma \ + --enable-devel-headers \ + --with-cuda=${CUDA_PATH} \ + --with-verbs \ + --with-dm \ + --enable-mt +make install -j$(nproc) +cd .. +rm -rf ucx # Remove UCX source to save space +echo "export LD_LIBRARY_PATH=${UCX_INSTALL_PATH}/lib:\$LD_LIBRARY_PATH" >> "${ENV}" diff --git a/docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md b/docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md index 8f5c1dfec0..87432173b4 100644 --- a/docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md +++ b/docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md @@ -19,11 +19,11 @@ We have a forthcoming guide for achieving great performance on H100; however, th In this section, we introduce several ways to install TensorRT-LLM. -### NGC Docker Image of dev branch +### NGC Docker Image -Day-0 support for gpt-oss is provided via the NGC container image `nvcr.io/nvidia/tensorrt-llm/release:gpt-oss-dev`. This image was built on top of the pre-day-0 **dev branch**. This container is multi-platform and will run on both x64 and arm64 architectures. +Visit the [NGC TensorRT-LLM Release page](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release) to find the most up-to-date NGC container image to use. You can also check the latest [release notes](https://github.com/NVIDIA/TensorRT-LLM/releases) to keep track of the support status of the latest releases. -Run the following docker command to start the TensorRT-LLM container in interactive mode: +Run the following Docker command to start the TensorRT-LLM container in interactive mode (change the image tag to match latest release): ```bash docker run --rm --ipc=host -it \ @@ -33,7 +33,7 @@ docker run --rm --ipc=host -it \ -p 8000:8000 \ -e TRTLLM_ENABLE_PDL=1 \ -v ~/.cache:/root/.cache:rw \ - nvcr.io/nvidia/tensorrt-llm/release:gpt-oss-dev \ + nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc0 \ /bin/bash ``` @@ -53,9 +53,9 @@ Additionally, the container mounts your user `.cache` directory to save the down Support for gpt-oss has been [merged](https://github.com/NVIDIA/TensorRT-LLM/pull/6645) into the **main branch** of TensorRT-LLM. As we continue to optimize gpt-oss performance, you can build TensorRT-LLM from source to get the latest features and support. Please refer to the [doc](https://nvidia.github.io/TensorRT-LLM/latest/installation/build-from-source-linux.html) if you want to build from source yourself. -### Regular Release of TensorRT-LLM +### TensorRT-LLM Python Wheel Install -Since gpt-oss has been supported on the main branch, you can get TensorRT-LLM out of the box through its regular release in the future. Please check the latest [release notes](https://github.com/NVIDIA/TensorRT-LLM/releases) to keep track of the support status. The release is provided as [NGC Container Image](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release/tags) or [pip Python wheel](https://pypi.org/project/tensorrt-llm/#history). You can find instructions on pip install [here](https://nvidia.github.io/TensorRT-LLM/installation/linux.html). +Regular releases of TensorRT-LLM are also provided as [Python wheels](https://pypi.org/project/tensorrt-llm/#history). You can find instructions on the pip install [here](https://nvidia.github.io/TensorRT-LLM/installation/linux.html). ## Performance Benchmarking and Model Serving @@ -210,7 +210,10 @@ We can use `trtllm-serve` to serve the model by translating the benchmark comman ```bash trtllm-serve \ - gpt-oss-120b \ # Or ${local_model_path} +Note: You can also point to a local path containing the model weights instead of the HF repo (e.g., `${local_model_path}`). + +trtllm-serve \ + openai/gpt-oss-120b \ --host 0.0.0.0 \ --port 8000 \ --backend pytorch \ @@ -228,7 +231,8 @@ For max-throughput configuration, run: ```bash trtllm-serve \ - gpt-oss-120b \ # Or ${local_model_path} +trtllm-serve \ + openai/gpt-oss-120b \ --host 0.0.0.0 \ --port 8000 \ --backend pytorch \ @@ -262,7 +266,7 @@ curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d ' "messages": [ { "role": "user", - "content": "What is NVIDIA's advantage for inference?" + "content": "What is NVIDIAs advantage for inference?" } ], "max_tokens": 1024, @@ -348,12 +352,7 @@ others according to your needs. ## (H200/H100 Only) Using OpenAI Triton Kernels for MoE -OpenAI ships a set of Triton kernels optimized for its MoE models. TensorRT-LLM can leverage these kernels for Hopper-based GPUs like NVIDIA's H200 for optimal performance. `TRTLLM` MoE backend is not supported on Hopper, and `CUTLASS` backend support is still ongoing. Please enable `TRITON` backend with the steps below if you are running on Hopper GPUs. - -### Installing OpenAI Triton - -The `nvcr.io/nvidia/tensorrt-llm/release:gpt-oss-dev` has prepared Triton already (`echo $TRITON_ROOT` could reveal the path). In other situations, you will need to build and install a specific version of Triton. Please follow the instructions in this [link](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/gpt_oss#using-openai-triton-kernels-for-moe). - +OpenAI ships a set of Triton kernels optimized for its MoE models. TensorRT-LLM can leverage these kernels for Hopper-based GPUs like NVIDIA's H200 for optimal performance. `TRTLLM` MoE backend is not supported on Hopper, and `CUTLASS` backend support is still ongoing. Please follow the instructions in this [link](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/gpt_oss#using-openai-triton-kernels-for-moe) to install and enable the `TRITON` MoE kernels on Hopper GPUs. ### Selecting Triton as the MoE backend diff --git a/docs/source/deployment-guide/quick-start-recipe-for-gpt-oss-on-trtllm.md b/docs/source/deployment-guide/quick-start-recipe-for-gpt-oss-on-trtllm.md new file mode 100644 index 0000000000..b201deb8f4 --- /dev/null +++ b/docs/source/deployment-guide/quick-start-recipe-for-gpt-oss-on-trtllm.md @@ -0,0 +1,328 @@ +# Quick Start Recipe for GPT-OSS on TensorRT-LLM - Blackwell Hardware + +## Introduction + +This deployment guide provides step-by-step instructions for running the GPT-OSS model using TensorRT-LLM, optimized for NVIDIA GPUs. It covers the complete setup required; from accessing model weights and preparing the software environment to configuring TensorRT-LLM parameters, launching the server, and validating inference output. + +The guide is intended for developers and practitioners seeking high-throughput or low-latency inference using NVIDIA’s accelerated stack—starting with the PyTorch container from NGC, then installing TensorRT-LLM for model serving. + +## Prerequisites + +* GPU: NVIDIA Blackwell Architecture +* OS: Linux +* Drivers: CUDA Driver 575 or Later +* Docker with NVIDIA Container Toolkit installed +* Python3 and python3-pip (Optional, for accuracy evaluation only) + +## Models + +* MXFP4 model: [GPT-OSS-120B](https://huggingface.co/openai/gpt-oss-120b) + + +## MoE Backend Support Matrix + +There are multiple MOE backends inside TRT-LLM. Here are the support matrix of the MOE backends. + +| Device | Activation Type | MoE Weights Type | MoE Backend | Use Case | +|------------|------------------|------------------|-------------|----------------| +| B200/GB200 | MXFP8 | MXFP4 | TRTLLM | Low Latency | +| B200/GB200 | MXFP8 | MXFP4 | CUTLASS | Max Throughput | + +The default moe backend is `CUTLASS`, so for the combination which is not supported by `CUTLASS`, one must set the `moe_config.backend` explicitly to run the model. + +## Deployment Steps + +### Run Docker Container + +Run the docker container using the TensorRT-LLM NVIDIA NGC image. + +```shell +docker run --rm -it \ +--ipc=host \ +--gpus all \ +-p 8000:8000 \ +-v ~/.cache:/root/.cache:rw \ +--name tensorrt_llm \ +nvcr.io/nvidia/tensorrt-llm/release:1.0.0rc6 \ +/bin/bash +``` + +Note: + +* The command mounts your user `.cache` directory to save the downloaded model checkpoints which are saved to `~/.cache/huggingface/hub/` by default. This prevents having to redownload the weights each time you rerun the container. If the `~/.cache` directory doesn’t exist please create it using `$ mkdir ~/.cache`. +* You can mount additional directories and paths using the `-v :` flag if needed, such as mounting the downloaded weight paths. +* The command also maps port `8000` from the container to your host so you can access the LLM API endpoint from your host +* See the for all the available containers. The containers published in the main branch weekly have `rcN` suffix, while the monthly release with QA tests has no `rcN` suffix. Use the `rc` release to get the latest model and feature support. + +If you want to use latest main branch, you can choose to build from source to install TensorRT-LLM, the steps refer to . + +### Creating the TRT-LLM Server config + +We create a YAML configuration file `/tmp/config.yml` for the TensorRT-LLM Server and populate it with the following recommended performance settings. + +For low-latency with `TRTLLM` MOE backend: + +```shell +EXTRA_LLM_API_FILE=/tmp/config.yml + +cat << EOF > ${EXTRA_LLM_API_FILE} +enable_attention_dp: false +cuda_graph_config: + enable_padding: true + max_batch_size: 128 +moe_config: + backend: TRTLLM +EOF +``` + +For max-throughput with `CUTLASS` MOE backend: + +```shell +EXTRA_LLM_API_FILE=/tmp/config.yml + +cat << EOF > ${EXTRA_LLM_API_FILE} +enable_attention_dp: true +cuda_graph_config: + enable_padding: true + max_batch_size: 128 +moe_config: + backend: CUTLASS +EOF +``` + +### Launch the TRT-LLM Server + +Below is an example command to launch the TRT-LLM server with the GPT-OSS model from within the container. The command is specifically configured for the 1024/1024 Input/Output Sequence Length test. The explanation of each flag is shown in the “Configs and Parameters” section. + +```shell +trtllm-serve openai/gpt-oss-120b \ + --host 0.0.0.0 \ + --port 8000 \ + --backend pytorch \ + --max_batch_size 128 \ + --max_num_tokens 16384 \ + --max_seq_len 2048 \ + --kv_cache_free_gpu_memory_fraction 0.9 \ + --tp_size 8 \ + --ep_size 8 \ + --trust_remote_code \ + --extra_llm_api_options ${EXTRA_LLM_API_FILE} +``` + +After the server is set up, the client can now send prompt requests to the server and receive results. + +### Configs and Parameters + +These options are used directly on the command line when you start the `trtllm-serve` process. + +#### `--tp_size` + +* **Description:** Sets the **tensor-parallel size**. This should typically match the number of GPUs you intend to use for a single model instance. + +#### `--ep_size` + +* **Description:** Sets the **expert-parallel size** for Mixture-of-Experts (MoE) models. Like `tp_size`, this should generally match the number of GPUs you're using. This setting has no effect on non-MoE models. + +#### `--kv_cache_free_gpu_memory_fraction` + +* **Description:** A value between `0.0` and `1.0` that specifies the fraction of free GPU memory to reserve for the KV cache after the model is loaded. Since memory usage can fluctuate, this buffer helps prevent out-of-memory (OOM) errors. +* **Recommendation:** If you experience OOM errors, try reducing this value to `0.7` or lower. + +#### `--backend pytorch` + +* **Description:** Tells TensorRT-LLM to use the **pytorch** backend. + +#### `--max_batch_size` + +* **Description:** The maximum number of user requests that can be grouped into a single batch for processing. + +#### `--max_num_tokens` + +* **Description:** The maximum total number of tokens (across all requests) allowed inside a single scheduled batch. + +#### `--max_seq_len` + +* **Description:** The maximum possible sequence length for a single request, including both input and generated output tokens. + +#### `--trust_remote_code` + +* **Description:** Allows TensorRT-LLM to download models and tokenizers from Hugging Face. This flag is passed directly to the Hugging Face API. + + +#### Extra LLM API Options (YAML Configuration) + +These options provide finer control over performance and are set within a YAML file passed to the `trtllm-serve` command via the `--extra_llm_api_options` argument. + +#### `cuda_graph_config` + +* **Description**: A section for configuring CUDA graphs to optimize performance. + +* **Options**: + + * `enable_padding`: If `"true"`, input batches are padded to the nearest `cuda_graph_batch_size`. This can significantly improve performance. + + **Default**: `false` + + * `max_batch_size`: Sets the maximum batch size for which a CUDA graph will be created. + + **Default**: `0` + + **Recommendation**: Set this to the same value as the `--max_batch_size` command-line option. + +#### `moe_config` + +* **Description**: Configuration for Mixture-of-Experts (MoE) models. + +* **Options**: + + * `backend`: The backend to use for MoE operations. + **Default**: `CUTLASS` + +See the [`TorchLlmArgs` class](https://nvidia.github.io/TensorRT-LLM/llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs) for the full list of options which can be used in the `extra_llm_api_options`. + +## Testing API Endpoint + +### Basic Test + +Start a new terminal on the host to test the TensorRT-LLM server you just launched. + +You can query the health/readiness of the server using: + +```shell +curl -s -o /dev/null -w "Status: %{http_code}\n" "http://localhost:8000/health" +``` + +When the `Status: 200` code is returned, the server is ready for queries. Note that the very first query may take longer due to initialization and compilation. + +After the TRT-LLM server is set up and shows Application startup complete, you can send requests to the server. + +```shell +curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{ + "model": "openai/gpt-oss-120b", + "messages": [ + { + "role": "user", + "content": "Where is New York?" + } + ], + "max_tokens": 1024, + "top_p": 1.0 +}' -w "\n" +``` + +Here is an example response, showing that the TRT-LLM server reasons and answers the questions. + +TODO: Use Chat Compeletions API / Responses API as the example after the PR is merged. + +```json +{"id":"chatcmpl-c5bf51b5cab94e10ba5da5266d12ee59","object":"chat.completion","created":1755815898,"model":"openai/gpt-oss-120b","choices":[{"index":0,"message":{"role":"assistant","content":"analysisThe user asks: \"Where is New York?\" Likely they want location info. Provide answer: New York State in northeastern US, New York City on the east coast, coordinates, etc. Provide context.assistantfinal**New York** can refer to two related places in the United States:\n\n| What it is | Where it is | Approx. coordinates | How to picture it |\n|------------|------------|--------------------|-------------------|\n| **New York State** | The northeastern corner of the United States, bordered by **Vermont, Massachusetts, Connecticut, New Jersey, Pennsylvania, and the Canadian provinces of Ontario and Quebec**. | 42.7° N, 75.5° W (roughly the state’s geographic centre) | A roughly rectangular state that stretches from the Atlantic Ocean in the southeast to the Adirondack Mountains and the Great Lakes region in the north. |\n| **New York City (NYC)** | The largest city in the state, located on the **southern tip of the state** where the **Hudson River meets the Atlantic Ocean**. It occupies five boroughs: Manhattan, Brooklyn, Queens, The Bronx, and Staten Island. | 40.7128° N, 74.0060° W | A dense, world‑famous metropolis that sits on a series of islands (Manhattan, Staten Island, parts of the Bronx) and the mainland (Brooklyn and Queens). |\n\n### Quick geographic context\n- **On a map of the United States:** New York State is in the **Northeast** region, just east of the Great Lakes and north of Pennsylvania. \n- **From Washington, D.C.:** Travel roughly **225 mi (360 km) northeast**. \n- **From Boston, MA:** Travel about **215 mi (350 km) southwest**. \n- **From Toronto, Canada:** Travel about **500 mi (800 km) southeast**.\n\n### Travel tips\n- **By air:** Major airports include **John F. Kennedy International (JFK)**, **LaGuardia (LGA)**, and **Newark Liberty International (EWR)** (the latter is actually in New Jersey but serves the NYC metro area). \n- **By train:** Amtrak’s **Northeast Corridor** runs from **Boston → New York City → Washington, D.C.** \n- **By car:** Interstates **I‑87** (north‑south) and **I‑90** (east‑west) are the primary highways crossing the state.\n\n### Fun fact\n- The name “**New York**” was given by the English in 1664, honoring the Duke of York (later King James II). The city’s original Dutch name was **“New Amsterdam.”**\n\nIf you need more specific directions (e.g., how to get to a particular neighborhood, landmark, or the state capital **Albany**), just let me know!","reasoning_content":null,"tool_calls":[]},"logprobs":null,"finish_reason":"stop","stop_reason":null,"mm_embedding_handle":null,"disaggregated_params":null,"avg_decoded_tokens_per_iter":1.0}],"usage":{"prompt_tokens":72,"total_tokens":705,"completion_tokens":633},"prompt_token_ids":null} +``` + +### Troubleshooting Tips + +* If you encounter CUDA out-of-memory errors, try reducing `max_batch_size` or `max_seq_len`. +* Ensure your model checkpoints are compatible with the expected format. +* For performance issues, check GPU utilization with nvidia-smi while the server is running. +* If the container fails to start, verify that the NVIDIA Container Toolkit is properly installed. +* For connection issues, make sure the server port (`8000` in this guide) is not being used by another application. + +### Running Evaluations to Verify Accuracy (Optional) + +We use OpenAI's official evaluation tool to test the model's accuracy. For more information see [https://github.com/openai/gpt-oss/tree/main/gpt_oss/evals](gpt-oss-eval). + +TODO(@Binghan Chen): Add instructions for running gpt-oss-eval. + +## Benchmarking Performance + +To benchmark the performance of your TensorRT-LLM server you can leverage the built-in `benchmark_serving.py` script. To do this first creating a wrapper `bench.sh` script. + +```shell +cat <<'EOF' > bench.sh +#!/usr/bin/env bash +set -euo pipefail + +concurrency_list="32 64 128 256 512 1024 2048 4096" +multi_round=5 +isl=1024 +osl=1024 +result_dir=/tmp/gpt_oss_output + +for concurrency in ${concurrency_list}; do + num_prompts=$((concurrency * multi_round)) + python -m tensorrt_llm.serve.scripts.benchmark_serving \ + --model openai/gpt-oss-120b \ + --backend openai \ + --dataset-name "random" \ + --random-input-len ${isl} \ + --random-output-len ${osl} \ + --random-prefix-len 0 \ + --random-ids \ + --num-prompts ${num_prompts} \ + --max-concurrency ${concurrency} \ + --ignore-eos \ + --tokenize-on-client \ + --percentile-metrics "ttft,tpot,itl,e2el" +done +EOF +chmod +x bench.sh +``` + +If you want to save the results to a file add the following options. + +```shell +--save-result \ +--result-dir "${result_dir}" \ +--result-filename "concurrency_${concurrency}.json" +``` + +For more benchmarking options see . + +Run `bench.sh` to begin a serving benchmark. This will take a long time if you run all the concurrencies mentioned in the above `bench.sh` script. + +```shell +./bench.sh +``` + +Sample TensorRT-LLM serving benchmark output. Your results may vary due to ongoing software optimizations. + +``` +============ Serving Benchmark Result ============ +Successful requests: 16 +Benchmark duration (s): 17.66 +Total input tokens: 16384 +Total generated tokens: 16384 +Request throughput (req/s): [result] +Output token throughput (tok/s): [result] +Total Token throughput (tok/s): [result] +User throughput (tok/s): [result] +---------------Time to First Token---------------- +Mean TTFT (ms): [result] +Median TTFT (ms): [result] +P99 TTFT (ms): [result] +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): [result] +Median TPOT (ms): [result] +P99 TPOT (ms): [result] +---------------Inter-token Latency---------------- +Mean ITL (ms): [result] +Median ITL (ms): [result] +P99 ITL (ms): [result] +----------------End-to-end Latency---------------- +Mean E2EL (ms): [result] +Median E2EL (ms): [result] +P99 E2EL (ms): [result] +================================================== +``` + +### Key Metrics + +* Median Time to First Token (TTFT) + * The typical time elapsed from when a request is sent until the first output token is generated. +* Median Time Per Output Token (TPOT) + * The typical time required to generate each token *after* the first one. +* Median Inter-Token Latency (ITL) + * The typical time delay between the completion of one token and the completion of the next. +* Median End-to-End Latency (E2EL) + * The typical total time from when a request is submitted until the final token of the response is received. +* Total Token Throughput + * The combined rate at which the system processes both input (prompt) tokens and output (generated) tokens. diff --git a/docs/source/index.rst b/docs/source/index.rst index b0964ca287..df00acb02b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -38,6 +38,7 @@ Welcome to TensorRT-LLM's Documentation! deployment-guide/quick-start-recipe-for-llama4-scout-on-trtllm.md deployment-guide/quick-start-recipe-for-deepseek-r1-on-trtllm.md deployment-guide/quick-start-recipe-for-llama3.3-70b-on-trtllm.md + deployment-guide/quick-start-recipe-for-gpt-oss-on-trtllm.md .. toctree:: diff --git a/docs/source/torch/auto_deploy/advanced/serving_with_trtllm_serve.md b/docs/source/torch/auto_deploy/advanced/serving_with_trtllm_serve.md new file mode 100644 index 0000000000..5a73d047ea --- /dev/null +++ b/docs/source/torch/auto_deploy/advanced/serving_with_trtllm_serve.md @@ -0,0 +1,77 @@ +# Serving with trtllm-serve + +AutoDeploy integrates with the OpenAI-compatible `trtllm-serve` CLI so you can expose AutoDeploy-optimized models over HTTP without writing server code. This page shows how to launch the server with the AutoDeploy backend, configure it via YAML, and validate with a simple request. + +## Quick start + +Launch `trtllm-serve` with the AutoDeploy backend by setting `--backend _autodeploy`: + +```bash +trtllm-serve \ + meta-llama/Llama-3.1-8B-Instruct \ + --backend _autodeploy +``` + +- `model`: HF name or local path +- `--backend _autodeploy`: uses AutoDeploy runtime + +Once the server is ready, test with an OpenAI-compatible request: + +```bash +curl -s http://localhost:8000/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "meta-llama/Llama-3.1-8B-Instruct", + "messages":[{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Where is New York? Tell me in a single sentence."}], + "max_tokens": 32 + }' +``` + +## Configuration via YAML + +Use `--extra_llm_api_options` to supply a YAML file that augments or overrides server/runtime settings. + +```bash +trtllm-serve \ + meta-llama/Llama-3.1-8B \ + --backend _autodeploy \ + --extra_llm_api_options autodeploy_config.yaml +``` + +Example `autodeploy_config.yaml`: + +```yaml +# Compilation backend for AutoDeploy +compile_backend: torch-opt # options: torch-simple, torch-compile, torch-cudagraph, torch-opt + +# Runtime engine +runtime: trtllm # options: trtllm, demollm + +# Model loading +skip_loading_weights: false # set true for architecture-only perf runs + +# KV cache memory +free_mem_ratio: 0.8 # fraction of free GPU mem for KV cache + +# CUDA graph optimization +cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64] + +# Attention backend +attn_backend: flashinfer # recommended for best performance +``` + +## Limitations and tips + +- KV cache block reuse is disabled automatically for AutoDeploy backend +- AutoDeploy backend doesn't yet support disaggregated serving. WIP +- For best performance: + - Prefer `compile_backend: torch-opt` + - Use `attn_backend: flashinfer` + - Set realistic `cuda_graph_batch_sizes` that match expected traffic + - Tune `free_mem_ratio` to 0.8–0.9 + +## See also + +- [AutoDeploy overview](../auto-deploy.md) +- [Benchmarking with trtllm-bench](./benchmarking_with_trtllm_bench.md) diff --git a/docs/source/torch/auto_deploy/auto-deploy.md b/docs/source/torch/auto_deploy/auto-deploy.md index fc00c0ccc3..185e1f321a 100644 --- a/docs/source/torch/auto_deploy/auto-deploy.md +++ b/docs/source/torch/auto_deploy/auto-deploy.md @@ -59,6 +59,7 @@ The exported graph then undergoes a series of automated transformations, includi - [Incorporating AutoDeploy into Your Own Workflow](./advanced/workflow.md) - [Expert Configurations](./advanced/expert_configurations.md) - [Performance Benchmarking](./advanced/benchmarking_with_trtllm_bench.md) +- [Serving with trtllm-serve](./advanced/serving_with_trtllm_serve.md) ## Roadmap diff --git a/examples/constraints.txt b/examples/constraints.txt index 4ce23b0de7..8b0d1a0093 100644 --- a/examples/constraints.txt +++ b/examples/constraints.txt @@ -1,3 +1,3 @@ -tensorrt_llm==1.1.0rc1 +tensorrt_llm==1.1.0rc2 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/llm-api/star_attention.py b/examples/llm-api/star_attention.py index 367f7cc843..d87895e71a 100644 --- a/examples/llm-api/star_attention.py +++ b/examples/llm-api/star_attention.py @@ -7,8 +7,8 @@ from difflib import SequenceMatcher import torch from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm.llmapi.llm_args import KvCacheConfig from tensorrt_llm.mapping import CpType -from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig def dump_jsonl(data, fname): @@ -54,11 +54,8 @@ def similarity_score(a, b): return SequenceMatcher(None, a, b).ratio() -# Generate the outputs using either TRT or PyTorch (based on the use_pytorch argument). It’s the same function for both workflows. def generate_llm_outputs(args, data, fp8=False, fp8_kv_cache=False): - quant_config = QuantConfig(quant_algo=QuantAlgo.FP8, - kv_cache_quant_algo=QuantAlgo.FP8 if fp8_kv_cache - else None) if fp8 else QuantConfig() + kv_cache_config = KvCacheConfig(dtype="fp8" if fp8_kv_cache else "auto") cp_config = { "cp_type": CpType.STAR, "cp_anchor_size": args.sa_anchor_size, @@ -70,7 +67,7 @@ def generate_llm_outputs(args, data, fp8=False, fp8_kv_cache=False): max_input_len=args.max_input_len, max_seq_len=args.max_seq_len, max_num_tokens=args.max_num_tokens, - quant_config=quant_config, + kv_cache_config=kv_cache_config, tensor_parallel_size=1, context_parallel_size=args.num_procs, cp_config=cp_config, diff --git a/jenkins/Build.groovy b/jenkins/Build.groovy index b7ff896665..ecacc33f3a 100644 --- a/jenkins/Build.groovy +++ b/jenkins/Build.groovy @@ -57,10 +57,10 @@ def CONFIG_LINUX_AARCH64_CU12 = "linux_aarch64_CU12" def CONFIG_LINUX_AARCH64_LLVM = "linux_aarch64_LLVM" @Field -def CONFIG_LINUX_X86_64_NANOBIND = "linux_x86_64_Nanobind" +def CONFIG_LINUX_X86_64_PYBIND = "linux_x86_64_Pybind" @Field -def CONFIG_LINUX_AARCH64_NANOBIND = "linux_aarch64_Nanobind" +def CONFIG_LINUX_AARCH64_PYBIND = "linux_aarch64_Pybind" @Field def BUILD_CONFIGS = [ @@ -76,9 +76,9 @@ def BUILD_CONFIGS = [ (TARNAME) : "TensorRT-LLM-CU12.tar.gz", (WHEEL_ARCHS): "80-real;86-real;89-real;90-real;100-real;120-real", ], - (CONFIG_LINUX_X86_64_NANOBIND) : [ - (WHEEL_EXTRA_ARGS) : "--binding_type nanobind --extra-cmake-vars ENABLE_MULTI_DEVICE=1 --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --micro_benchmarks", - (TARNAME) : "nanobind-TensorRT-LLM.tar.gz", + (CONFIG_LINUX_X86_64_PYBIND) : [ + (WHEEL_EXTRA_ARGS) : "--binding_type pybind --extra-cmake-vars ENABLE_MULTI_DEVICE=1 --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --micro_benchmarks", + (TARNAME) : "pybind-TensorRT-LLM.tar.gz", (WHEEL_ARCHS): "80-real;86-real;89-real;90-real;100-real;120-real", ], (CONFIG_LINUX_X86_64_SINGLE_DEVICE) : [ @@ -101,9 +101,9 @@ def BUILD_CONFIGS = [ (TARNAME) : "TensorRT-LLM-GH200-CU12.tar.gz", (WHEEL_ARCHS): "90-real;100-real;120-real", ], - (CONFIG_LINUX_AARCH64_NANOBIND): [ - (WHEEL_EXTRA_ARGS) : "--binding_type nanobind --extra-cmake-vars WARNING_IS_ERROR=ON", - (TARNAME) : "nanobind-TensorRT-LLM-GH200.tar.gz", + (CONFIG_LINUX_AARCH64_PYBIND): [ + (WHEEL_EXTRA_ARGS) : "--binding_type pybind --extra-cmake-vars WARNING_IS_ERROR=ON", + (TARNAME) : "pybind-TensorRT-LLM-GH200.tar.gz", (WHEEL_ARCHS): "90-real;100-real;120-real", ], (CONFIG_LINUX_AARCH64_LLVM) : [ @@ -568,8 +568,8 @@ def launchStages(pipeline, cpu_arch, enableFailFast, globalVars) pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64_CU12 : CONFIG_LINUX_X86_64_VANILLA_CU12), "Build TRT-LLM LLVM": [LLM_DOCKER_IMAGE] + prepareLLMBuild( pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64_LLVM : CONFIG_LINUX_X86_64_LLVM), - "Build TRT-LLM Nanobind": [LLM_DOCKER_IMAGE] + prepareLLMBuild( - pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64_NANOBIND : CONFIG_LINUX_X86_64_NANOBIND), + "Build TRT-LLM Pybind": [LLM_DOCKER_IMAGE] + prepareLLMBuild( + pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64_PYBIND : CONFIG_LINUX_X86_64_PYBIND), ] if (cpu_arch == X86_64_TRIPLE) { diff --git a/jenkins/L0_Test.groovy b/jenkins/L0_Test.groovy index 5e52ebe7c5..b7cdf167c5 100644 --- a/jenkins/L0_Test.groovy +++ b/jenkins/L0_Test.groovy @@ -74,7 +74,7 @@ def LINUX_AARCH64_CONFIG = "linux_aarch64" def LINUX_AARCH64_CONFIG_CU12 = "linux_aarch64_CU12" @Field -def NANOBIND_CONFIG = "Nanobind" +def PYBIND_CONFIG = "Pybind" @Field def BUILD_CONFIGS = [ @@ -85,7 +85,7 @@ def BUILD_CONFIGS = [ (LLVM_CONFIG) : [(TARNAME) : "llvm-TensorRT-LLM.tar.gz"], (LINUX_AARCH64_CONFIG) : [(TARNAME) : "TensorRT-LLM-GH200.tar.gz"], (LINUX_AARCH64_CONFIG_CU12) : [(TARNAME) : "TensorRT-LLM-GH200-CU12.tar.gz"], - (NANOBIND_CONFIG) : [(TARNAME) : "nanobind-TensorRT-LLM.tar.gz"], + (PYBIND_CONFIG) : [(TARNAME) : "pybind-TensorRT-LLM.tar.gz"], ] // TODO: Move common variables to an unified location @@ -657,8 +657,7 @@ def createKubernetesPodConfig(image, type, arch = "amd64", gpuCount = 1, perfMod def driverVersion = Constants.DEFAULT_NVIDIA_DRIVER_VERSION def cpuCount = "${TESTER_CORES}" - // Multi-GPU only supports DGX-H100 and DGX-H200 due to the hardware stability. - if ((type.contains("dgx-h100") || type.contains("dgx-h200")) && hasMultipleGPUs) + if (hasMultipleGPUs) { // Not a hard requirement, but based on empirical values. memorySize = "${gpuCount * 150}" + "Gi" @@ -672,7 +671,7 @@ def createKubernetesPodConfig(image, type, arch = "amd64", gpuCount = 1, perfMod targetCould = "kubernetes" // The following GPU types doesn't support dynamic driver flashing. - if (type.contains("dgx-h100") || type.contains("dgx-h200") || type in ["b100-ts2", "gh200", "rtx-5080", "rtx-5090"]) { + if (type.contains("dgx-h100") || type.contains("dgx-h200") || type.contains("rtx-pro-6000") || type in ["b100-ts2", "gh200", "rtx-5080", "rtx-5090"]) { selectors = """ kubernetes.io/arch: ${arch} kubernetes.io/os: linux @@ -1281,6 +1280,7 @@ def runLLMTestlistOnPlatformImpl(pipeline, platform, testList, config=VANILLA_CO echoNodeAndGpuInfo(pipeline, stageName) sh "cat ${MODEL_CACHE_DIR}/README" sh "nvidia-smi -q" + sh "nvidia-smi topo -m" sh "df -h" // setup HF_HOME to cache model and datasets @@ -1789,7 +1789,7 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null) "A10-TensorRT-4": ["a10", "l0_a10", 4, 6], "A10-TensorRT-5": ["a10", "l0_a10", 5, 6], "A10-TensorRT-6": ["a10", "l0_a10", 6, 6], - "A10-Nanobind": ["a10", "l0_a10_nanobind", 1, 1], + "A10-Pybind": ["a10", "l0_a10_pybind", 1, 1], "A30-Triton-1": ["a30", "l0_a30", 1, 1], "A30-PyTorch-1": ["a30", "l0_a30", 1, 2], "A30-PyTorch-2": ["a30", "l0_a30", 2, 2], @@ -1809,8 +1809,6 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null) "B200_PCIe-PyTorch-1": ["b100-ts2", "l0_b200", 1, 3], "B200_PCIe-PyTorch-2": ["b100-ts2", "l0_b200", 2, 3], "B200_PCIe-PyTorch-3": ["b100-ts2", "l0_b200", 3, 3], - "B200_PCIe-TensorRT-1": ["b100-ts2", "l0_b200", 1, 2], - "B200_PCIe-TensorRT-2": ["b100-ts2", "l0_b200", 2, 2], "RTX5090-PyTorch-1": ["rtx-5090", "l0_gb202", 1, 1], "RTX5080-TensorRT-1": ["rtx-5080", "l0_gb203", 1, 2], "RTX5080-TensorRT-2": ["rtx-5080", "l0_gb203", 2, 2], @@ -1850,6 +1848,8 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null) "H100_PCIe-TensorRT-Post-Merge-5": ["h100-cr", "l0_h100", 5, 5], "B200_PCIe-Triton-Post-Merge-1": ["b100-ts2", "l0_b200", 1, 1], "B200_PCIe-PyTorch-Post-Merge-1": ["b100-ts2", "l0_b200", 1, 1], + "B200_PCIe-TensorRT-Post-Merge-1": ["b100-ts2", "l0_b200", 1, 2], + "B200_PCIe-TensorRT-Post-Merge-2": ["b100-ts2", "l0_b200", 2, 2], "H100_PCIe-TensorRT-Perf-1": ["h100-cr", "l0_perf", 1, 1], "H100_PCIe-PyTorch-Perf-1": ["h100-cr", "l0_perf", 1, 1], "DGX_H200-8_GPUs-PyTorch-Post-Merge-1": ["dgx-h200-x8", "l0_dgx_h200", 1, 1, 8], @@ -1857,6 +1857,9 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null) "DGX_H200-4_GPUs-TensorRT-Post-Merge-1": ["dgx-h200-x4", "l0_dgx_h200", 1, 3, 4], "DGX_H200-4_GPUs-TensorRT-Post-Merge-2": ["dgx-h200-x4", "l0_dgx_h200", 2, 3, 4], "DGX_H200-4_GPUs-TensorRT-Post-Merge-3": ["dgx-h200-x4", "l0_dgx_h200", 3, 3, 4], + "RTXPro6000-Pytorch-Post-Merge-1": ["rtx-pro-6000", "l0_rtx_pro_6000", 1, 1], + "RTXPro6000-4_GPUs-Pytorch-Post-Merge-1": ["rtx-pro-6000-x4", "l0_rtx_pro_6000", 1, 2, 4], + "RTXPro6000-4_GPUs-Pytorch-Post-Merge-2": ["rtx-pro-6000-x4", "l0_rtx_pro_6000", 2, 2, 4], ] parallelJobs = x86TestConfigs.collectEntries{key, values -> [key, [createKubernetesPodConfig(key.contains("-CU12-") ? LLM_DOCKER_IMAGE_12_9 : LLM_DOCKER_IMAGE, values[0], "amd64", values[4] ?: 1, key.contains("Perf")), { @@ -1867,8 +1870,8 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null) if (key.contains("llvm")) { config = LLVM_CONFIG } - if (key.contains("Nanobind")) { - config = NANOBIND_CONFIG + if (key.contains("Pybind")) { + config = PYBIND_CONFIG } if (key.contains("-CU12-")) { config = VANILLA_CONFIG_CU12 @@ -1878,7 +1881,6 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null) fullSet = parallelJobs.keySet() x86SlurmTestConfigs = [ - "RTXPro6000-PyTorch-Post-Merge-1": ["rtx-pro-6000", "l0_rtx_pro_6000", 1, 1], "DGX_B200-4_GPUs-PyTorch-Post-Merge-1": ["b200-x4", "l0_dgx_b200", 1, 1, 4], ] fullSet += x86SlurmTestConfigs.keySet() @@ -2099,11 +2101,6 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null) checkPipStage = true } - if (cpu_arch == AARCH64_TRIPLE && values[5] != DLFW_IMAGE) { - checkPipStage = false - echo "Skip pip install sanity check due to https://nvbugs/5453827" - } - if (checkPipStage) { stage("Run LLMAPI tests") { pipInstallSanitySpec = createKubernetesPodConfig(values[5], gpu_type, k8s_arch) @@ -2484,7 +2481,7 @@ pipeline { def testPhase2StageName = env.testPhase2StageName if (testPhase2StageName) { - def dgxSigns = ["DGX_H100", "DGX_H200", "GB200", "DGX_B200"] + def dgxSigns = ["DGX_H100", "DGX_H200", "GB200", "DGX_B200", "RTXPro6000-4_GPUs"] singleGpuJobs = parallelJobs.findAll{!dgxSigns.any{sign -> it.key.contains(sign)}} dgxJobs = parallelJobs.findAll{dgxSigns.any{sign -> it.key.contains(sign)}} } diff --git a/requirements.txt b/requirements.txt index 9e7ce380e9..f61d12ad59 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ --extra-index-url https://download.pytorch.org/whl/cu128 -c constraints.txt -accelerate>=0.25.0 +accelerate>=1.7.0 build colored cuda-python>=12 diff --git a/scripts/build_wheel.py b/scripts/build_wheel.py index d7cd4c61f1..a1275bf106 100755 --- a/scripts/build_wheel.py +++ b/scripts/build_wheel.py @@ -435,7 +435,7 @@ def main(*, install: bool = False, skip_building_wheel: bool = False, linking_install_binary: bool = False, - binding_type: str = "pybind", + binding_type: str = "nanobind", benchmarks: bool = False, micro_benchmarks: bool = False, nvtx: bool = False, @@ -984,8 +984,8 @@ def add_arguments(parser: ArgumentParser): ) parser.add_argument("--binding_type", choices=["pybind", "nanobind"], - default="pybind", - help="Which binding type to build: pybind or nanobind") + default="nanobind", + help="Which binding library to use: pybind or nanobind") parser.add_argument("--benchmarks", action="store_true", help="Build the benchmarks for the C++ runtime") diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index f7ad7934a9..041d51e73d 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -19,6 +19,11 @@ transforms: stage: post_export cleanup_input_constraints: stage: post_export + ############################################################################################ + # RUN PATTERN MATCHER TRANSFORMATIONS TO STANDARDIZE GRAPH REPRESENTATION + ############################################################################################ + match_moe_pattern: + stage: pattern_matcher match_repeat_kv: stage: pattern_matcher match_eager_attention: @@ -27,12 +32,13 @@ transforms: stage: pattern_matcher match_attention_layout: stage: pattern_matcher - match_moe_pattern: - stage: pattern_matcher match_rope_pattern: stage: pattern_matcher match_rope_layout: stage: pattern_matcher + ############################################################################################ + # RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION + ############################################################################################ eliminate_redundant_transposes: stage: pattern_matcher # TODO (lucaslie): let's move this to perf optimization once TP sharding is improved @@ -57,5 +63,44 @@ transforms: sharding_transform_executor: stage: sharding run_shape_prop: true + ############################################################################################ + # MOVE MODEL AND LOAD WEIGHTS + ############################################################################################ load_weights: stage: weight_load + ############################################################################################ + # RUN POST-LOAD FUSION AND OPTIMIZATIONS + ############################################################################################ + # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs + # fuse_moe: + # stage: post_load_fusion + # fuse_gemms: + # stage: post_load_fusion + fuse_allreduce_residual_rmsnorm: + stage: post_load_fusion + fuse_collectives: + stage: post_load_fusion + # TODO (lucaslie): add backend selection as part of configurable inference optimizers + # check if we can fuse rmsnorm + fuse_rmsnorm: + stage: post_load_fusion + backend: flashinfer + ############################################################################################ + # SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES + ############################################################################################ + update_in_out_nodes: + stage: cache_init + insert_cached_attention: + stage: cache_init + insert_cached_mla_attention: + stage: cache_init + attn_backend: MultiHeadLatentAttention + initialize_cache: + stage: cache_init + resize_kv_cache: + stage: cache_init + ############################################################################################ + # COMPILE MODEL + ############################################################################################ + compile_model: + stage: compile diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py index 414039a506..01fb0deb57 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py @@ -198,7 +198,6 @@ def prepare_flashinfer_metadata( flashinfer.get_seq_lens(paged_kv_indptr, paged_kv_last_page_len, page_size), position_ids.numel(), ) - # return metadata return ( qo_indptr, diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index 812dfea29c..9811274a8b 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -274,6 +274,16 @@ class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings): self._quant_config = value ### VALIDATION ################################################################################# + @field_validator("max_seq_len", mode="before") + @classmethod + def ensure_max_seq_len(cls, value: Any, info: ValidationInfo) -> Any: + if value is None: + # Fallback to the AutoDeployConfig default when not provided + return AutoDeployConfig.model_fields["max_seq_len"].get_default( + call_default_factory=True + ) + return value + @field_validator("build_config", mode="before") @classmethod def ensure_no_build_config(cls, value: Any, info: ValidationInfo) -> Any: diff --git a/tensorrt_llm/_torch/auto_deploy/transform/interface.py b/tensorrt_llm/_torch/auto_deploy/transform/interface.py index 1087714177..cddc56b872 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/interface.py @@ -54,6 +54,7 @@ class SharedConfig(BaseModel): sharding_config: ShardingConfig = Field(default_factory=ShardingConfig) local_rank: int = Field(default=0) world_size: int = Field(default=1) + attn_backend: str = Field(default="flashinfer", description="The attention backend to use.") class TransformConfig(BaseModel): @@ -285,7 +286,10 @@ class BaseTransform(ABC): # update + store new meta data history[t_name] = info autodeploy_meta[self._history_key] = history - self._set_autodeploy_meta(gm, autodeploy_meta) + + if isinstance(gm, GraphModule): + # After compilation, gm becomes type CapturedGraph with no meta data. + self._set_autodeploy_meta(gm, autodeploy_meta) # return the graph module return gm diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py b/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py new file mode 100644 index 0000000000..6c5b1fe2b9 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py @@ -0,0 +1,204 @@ +import operator +from typing import Tuple + +import torch +from torch.fx import GraphModule + +from ...distributed.trtllm import is_trtllm_op_available +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils.node_utils import get_op_overload_packet, get_user_if_pattern_match, is_op +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry + +# TODO: This is an overly simplified model that works well for vanilla Llama models. +# However, we eventually want to consider more sophisticated patterns such as +# * all_reduce(lin1(x) + lin2(x)) +# * version above with fused GEMMs (i.e. with a split node) +# * all_reduce(pointwise_op(linear(x))) +# * ... + + +@TransformRegistry.register("fuse_collectives") +class FuseCollectives(BaseTransform): + """ + Fuses all_reduce ops with preceding (quantized) linear ops into a single fused node for improved performance. + """ + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + num_gemm_collective_fusions = 0 + + # lookup for fused ops + # TODO: avoid this hardcoded lookup, e.g., by generating fused ops on the fly. + lookup = { + torch.ops.auto_deploy.torch_linear_simple: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce, + torch.ops.aten.linear: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce, + torch.ops.auto_deploy.torch_quant_fp8_linear: torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce, + } + + # go through all nodes and find all_reduce nodes + for node in gm.graph.nodes: + if not is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce): + continue + + # check if args are as expected + assert len(node.args) == 1 and not len(node.kwargs), ( + "Unexpected args/kwargs for all_reduce" + ) + + # retrieve parent and check a few conditions on the parent node + parent_node = node.args[0] + if not is_op(parent_node, lookup.keys()): + continue + if len(parent_node.users) > 1: + continue + + with gm.graph.inserting_before(node): + # insert fused node + fused_linear_collective_node = gm.graph.call_function( + lookup[get_op_overload_packet(parent_node.target)], + args=parent_node.args, + kwargs=parent_node.kwargs, + ) + node.replace_all_uses_with(fused_linear_collective_node) + gm.graph.erase_node(node) + gm.graph.erase_node(parent_node) + num_gemm_collective_fusions += 1 + + info = TransformInfo( + skipped=False, + num_matches=num_gemm_collective_fusions, + is_clean=False, + has_valid_shapes=False, + ) + + return gm, info + + +@TransformRegistry.register("fuse_allreduce_residual_rmsnorm") +class FuseAllreduceResidualRMSNorm(BaseTransform): + """Essentially, this transformation fuses the following operators into one allreduce trtllm implementation. + + * target pattern: + x = all_reduce(x) + y = x + residual + return rmsnorm(y), y + * replacement: + fused_allreduce_residual_rmsnorm(x, residual, rmsnorm_weight, rmsnorm_eps) + + """ + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + if not is_trtllm_op_available(): + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + + num_ar_r_rms_fusions = 0 + + def trace_and_fuse(allreduce_node, graph): + # Check if all_reduce is followed by addition + users = list(allreduce_node.users.keys()) + if len(users) != 1: + return # Skip if all_reduce has more than one consumer + add_node = users[0] + + # Traverse nodes for RMSNorm pattern which is composed of to_copy, pow, mean, add, refer + # the Huggingface LlamaRMSNorm implementation as example for more details + to_copy_1 = get_user_if_pattern_match(add_node, [torch.ops.aten.add, operator.add], 2) + # operand of pow and mul + pow_node = get_user_if_pattern_match( + to_copy_1, [torch.ops.aten._to_copy, torch.ops.aten.to], 2 + ) + mean_node = get_user_if_pattern_match(pow_node, torch.ops.aten.pow, 1) + add_eps_node = get_user_if_pattern_match(mean_node, torch.ops.aten.mean, 1) + rsqrt_node = get_user_if_pattern_match( + add_eps_node, [torch.ops.aten.add, operator.add], 1 + ) + mul_node_1 = get_user_if_pattern_match(rsqrt_node, torch.ops.aten.rsqrt, 1) + to_copy_2 = get_user_if_pattern_match(mul_node_1, torch.ops.aten.mul, 1) + mul_node_2 = get_user_if_pattern_match( + to_copy_2, [torch.ops.aten._to_copy, torch.ops.aten.to], 1 + ) + # check args of ops: pow(2) and mean(-1) + ARGS_MATCH = pow_node is not None and pow_node.args[1] == 2 # exponent + ARGS_MATCH &= mean_node is not None and mean_node.args[1] == [-1] # dimensions + + # Match found: Replace with fused operation + if ( + to_copy_1 + and pow_node + and mean_node + and add_eps_node + and rsqrt_node + and mul_node_1 + and to_copy_2 + and mul_node_2 + and ARGS_MATCH + ): + # Gather the inputs for the custom operation + tensor = allreduce_node.args[0] + # Identify the residual argument in the add operation + # One of the args in add_node.args is the output of all_reduce + # The same idea also applies to norm_weight + residual = ( + add_node.args[0] if add_node.args[1] is allreduce_node else add_node.args[1] + ) + norm_weight = ( + mul_node_2.args[0] if mul_node_2.args[1] is to_copy_2 else mul_node_2.args[1] + ) + eps = add_eps_node.args[1] + + # Insert nodes + with graph.inserting_before(allreduce_node): + fused_node = graph.call_function( + torch.ops.dist.fused_allreduce_residual_rmsnorm, + args=( + tensor, + residual, + norm_weight, + eps, + ), + ) + # Extract outputs from the tuple returned by `fused_node` + final_output_node = gm.graph.create_node( + "call_function", + target=operator.getitem, + args=(fused_node, 0), + ) + add_output_node = gm.graph.create_node( + "call_function", + target=operator.getitem, + args=(fused_node, 1), + ) + + # Replace all uses of rmsnorm_node with final_output_node + mul_node_2.replace_all_uses_with(final_output_node) + + # Replace all uses of add_node with add_output_node + add_node.replace_all_uses_with(add_output_node) + + nonlocal num_ar_r_rms_fusions + num_ar_r_rms_fusions += 1 + + # Traverse all nodes + for node in gm.graph.nodes: + if is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce): + trace_and_fuse(allreduce_node=node, graph=gm.graph) + + info = TransformInfo( + skipped=False, num_matches=num_ar_r_rms_fusions, is_clean=False, has_valid_shapes=False + ) + + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py b/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py new file mode 100644 index 0000000000..00601303b6 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py @@ -0,0 +1,65 @@ +from typing import List, Literal, Optional, Tuple, Type + +from pydantic import Field +from torch.fx import GraphModule + +from ...compile import compile_and_capture +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ..interface import ( + BaseTransform, + SharedConfig, + TransformConfig, + TransformInfo, + TransformRegistry, +) + + +class CompileModelConfig(TransformConfig): + """Configuration for the compile model transform.""" + + cuda_graph_batch_sizes: Optional[List[int]] = Field( + default=None, description="The batch sizes to use for CUDA graphs." + ) + num_batched_inputs: int = Field( + default=2, description="The number of batched inputs to use for CUDA graphs." + ) + compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = ( + Field(description="The backend to use for compiling the model.") + ) + + +@TransformRegistry.register("compile_model") +class CompileModel(BaseTransform): + """A transform to compile the model.""" + + config: CompileModelConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return CompileModelConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + cm.info.set_generate_only_batch() + egm_compiled = compile_and_capture( + gm, + self.config.compile_backend, + args=cm.args, + dynamic_shapes=cm.dynamic_shapes, + compiler_kwargs={ + "cuda_graph_batch_sizes": self.config.cuda_graph_batch_sizes, + "num_batched_inputs": self.config.num_batched_inputs, + }, + ) + cm.info.reset() + + # store info object about the transform + info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True) + + return egm_compiled, info diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/fusion.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py similarity index 76% rename from tensorrt_llm/_torch/auto_deploy/transformations/library/fusion.py rename to tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py index e66ced8ae6..2d422c42d6 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/fusion.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py @@ -6,6 +6,8 @@ import torch import torch.nn as nn from torch.fx import GraphModule, Node +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface from ...utils.cuda_mem_tracker import cuda_memory_tracker from ...utils.logger import ad_logger from ...utils.node_utils import ( @@ -14,7 +16,7 @@ from ...utils.node_utils import ( is_linear_op, ) from ...utils.quantization_utils import QuantizationImpl -from .._graph import canonicalize_graph +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_nodes: List[Node]): @@ -116,30 +118,36 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node gm.delete_all_unused_submodules() -def fuse_gemms(gm: GraphModule) -> None: - ad_logger.info("GEMM fusion") - ad_logger.debug("Before GEMM fusion: " + str(gm)) - # sort linear nodes by parent node - linear_nodes = defaultdict(list) - for node in gm.graph.nodes: - # TODO: we don't handle bias for now... - if is_linear_op(node, include_quantization=True) and node.args[2] is None: - linear_nodes[node.args[0]].append(node) +@TransformRegistry.register("fuse_gemms") +class FuseGemms(BaseTransform): + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + # sort linear nodes by parent node + linear_nodes = defaultdict(list) + for node in gm.graph.nodes: + # TODO: we don't handle bias for now... + if is_linear_op(node, include_quantization=True) and node.args[2] is None: + linear_nodes[node.args[0]].append(node) - # fuse linear nodes - idx = -1 - with cuda_memory_tracker(): - for parent_node, lin_children in linear_nodes.items(): - if len(lin_children) < 2: - continue - # linear nodes to fuse - ad_logger.debug( - f"Found linear nodes to fuse: {lin_children} with parent node: {parent_node}" - ) - _insert_fused_gemm(gm, idx := idx + 1, parent_node, lin_children) + # fuse linear nodes + idx = -1 + num_matches = 0 + with cuda_memory_tracker(): + for parent_node, lin_children in linear_nodes.items(): + if len(lin_children) < 2: + continue + # linear nodes to fuse + _insert_fused_gemm(gm, idx := idx + 1, parent_node, lin_children) + num_matches += 1 - # clean up and return - canonicalize_graph(gm) + torch.cuda.empty_cache() - ad_logger.debug("After GEMM fusion: " + str(gm)) - torch.cuda.empty_cache() + info = TransformInfo( + skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False + ) + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py new file mode 100644 index 0000000000..80f9d440c1 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py @@ -0,0 +1,299 @@ +"""Graph transformation to automatically add kv cache into fused MHA op.""" + +import operator +from typing import Dict, Optional, Tuple, Type + +import torch +from pydantic import Field +from torch.fx import Graph, GraphModule, Node + +from ...custom_ops.attention_interface import AttentionRegistry +from ...distributed.common import all_gather_object, get_world_size +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...transformations._graph import add_graph_input +from ...utils.logger import ad_logger +from ...utils.node_utils import get_all_input_output_nodes, is_op +from ..interface import ( + BaseTransform, + SharedConfig, + TransformConfig, + TransformInfo, + TransformRegistry, +) + + +@TransformRegistry.register("update_in_out_nodes") +class UpdateInOutNodes(BaseTransform): + """Modify the graph module by adding new input nodes. + + The new input nodes correspond to the extra arguments needed for cached and flattened attention. + + Args: + egm: The graph module to analyze and modify. + cm: Cached sequence interface containing extra argument information. + """ + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + # loop through nodes to get input, output, and get_attr nodes + input_nodes, output_nodes = get_all_input_output_nodes(gm.graph) + + # we only expect one input node + assert len(input_nodes) == 2, "Expected exactly two input nodes (input_ids, position_ids)." + + # NOTE: for now, we wanna make sure we *only* return the final output and no hidden states. + # Later on, we can revisit how to support returning hidden states. + assert len(output_nodes) == 1, "Expected exactly one output node!" + assert len(output_nodes[0].all_input_nodes) == 1, ( + "Expected to only return final tensor output!" + ) + + # Activate and add extra argument nodes + new_args = cm.info.switch_to_cached_attn_inputs() + for name in new_args: + input_nodes.append(add_graph_input(gm, name)) + + info = TransformInfo(skipped=False, num_matches=1, is_clean=False, has_valid_shapes=False) + + return gm, info + + +class InsertCachedAttentionConfig(TransformConfig): + """Configuration for the insert cached attention transform.""" + + attn_backend: Optional[str] = Field(default=None, description="The attention backend to use.") + + +@TransformRegistry.register("insert_cached_attention") +class InsertCachedAttention(BaseTransform): + """ + A transform to insert cached attention into the graph module. + + If attn_backend is not provided in transform config, will find from shared config. + """ + + config: InsertCachedAttentionConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return InsertCachedAttentionConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + """Replace uncached source attention node with corresponding cached attn node.""" + attn_descriptor = AttentionRegistry.get(self.config.attn_backend) + + cache_config = factory.get_cache_config() + + # Get all attention nodes and their info objects + source_op = attn_descriptor.get_source_attention_op() + + # pick up graph + graph: Graph = gm.graph + + # look for relevant source attention nodes + source_attn_nodes = [n for n in graph.nodes if is_op(n, source_op)] + + if not source_attn_nodes: + # If there are no nodes for kv cache insertion found, return current graph + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + + # Sanity check + if cm.info.is_paged: + assert attn_descriptor.is_paged(), "Paged sequence info requires paged attention op." + + # retrieve input nodes + input_nodes, _ = get_all_input_output_nodes(gm.graph) + + # insert metadata computation and extract each argument as a node + get_metadata, num_metadata = attn_descriptor.get_prepare_metadata_op() + with graph.inserting_before(input_nodes[-1].next): + ret_node = graph.call_function( + get_metadata, + args=( + *input_nodes, + cm.info.page_size, + ), + ) + metadata_nodes = [ + graph.call_function(operator.getitem, args=(ret_node, idx)) + for idx in range(num_metadata) + ] + + buffer_in_lookup: Dict[str, Node] = {} + + # replace fused attention node with attention node that has kv cache + num_cached_attn_replacements = 0 + for idx, attn_node in enumerate(source_attn_nodes): + # pick out GEMMs + qkv = attn_node.args[: attn_descriptor.get_num_qkv_args()] + + # setup + store cache initializers and caches as input nodes + cache_in_nodes = [] + for k, get_cache in attn_descriptor.get_cache_initializers( + attn_node, cache_config + ).items(): + k_indexed = f"{k}_{idx}" + cm.add_cache(k_indexed, get_cache) + cache_in_nodes.append(add_graph_input(gm, k_indexed)) + + # setup + store global buffer initializers and buffers as input nodes + # NOTE: we have to check against existing keys to make sure nothing is registered twice... + buffer_in_nodes = [] + for k, get_buffer in attn_descriptor.get_global_buffer_initializers(attn_node).items(): + if k not in buffer_in_lookup: + cm.add_cache(k, get_buffer) + buffer_in_lookup[k] = add_graph_input(gm, k) + buffer_in_nodes.append(buffer_in_lookup[k]) # store buffer nodes for this op + + # retrieve constants for attention_op + constants = attn_descriptor.get_constants(attn_node) + + # insert cached attention replacement op + with graph.inserting_before(attn_node): + cached_attn_node = graph.call_function( + attn_descriptor.get_cached_attention_op(), + args=(*qkv, *metadata_nodes, *cache_in_nodes, *buffer_in_nodes, *constants), + ) + attn_node.replace_all_uses_with(cached_attn_node) + graph.erase_node(attn_node) + num_cached_attn_replacements += 1 + + info = TransformInfo( + skipped=False, + num_matches=num_cached_attn_replacements, + is_clean=False, + has_valid_shapes=False, + ) + + return gm, info + + +@TransformRegistry.register("insert_cached_mla_attention") +class InsertCachedMLAAttention(InsertCachedAttention): + """ + A transform to insert cached MLA attention into the graph module. + + This class is identical to InsertCachedAttention and inherits all its behavior. + """ + + pass + + +class ResizeKVCacheConfig(TransformConfig): + """Configuration for the resize kv cache transform.""" + + free_mem_ratio: float = Field( + description="The fraction of available memory to occupy.", default=0.8 + ) + + +@TransformRegistry.register("resize_kv_cache") +class ResizeKVCache(BaseTransform): + """Inflate the kv cache to occupy the available GPU memory. + + free_mem_ratio specifies the fraction of available memory to occupy. + """ + + config: ResizeKVCacheConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return ResizeKVCacheConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + free_mem_ratio = self.config.free_mem_ratio + + def _get_mem_info_in_mb(): + free_mem, total_mem = torch.cuda.mem_get_info() + return free_mem // 1024**2, total_mem // 1024**2 + + free_mem, total_mem = _get_mem_info_in_mb() + ad_logger.info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}") + current_cache_size = cm.current_cache_size_bytes() + current_num_pages = cm.info.num_pages + ad_logger.info( + f"Current cache size: {current_cache_size}, Current num pages: {current_num_pages}" + ) + + if free_mem_ratio == 0.0: + ad_logger.info(f"Skipping cache resize for {free_mem_ratio=}") + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + + try: + # Let's run a forward pass to get the memory usage + cm.info._set_max_num_tokens_sample() + free_mem_pre, _ = _get_mem_info_in_mb() + ad_logger.info(f"Free memory before forward pass (MB): {free_mem_pre}") + + gm(*cm.args) + + free_mem_post, _ = _get_mem_info_in_mb() + ad_logger.info(f"Free memory after forward pass (MB): {free_mem_post}") + + memory_for_forward_pass = free_mem_pre - free_mem_post + ad_logger.info(f"Memory for forward pass (MB): {memory_for_forward_pass}") + + new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size + new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages)) + + # Need to sync all the GPUs + gathered_num_pages = [None] * get_world_size() + all_gather_object(gathered_num_pages, new_num_pages) + new_num_pages = min(gathered_num_pages) + ad_logger.info(f"After all_gather - new_num_pages: {new_num_pages}") + + cm.resize_cache(new_num_pages) + except Exception as e: + ad_logger.warning( + f"Error encountered while resizing kv cache: {e}.\nSkipping cache resize." + ) + + # Free memory + torch.cuda.empty_cache() + + info = TransformInfo( + skipped=False, + num_matches=0, + is_clean=True, + has_valid_shapes=True, + ) + + return gm, info + + +@TransformRegistry.register("initialize_cache") +class InitializeCache(BaseTransform): + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + cm.initialize_caches() + + info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True) + + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py new file mode 100644 index 0000000000..1772037d93 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py @@ -0,0 +1,148 @@ +"""Graph transform to optimize RMSNorm execution using FlashInfer.""" + +from functools import partial +from typing import Tuple, Type + +import torch +from pydantic import Field +from torch.fx import GraphModule + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface + +# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher +from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern +from ..interface import ( + BaseTransform, + SharedConfig, + TransformConfig, + TransformInfo, + TransformRegistry, +) + +_BACKEND_OPS = { + "flashinfer": torch.ops.auto_deploy.flashinfer_rms_norm, + "triton": torch.ops.auto_deploy.triton_rms_norm, + "torch": torch.ops.auto_deploy.torch_rmsnorm, +} + + +def _rms_norm_pattern(data: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: + """Implements the RMSNorm pattern for pattern matching. + + Args: + data: Input tensor to normalize. + weight: Scaling weights for the normalized output. + eps: Small constant for numerical stability. + + Returns: + Normalized and scaled tensor. + """ + input_dtype = data.dtype + data = data.to(torch.float32) + variance = data.pow(2).mean(-1, keepdim=True) + data = data * torch.rsqrt(variance + eps) + return weight * data.to(input_dtype) + + +def _rms_norm_replacement( + data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str +) -> torch.Tensor: + """Backend-specific rms_norm implementation. + + Args: + data: Input tensor to normalize. + weight: Scaling weights for the normalized output. + eps: Small constant for numerical stability. + backend: Backend to use for RMSNorm computation ("flashinfer" or "triton"). + + Returns: + Normalized and scaled tensor using the specified backend implementation. + """ + + assert backend.lower() in _BACKEND_OPS, ( + f"Invalid {backend=}; must be one of {list(_BACKEND_OPS)}" + ) + return _BACKEND_OPS[backend.lower()](data, weight, eps) + + +class FuseRMSNormConfig(TransformConfig): + """Configuration for the RMSNorm fusion transform.""" + + backend: str = Field( + default="flashinfer", + description="Backend to use for RMSNorm computation ('flashinfer' or 'triton').", + ) + + +@TransformRegistry.register("fuse_rmsnorm") +class FuseRMSNorm(BaseTransform): + """Matches and replaces RMSNorm patterns in the graph with FlashInfer or Triton implementation. + + This function sets up pattern matching to identify RMSNorm operations in the graph + and replaces them with optimized implementations. It uses dummy tensors to register + the pattern matching rules. + + Args: + gm: Input graph module to transform. + backend: Backend to use for RMSNorm computation ("flashinfer" or "triton"). + + Returns: + Transformed graph module with optimized RMSNorm operations. + """ + + config: FuseRMSNormConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return FuseRMSNormConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + if self.config.backend.lower() not in _BACKEND_OPS: + raise ValueError( + f"Invalid backend, must be one of {list(_BACKEND_OPS)}, got {self.config.backend}" + ) + + graph = gm.graph + patterns = ADPatternMatcherPass() + + # Create dummy tensors for pattern matching + bs = 2 + hidden_size = 512 + + def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6): + return [ + torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype), + torch.randn(hidden_size, device="cuda", dtype=weight_dtype), + eps, + ] + + # Define configurations for different data types + configs = [ + (torch.bfloat16, torch.bfloat16), + (torch.float16, torch.float16), + (torch.float32, torch.float32), + ] + + # Register patterns for each configuration + for input_dtype, weight_dtype in configs: + register_ad_pattern( + search_fn=_rms_norm_pattern, + replace_fn=partial(_rms_norm_replacement, backend=self.config.backend), + patterns=patterns, + dummy_args=dummy_args(input_dtype, weight_dtype), + op_ignore_types={}, + scalar_workaround={"eps": 1e-6}, + ) + + cnt = patterns.apply(graph) + + info = TransformInfo(skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=False) + + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py index 0d4c388ebc..e5260ada48 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py @@ -1,11 +1,5 @@ """A library of transformation passes.""" -from .collectives import * -from .fused_moe import * -from .fusion import * -from .kvcache import * -from .rms_norm import * - try: from .visualization import visualize_namespace except ImportError: diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py deleted file mode 100644 index 8cec047561..0000000000 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py +++ /dev/null @@ -1,167 +0,0 @@ -import operator - -import torch -from torch.fx import GraphModule - -from ...distributed.trtllm import is_trtllm_op_available -from ...utils.logger import ad_logger -from ...utils.node_utils import get_op_overload_packet, get_user_if_pattern_match, is_op -from .._graph import canonicalize_graph - - -# TODO: This is an overly simplified model that works well for vanilla Llama models. -# However, we eventually want to consider more sophisticated patterns such as -# * all_reduce(lin1(x) + lin2(x)) -# * version above with fused GEMMs (i.e. with a split node) -# * all_reduce(pointwise_op(linear(x))) -# * ... -def fuse_collectives(gm: GraphModule) -> None: - num_gemm_collective_fusions = 0 - ad_logger.debug("Before GEMM+Collective fusion: " + str(gm)) - - # lookup for fused ops - # TODO: avoid this hardcoded lookup, e.g., by generating fused ops on the fly. - lookup = { - torch.ops.auto_deploy.torch_linear_simple: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce, - torch.ops.aten.linear: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce, - torch.ops.auto_deploy.torch_quant_fp8_linear: torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce, - } - - # go through all nodes and find all_reduce nodes - for node in gm.graph.nodes: - if not is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce): - continue - - # check if args are as expected - assert len(node.args) == 1 and not len(node.kwargs), "Unexpected args/kwargs for all_reduce" - - # retrieve parent and check a few conditions on the parent node - parent_node = node.args[0] - if not is_op(parent_node, lookup.keys()): - continue - if len(parent_node.users) > 1: - continue - - with gm.graph.inserting_before(node): - # insert fused node - fused_linear_collective_node = gm.graph.call_function( - lookup[get_op_overload_packet(parent_node.target)], - args=parent_node.args, - kwargs=parent_node.kwargs, - ) - node.replace_all_uses_with(fused_linear_collective_node) - gm.graph.erase_node(node) - gm.graph.erase_node(parent_node) - num_gemm_collective_fusions += 1 - - canonicalize_graph(gm) - ad_logger.info(f"Found {num_gemm_collective_fusions} GEMM+Collective fusions") - ad_logger.debug("After GEMM+Collective fusion: " + str(gm)) - - -def fuse_allreduce_residual_rmsnorm(gm: GraphModule) -> None: - """Essentially, this function fuses the following operators into one allreduce trtllm implementation. - - * target pattern: - x = all_reduce(x) - y = x + residual - return rmsnorm(y), y - * replacement: - fused_allreduce_residual_rmsnorm(x, residual, rmsnorm_weight, rmsnorm_eps) - - """ - if not is_trtllm_op_available(): - return - - num_ar_r_rms_fusions = 0 - ad_logger.debug("Before allreduce+residual+rmsnorm fusion: " + str(gm)) - - def trace_and_fuse(allreduce_node, graph): - # Check if all_reduce is followed by addition - users = list(allreduce_node.users.keys()) - if len(users) != 1: - return # Skip if all_reduce has more than one consumer - add_node = users[0] - - # Traverse nodes for RMSNorm pattern which is composed of to_copy, pow, mean, add, refer - # the Huggingface LlamaRMSNorm implementation as example for more details - to_copy_1 = get_user_if_pattern_match(add_node, [torch.ops.aten.add, operator.add], 2) - # operand of pow and mul - pow_node = get_user_if_pattern_match( - to_copy_1, [torch.ops.aten._to_copy, torch.ops.aten.to], 2 - ) - mean_node = get_user_if_pattern_match(pow_node, torch.ops.aten.pow, 1) - add_eps_node = get_user_if_pattern_match(mean_node, torch.ops.aten.mean, 1) - rsqrt_node = get_user_if_pattern_match(add_eps_node, [torch.ops.aten.add, operator.add], 1) - mul_node_1 = get_user_if_pattern_match(rsqrt_node, torch.ops.aten.rsqrt, 1) - to_copy_2 = get_user_if_pattern_match(mul_node_1, torch.ops.aten.mul, 1) - mul_node_2 = get_user_if_pattern_match( - to_copy_2, [torch.ops.aten._to_copy, torch.ops.aten.to], 1 - ) - # check args of ops: pow(2) and mean(-1) - ARGS_MATCH = pow_node is not None and pow_node.args[1] == 2 # exponent - ARGS_MATCH &= mean_node is not None and mean_node.args[1] == [-1] # dimensions - - # Match found: Replace with fused operation - if ( - to_copy_1 - and pow_node - and mean_node - and add_eps_node - and rsqrt_node - and mul_node_1 - and to_copy_2 - and mul_node_2 - and ARGS_MATCH - ): - # Gather the inputs for the custom operation - tensor = allreduce_node.args[0] - # Identify the residual argument in the add operation - # One of the args in add_node.args is the output of all_reduce - # The same idea also applies to norm_weight - residual = add_node.args[0] if add_node.args[1] is allreduce_node else add_node.args[1] - norm_weight = ( - mul_node_2.args[0] if mul_node_2.args[1] is to_copy_2 else mul_node_2.args[1] - ) - eps = add_eps_node.args[1] - - # Insert nodes - with graph.inserting_before(allreduce_node): - fused_node = graph.call_function( - torch.ops.dist.fused_allreduce_residual_rmsnorm, - args=( - tensor, - residual, - norm_weight, - eps, - ), - ) - # Extract outputs from the tuple returned by `fused_node` - final_output_node = gm.graph.create_node( - "call_function", - target=operator.getitem, - args=(fused_node, 0), - ) - add_output_node = gm.graph.create_node( - "call_function", - target=operator.getitem, - args=(fused_node, 1), - ) - - # Replace all uses of rmsnorm_node with final_output_node - mul_node_2.replace_all_uses_with(final_output_node) - - # Replace all uses of add_node with add_output_node - add_node.replace_all_uses_with(add_output_node) - - nonlocal num_ar_r_rms_fusions - num_ar_r_rms_fusions += 1 - - # Traverse all nodes - for node in gm.graph.nodes: - if is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce): - trace_and_fuse(allreduce_node=node, graph=gm.graph) - - canonicalize_graph(gm) - ad_logger.info(f"Found {num_ar_r_rms_fusions} allreduce+residual+rmsnorm fusions") - ad_logger.debug("After allreduce+residual+rmsnorm fusion: " + str(gm)) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py deleted file mode 100644 index e049970862..0000000000 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py +++ /dev/null @@ -1,511 +0,0 @@ -from collections import defaultdict -from typing import Optional - -import torch -from torch.fx import GraphModule, Node - -from ...utils.cuda_mem_tracker import cuda_memory_tracker -from ...utils.logger import ad_logger -from ...utils.node_utils import bfs, identify_regions_between_residuals, is_linear_op, is_op -from ...utils.quantization_utils import get_scales_and_type_from_node -from .._graph import canonicalize_graph - - -def match_moe_pattern(gm: GraphModule) -> None: - graph = gm.graph - - ad_logger.debug("Before MoE Pattern Matching: " + str(gm)) - # Preprocessing: Identify boundary nodes (e.g. residual connections) in the graph. - boundary_nodes = identify_regions_between_residuals(gm) - - num_moe_patterns = 0 - - for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]): - # Step 1: Identify Expert Compute pattern - (pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type) = ( - _match_expert_compute_pattern(start_boundary, end_boundary) - ) - if not expert_weights: - continue - # TODO: naming convention to verify the order of the weight nodes - - # Step 2: Trace upwards to locate normalize_routing_weight and selected_experts: - arg1_list, arg2_list = _extract_index_branches_from_expert_outputs(pattern_output_nodes) - normalized_routing_weights = _find_lowest_common_ancessor(arg1_list) - if not normalized_routing_weights: - continue - - common_ancessor2 = _find_lowest_common_ancessor(arg2_list) - if not common_ancessor2: - continue - selected_experts = bfs( - common_ancessor2, - lambda node: is_op(node, torch.ops.aten.one_hot), - attr_next="all_input_nodes", - boundary=start_boundary, - ).args[0] - if not selected_experts: - continue - - # Step 3: Trace upwards to find input node: - hidden_states = _find_lowest_common_ancessor(pattern_input_nodes) - if not hidden_states: - continue - - # Step 4: Find output node with the combine pattern - final_hidden_state_node = _find_final_hidden_state_node(pattern_output_nodes, end_boundary) - if final_hidden_state_node is None: - continue - - # Step 5: Insert the MoE op into the graph. - ad_logger.debug( - f"Found MoE Pattern: between boundary {start_boundary} and {end_boundary}.\n" - f"Input hidden states node: {hidden_states}, " - f"selected_experts node: {selected_experts}, " - f"routing_weights node: {normalized_routing_weights}, " - f"expert weights: {expert_weights}, weight type: {weight_type}" - ) - with graph.inserting_before(final_hidden_state_node): - w1_list = expert_weights["w1"] - w2_list = expert_weights["w2"] - w3_list = expert_weights["w3"] - - if weight_type == "fp8": - fused_moe_node = graph.call_function( - torch.ops.auto_deploy.torch_quant_fp8_moe, - args=( - hidden_states, - selected_experts, - normalized_routing_weights, - w1_list, - w2_list, - w3_list, - expert_scales["w1_input_scale"], - expert_scales["w2_input_scale"], - expert_scales["w3_input_scale"], - expert_scales["w1_weight_scale"], - expert_scales["w2_weight_scale"], - expert_scales["w3_weight_scale"], - ), - ) - elif weight_type == "fp4": - fused_moe_node = graph.call_function( - torch.ops.auto_deploy.torch_quant_fp4_moe, - args=( - hidden_states, - selected_experts, - normalized_routing_weights, - w1_list, - w2_list, - w3_list, - expert_scales["w1_input_scale"], - expert_scales["w2_input_scale"], - expert_scales["w3_input_scale"], - expert_scales["w1_weight_scale"], - expert_scales["w2_weight_scale"], - expert_scales["w3_weight_scale"], - expert_scales["w1_alpha"], - expert_scales["w2_alpha"], - expert_scales["w3_alpha"], - ), - ) - else: - fused_moe_node = graph.call_function( - torch.ops.auto_deploy.torch_moe, - args=( - hidden_states, - selected_experts, - normalized_routing_weights, - w1_list, - w2_list, - w3_list, - ), - ) - - final_hidden_state_node.replace_all_uses_with(fused_moe_node) - graph.erase_node(final_hidden_state_node) - - while _remove_dead_inplace_nodes_in_region(gm.graph, start_boundary, end_boundary): - gm.graph.eliminate_dead_code() - - num_moe_patterns += 1 - - canonicalize_graph(gm) - - ad_logger.info(f"Found {num_moe_patterns} MoE Patterns") - ad_logger.debug("After MoE Pattern Matching: " + str(gm)) - - -def fuse_moe(gm: torch.fx.GraphModule) -> None: - """ - Scan the FX graph and replace all calls to torch.ops.auto_deploy.torch_moe with - torch.ops.auto_deploy.trtllm_moe_fused. - """ - ad_logger.debug("Before MoE fusion: " + str(gm)) - - with cuda_memory_tracker(): - fused_key_counter = _insert_fused_moe_ops(gm) - if fused_key_counter: - canonicalize_graph(gm) - - ad_logger.info(f"Found {fused_key_counter} MoE fusions") - ad_logger.debug("After MoE fusion: " + str(gm)) - - -def _insert_fused_moe_ops(gm: GraphModule) -> int: - fused_key_counter = 0 - graph = gm.graph - - for node in list(graph.nodes): - if not is_op(node, torch.ops.auto_deploy.torch_moe): - continue - - ad_logger.debug(f"Found MoE op to fuse: {node} with args: {node.args}") - hidden_states, selected_experts, routing_weights, w1_list, w2_list, w3_list = node.args - - fused_w3_w1_experts = torch.stack( - [ - torch.cat( - [gm.get_parameter(w3_node.target), gm.get_parameter(w1_node.target)], dim=-2 - ) - for w1_node, w3_node in zip(w1_list, w3_list) - ], - dim=0, - ) - - fused_w2_experts = torch.stack([gm.get_parameter(n.target) for n in w2_list], dim=0) - - new_key_w3_w1 = f"fused_moe_w3_w1_stacked_{fused_key_counter}" - new_key_w2 = f"fused_moe_w2_stacked_{fused_key_counter}" - fused_key_counter += 1 - param_w3_w1 = torch.nn.Parameter(fused_w3_w1_experts) - param_w2 = torch.nn.Parameter(fused_w2_experts) - gm.register_parameter(new_key_w3_w1, param_w3_w1) - gm.register_parameter(new_key_w2, param_w2) - - with graph.inserting_before(node): - new_node = graph.call_function( - # TODO(Fridah-nv): torch.ops.auto_deploy.trtllm_moe_fused for quantized models - torch.ops.auto_deploy.trtllm_moe_fused, - args=( - hidden_states, - selected_experts, - routing_weights, - graph.get_attr(new_key_w3_w1), - graph.get_attr(new_key_w2), - ), - ) - - node.replace_all_uses_with(new_node) - graph.erase_node(node) - - return fused_key_counter - - -def _find_lowest_common_ancessor(nodes: list[Node]) -> Optional[Node]: - """ - Find the lowest common ancestor for a list of nodes in a torch.fx Graph by following - each node's primary branch (recursively following the first Node argument). - - It first finds the LCA of the first two nodes and then - iteratively computes the LCA of the result with the next node, and so on. - - Returns: - The common ancestor Node if found, otherwise None. - """ - if not nodes: - return None - - def get_parent(node: Node) -> Optional[Node]: - """Return the first Node-valued argument for a given node, or None if not found.""" - for arg in node.args: - if isinstance(arg, Node): - return arg - return None - - def get_depth(node: Node) -> int: - """ - Recursively compute the depth of the node by following its primary branch. - Depth is defined as the number of steps to reach a node with no parent. - """ - parent = get_parent(node) - if parent is None: - return 0 - return 1 + get_depth(parent) - - def lca_two(a: Node, b: Node) -> Optional[Node]: - """ - Find the lowest common ancestor of two nodes by first equalizing their depth - and then moving upward until a common node is found. - """ - depth_a = get_depth(a) - depth_b = get_depth(b) - - # Equalize depths - while depth_a > depth_b: - a = get_parent(a) - depth_a -= 1 - while depth_b > depth_a: - b = get_parent(b) - depth_b -= 1 - - # Walk upward in lockstep - while a is not None and b is not None: - if a is b: - return a - a = get_parent(a) - b = get_parent(b) - return None - - # Iteratively compute the LCA across all nodes. - common = nodes[0] - for node in nodes[1:]: - common = lca_two(common, node) - if common is None: - return None - - return common - - -def _extract_linear_parameters(linear_node: Node) -> tuple[Node, torch.Tensor, Optional[dict], str]: - """ - Given a linear op node, extract the input tensor node, weight tensor, - any quantization scales (if the op is quantized), and return a weight type. - - For a torch.ops.auto_deploy.torch_linear_simple.default op: - - Returns (input_node, weight, None, "simple") - - For a torch.ops.auto_deploy.torch_quant_fp8_linear op: - - Returns (input_node, weight, {"input_scale": input_scale, "weight_scale": weight_scale}, "fp8") - For a torch.ops.auto_deploy.torch_quant_fp4_linear op: - - Returns (input_node, weight, {"input_scale": input_scale, "weight_scale": weight_scale, "alpha": alpha}, "fp4") - """ - input_node = linear_node.args[0] - if is_op(linear_node, torch.ops.auto_deploy.torch_linear_simple): - weight = linear_node.args[1] - return input_node, weight, None, "" - elif { - is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp4_linear), - is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp8_linear), - }: - weight = linear_node.args[1] - scales, quant_type = get_scales_and_type_from_node(linear_node) - return input_node, weight, scales, quant_type - - -def _match_expert_compute_pattern(start_boundary: Node, end_boundary: Node): - """ - Match the expert compute pattern between the given boundaries. - - The expert compute pattern corresponds to: - - (F.silu(x @ w1.t()) * (x @ w3.t())) @ w2.t() - - For each expert, the function extracts the input node from the w1 branch and - collects the weight parameters from three linear ops (w1, w3, and w2 branches). - - This function supports both: - - torch.ops.auto_deploy.torch_linear_simple.default ops, and - - torch.ops.auto_deploy.torch_quant_fp8_linear ops (also extracts quantization scales). - - torch.ops.auto_deploy.torch_quant_fp4_linear ops (also extracts quantization scales). - - Returns: - A tuple: - (pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type) - - - pattern_input_nodes: List of input nodes (x) used for the expert compute. - - pattern_output_nodes: List of final expert output nodes (the linear op with weight w2). - - expert_weights: Dict with keys "w1", "w2", "w3" mapping to lists of weight tensors. - - expert_scales: Dict with keys "w1_input_scale", "w1_weight_scale", etc., containing scale tensors - (empty if weight_type is "simple"). - - weight_type: "fp8" if FP8 ops were used, "simple" otherwise. - """ - pattern_input_nodes, pattern_output_nodes = [], [] - expert_weights = defaultdict(list) - expert_scales = defaultdict(list) - weight_type = "simple" # default - - nodes = list(start_boundary.graph.nodes) - region_nodes = nodes[nodes.index(start_boundary) + 1 : nodes.index(end_boundary)] - - for node in region_nodes: - # Accept both simple and quantized linear ops. - if not is_linear_op(node, include_quantization=True): - continue - - final_linear = node - if not final_linear.args or not isinstance(final_linear.args[0], Node): - continue - - mul_node = final_linear.args[0] - if not is_op(mul_node, torch.ops.aten.mul) or len(mul_node.args) < 2: - continue - - arg_a, arg_b = mul_node.args[:2] - silu_node = ( - arg_a - if is_op(arg_a, torch.ops.aten.silu) - else arg_b - if is_op(arg_b, torch.ops.aten.silu) - else None - ) - if silu_node is None: - continue - - if not (silu_node.args and is_linear_op(silu_node.args[0], include_quantization=True)): - continue - linear_w1_node = silu_node.args[0] - - # The other branch should be a linear op (w3 branch). - linear_w3_node = arg_b if arg_a is silu_node else arg_a - if not is_linear_op(linear_w3_node, include_quantization=True): - continue - if not (linear_w1_node.args and linear_w3_node.args): - continue - - # Extract parameters from each linear op. - input_node_w1, weight_w1, quant_params_w1, wt_type_w1 = _extract_linear_parameters( - linear_w1_node - ) - _, weight_w3, quant_params_w3, wt_type_w3 = _extract_linear_parameters(linear_w3_node) - _, weight_w2, quant_params_w2, wt_type_w2 = _extract_linear_parameters(final_linear) - - if None in (weight_w1, weight_w3, weight_w2): - continue - - # Ensure the weight type is consistent across branches. - if wt_type_w1 != wt_type_w3 or wt_type_w1 != wt_type_w2: - continue - weight_type = wt_type_w1 - - pattern_input_nodes.append(input_node_w1) - pattern_output_nodes.append(final_linear) - expert_weights["w1"].append(weight_w1) - expert_weights["w3"].append(weight_w3) - expert_weights["w2"].append(weight_w2) - - # TODO: sanity check that all experts have same weight type - if weight_type == "fp8": - expert_scales["w1_input_scale"].append(quant_params_w1["input_scale"]) - expert_scales["w1_weight_scale"].append(quant_params_w1["weight_scale"]) - expert_scales["w3_input_scale"].append(quant_params_w3["input_scale"]) - expert_scales["w3_weight_scale"].append(quant_params_w3["weight_scale"]) - expert_scales["w2_input_scale"].append(quant_params_w2["input_scale"]) - expert_scales["w2_weight_scale"].append(quant_params_w2["weight_scale"]) - elif weight_type == "fp4": - expert_scales["w1_input_scale"].append(quant_params_w1["input_scale"]) - expert_scales["w1_weight_scale"].append(quant_params_w1["weight_scale"]) - expert_scales["w1_alpha"].append(quant_params_w1["alpha"]) - expert_scales["w3_input_scale"].append(quant_params_w3["input_scale"]) - expert_scales["w3_weight_scale"].append(quant_params_w3["weight_scale"]) - expert_scales["w3_alpha"].append(quant_params_w3["alpha"]) - expert_scales["w2_input_scale"].append(quant_params_w2["input_scale"]) - expert_scales["w2_weight_scale"].append(quant_params_w2["weight_scale"]) - expert_scales["w2_alpha"].append(quant_params_w2["alpha"]) - - return pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type - - -def _find_final_hidden_state_node( - pattern_output_nodes: list[Node], end_boundary: Node -) -> Optional[Node]: - """ - Identify the final hidden state node corresponding to the combine pattern: - - (expert_output * routing_weight) → index_add_ - - For each expert output node (from the expert compute pattern), this function: - 1. Retrieves a multiplication node from its users. - 2. Extracts the second argument from the multiplication node (assumed to be the index node). - 3. Uses a BFS to locate the subsequent index_add_ node (guarded by the end_boundary). - - After collecting all such index_add_ nodes, the final hidden state node is determined - as the one that is not used by any of the other index_add_ nodes. - - If any required attribute (users or args) is missing during the process or if no valid - final node is found, the function returns None. - """ - - if not pattern_output_nodes: - return None - - index_add_nodes = [] - for node in pattern_output_nodes: - if not node.users: - return None - mul_node = next(iter(node.users)) - if not (hasattr(mul_node, "args") and len(mul_node.args) >= 2): - return None - index_node = mul_node.args[1] - index_add_node = bfs( - index_node, lambda n: is_op(n, torch.ops.aten.index_add_), boundary=end_boundary - ) - if not index_add_node: - return None - index_add_nodes.append(index_add_node) - - # The final node is defined as the index_add_node that is not used by any other index_add_nodes - return next( - ( - candidate - for candidate in index_add_nodes - if not any( - candidate in other.args for other in index_add_nodes if candidate is not other - ) - ), - None, - ) - - -def _extract_index_branches_from_expert_outputs( - pattern_output_nodes: list[Node], -) -> tuple[list[Node], list[Node]]: - """ - Extract routing and experts branches from expert outputs. - - For each expert output, find its multiplication user. From the - multiplication node's second argument (an index node), - extract: - - The first argument as the routing branch. - - The second argument (flattened if a list/tuple) as the experts branch. - - Returns: - A tuple (routing_branches, experts_branches). - """ - routing_branches, experts_branches = [], [] - for out in pattern_output_nodes: - mul = next((u for u in out.users if is_op(u, torch.ops.aten.mul)), None) - if not mul or len(mul.args) < 2: - continue - idx_node = mul.args[1] - if not is_op(idx_node, torch.ops.aten.index): - continue - routing_branches.append(idx_node.args[0]) - experts = idx_node.args[1] - experts_branches.extend(experts) if isinstance( - experts, (list, tuple) - ) else experts_branches.append(experts) - return routing_branches, experts_branches - - -def _remove_dead_inplace_nodes_in_region( - graph: torch.fx.Graph, - start_boundary: torch.fx.Node, - end_boundary: torch.fx.Node, -) -> bool: - """ - Searches (via BFS) for a dead in-place node (index_add_) in the region - between start_boundary and end_boundary. If one is found, it is removed from the graph. - Returns True if a node was removed, False otherwise. - """ - - def target(n: torch.fx.Node) -> bool: - return is_op(n, {torch.ops.aten.index_add_}) and len(n.users) == 0 - - try: - node_to_remove = bfs(start_boundary, target, attr_next="users", boundary=end_boundary) - ad_logger.debug(f"Removing In-place Dead Node: {node_to_remove}") - graph.erase_node(node_to_remove) - return True - except RuntimeError: - return False diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py deleted file mode 100644 index 618c8108f8..0000000000 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py +++ /dev/null @@ -1,193 +0,0 @@ -"""Graph transformation to automatically add kv cache into fused MHA op.""" - -import operator -from typing import Dict, Type - -import torch -from torch.fx import Graph, GraphModule, Node - -from ...custom_ops.attention_interface import AttentionDescriptor, CacheConfig -from ...distributed.common import all_gather_object, get_world_size -from ...shim.interface import CachedSequenceInterface -from ...utils.logger import ad_logger -from ...utils.node_utils import get_all_input_output_nodes, is_op -from .._graph import add_graph_input, canonicalize_graph - - -def update_in_out_nodes(egm: GraphModule, cm: CachedSequenceInterface) -> None: - """Modify the graph module by adding new input nodes and canonicalizing the graph. - - The new input nodes correspond to the extra arguments needed for cached and flattened attention. - - Args: - egm: The graph module to analyze and modify. - cm: Cached sequence interface containing extra argument information. - """ - # loop through nodes to get input, output, and get_attr nodes - input_nodes, output_nodes = get_all_input_output_nodes(egm.graph) - - # we only expect one input node - assert len(input_nodes) == 2, "Expected exactly two input nodes (input_ids, position_ids)." - - # NOTE: for now, we wanna make sure we *only* return the final output and no hidden states. - # Later on, we can revisit how to support returning hidden states. - assert len(output_nodes) == 1, "Expected exactly one output node!" - assert len(output_nodes[0].all_input_nodes) == 1, "Expected to only return final tensor output!" - - ad_logger.info(f"Found {len(input_nodes)} input nodes and {len(output_nodes)} output nodes") - - # Activate and add extra argument nodes - new_args = cm.info.switch_to_cached_attn_inputs() - for name in new_args: - input_nodes.append(add_graph_input(egm, name)) - ad_logger.info(f"Added {len(new_args)} new input nodes for cached attention metadata") - - canonicalize_graph(egm) - - -def insert_cached_attention( - egm: GraphModule, - cm: CachedSequenceInterface, - attn_descriptor: Type[AttentionDescriptor], - cache_config: CacheConfig, -) -> None: - """Replace uncached source attention node with corresponding cached attn node.""" - # Get all attention nodes and their info objects - source_op = attn_descriptor.get_source_attention_op() - - # pick up graph - graph: Graph = egm.graph - - # look for relevant source attention nodes - source_attn_nodes = [n for n in graph.nodes if is_op(n, source_op)] - - if not source_attn_nodes: - # If there are no nodes for kv cache insertion found, return current graph - return - - # Sanity check - if cm.info.is_paged: - assert attn_descriptor.is_paged(), "Paged sequence info requires paged attention op." - - ad_logger.debug(f"Before inserting {attn_descriptor=} with cache: {egm}") - - # retrieve input nodes - input_nodes, _ = get_all_input_output_nodes(egm.graph) - - # insert metadata computation and extract each argument as a node - get_metadata, num_metadata = attn_descriptor.get_prepare_metadata_op() - with graph.inserting_before(input_nodes[-1].next): - ret_node = graph.call_function( - get_metadata, - args=( - *input_nodes, - cm.info.page_size, - ), - ) - metadata_nodes = [ - graph.call_function(operator.getitem, args=(ret_node, idx)) - for idx in range(num_metadata) - ] - - buffer_in_lookup: Dict[str, Node] = {} - - # replace fused attention node with attention node that has kv cache - num_cached_attn_replacements = 0 - for idx, attn_node in enumerate(source_attn_nodes): - # pick out GEMMs - qkv = attn_node.args[: attn_descriptor.get_num_qkv_args()] - - # setup + store cache initializers and caches as input nodes - cache_in_nodes = [] - for k, get_cache in attn_descriptor.get_cache_initializers(attn_node, cache_config).items(): - k_indexed = f"{k}_{idx}" - cm.add_cache(k_indexed, get_cache) - cache_in_nodes.append(add_graph_input(egm, k_indexed)) - - # setup + store global buffer initializers and buffers as input nodes - # NOTE: we have to check against existing keys to make sure nothing is registered twice... - buffer_in_nodes = [] - for k, get_buffer in attn_descriptor.get_global_buffer_initializers(attn_node).items(): - if k not in buffer_in_lookup: - cm.add_cache(k, get_buffer) - buffer_in_lookup[k] = add_graph_input(egm, k) - buffer_in_nodes.append(buffer_in_lookup[k]) # store buffer nodes for this op - - # retrieve constants for attention_op - constants = attn_descriptor.get_constants(attn_node) - - # insert cached attention replacement op - with graph.inserting_before(attn_node): - cached_attn_node = graph.call_function( - attn_descriptor.get_cached_attention_op(), - args=(*qkv, *metadata_nodes, *cache_in_nodes, *buffer_in_nodes, *constants), - ) - attn_node.replace_all_uses_with(cached_attn_node) - graph.erase_node(attn_node) - num_cached_attn_replacements += 1 - - canonicalize_graph(egm) - ad_logger.info( - f"Replaced {num_cached_attn_replacements} {source_op} ops " - f"with {attn_descriptor.get_cached_attention_op()}" - ) - ad_logger.debug(f"After inserting {attn_descriptor=} with cache: {egm}") - - -def resize_kv_cache( - egm: GraphModule, - cm: CachedSequenceInterface, - free_mem_ratio: float = 0.8, -) -> None: - """Inflate the kv cache to occupy the available GPU memory. - - free_mem_ratio specifies the fraction of available memory to occupy. - """ - - def _get_mem_info_in_mb(): - free_mem, total_mem = torch.cuda.mem_get_info() - return free_mem // 1024**2, total_mem // 1024**2 - - free_mem, total_mem = _get_mem_info_in_mb() - ad_logger.info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}") - current_cache_size = cm.current_cache_size_bytes() - current_num_pages = cm.info.num_pages - ad_logger.info( - f"Current cache size: {current_cache_size}, Current num pages: {current_num_pages}" - ) - - if free_mem_ratio == 0.0: - ad_logger.info(f"Skipping cache resize for {free_mem_ratio=}") - return - - try: - # Let's run a forward pass to get the memory usage - cm.info._set_max_num_tokens_sample() - free_mem_pre, _ = _get_mem_info_in_mb() - ad_logger.info(f"Free memory before forward pass (MB): {free_mem_pre}") - - egm(*cm.args) - - free_mem_post, _ = _get_mem_info_in_mb() - ad_logger.info(f"Free memory after forward pass (MB): {free_mem_post}") - - memory_for_forward_pass = free_mem_pre - free_mem_post - ad_logger.info(f"Memory for forward pass (MB): {memory_for_forward_pass}") - - new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size - new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages)) - - # Need to sync all the GPUs - gathered_num_pages = [None] * get_world_size() - all_gather_object(gathered_num_pages, new_num_pages) - new_num_pages = min(gathered_num_pages) - ad_logger.info(f"After all_gather - new_num_pages: {new_num_pages}") - - cm.resize_cache(new_num_pages) - except Exception as e: - ad_logger.warning( - f"Error encountered while resizing kv cache: {e}.\nSkipping cache resize." - ) - - # Free memory - torch.cuda.empty_cache() diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/rms_norm.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/rms_norm.py deleted file mode 100644 index a94758b181..0000000000 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/rms_norm.py +++ /dev/null @@ -1,113 +0,0 @@ -"""Graph transform to optimize RMSNorm execution using FlashInfer.""" - -from functools import partial - -import torch -from torch.fx import GraphModule - -from ...utils.logger import ad_logger - -# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher -from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern -from .._graph import canonicalize_graph - -_BACKEND_OPS = { - "flashinfer": torch.ops.auto_deploy.flashinfer_rms_norm, - "triton": torch.ops.auto_deploy.triton_rms_norm, - "torch": torch.ops.auto_deploy.torch_rmsnorm, -} - - -def _rms_norm_pattern(data: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: - """Implements the RMSNorm pattern for pattern matching. - - Args: - data: Input tensor to normalize. - weight: Scaling weights for the normalized output. - eps: Small constant for numerical stability. - - Returns: - Normalized and scaled tensor. - """ - input_dtype = data.dtype - data = data.to(torch.float32) - variance = data.pow(2).mean(-1, keepdim=True) - data = data * torch.rsqrt(variance + eps) - return weight * data.to(input_dtype) - - -def _rms_norm_replacement( - data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str -) -> torch.Tensor: - """Backend-specific rms_norm implementation. - - Args: - data: Input tensor to normalize. - weight: Scaling weights for the normalized output. - eps: Small constant for numerical stability. - backend: Backend to use for RMSNorm computation ("flashinfer" or "triton"). - - Returns: - Normalized and scaled tensor using the specified backend implementation. - """ - - assert backend.lower() in _BACKEND_OPS, ( - f"Invalid {backend=}; must be one of {list(_BACKEND_OPS)}" - ) - return _BACKEND_OPS[backend.lower()](data, weight, eps) - - -def fuse_rmsnorm(gm: GraphModule, backend: str = "triton") -> None: - """Matches and replaces RMSNorm patterns in the graph with FlashInfer or Triton implementation. - - This function sets up pattern matching to identify RMSNorm operations in the graph - and replaces them with optimized implementations. It uses dummy tensors to register - the pattern matching rules. - - Args: - gm: Input graph module to transform. - backend: Backend to use for RMSNorm computation ("flashinfer" or "triton"). - - Returns: - Transformed graph module with optimized RMSNorm operations. - """ - if backend.lower() not in _BACKEND_OPS: - raise ValueError(f"Invalid backend, must be one of {list(_BACKEND_OPS)}, got {backend}") - ad_logger.info(f"Starting RMSNorm pattern matching with backend: {backend}") - - graph = gm.graph - patterns = ADPatternMatcherPass() - - # Create dummy tensors for pattern matching - bs = 2 - hidden_size = 512 - - def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6): - return [ - torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype), - torch.randn(hidden_size, device="cuda", dtype=weight_dtype), - eps, - ] - - # Define configurations for different data types - configs = [ - (torch.bfloat16, torch.bfloat16), - (torch.float16, torch.float16), - (torch.float32, torch.float32), - ] - - # Register patterns for each configuration - for input_dtype, weight_dtype in configs: - register_ad_pattern( - search_fn=_rms_norm_pattern, - replace_fn=partial(_rms_norm_replacement, backend=backend), - patterns=patterns, - dummy_args=dummy_args(input_dtype, weight_dtype), - op_ignore_types={}, - scalar_workaround={"eps": 1e-6}, - ) - - cnt = patterns.apply(graph) - ad_logger.info(f"RMSNorm pattern count: {cnt}") - canonicalize_graph(gm) - ad_logger.debug("RMSNorm pattern matching completed.") diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py index c841b4601f..931c8ec955 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py @@ -5,21 +5,11 @@ import gc import torch import torch.nn as nn -from ..compile import compile_and_capture from ..custom_ops.attention_interface import AttentionRegistry from ..llm_args import AutoDeployConfig from ..models.factory import ModelFactory from ..shim.interface import CachedSequenceInterface from ..transform.optimizer import InferenceOptimizer as ModularInferenceOptimizer -from ..utils.logger import ad_logger -from .library import ( - fuse_allreduce_residual_rmsnorm, - fuse_collectives, - fuse_rmsnorm, - insert_cached_attention, - resize_kv_cache, - update_in_out_nodes, -) class InferenceOptimizer: @@ -55,88 +45,60 @@ class InferenceOptimizer: self.ad_config.attn_backend ).get_attention_layout() - new_optimizer = ModularInferenceOptimizer(self.factory, self.ad_config.transforms) - - # TODO (hg): similar to above. - if "load_weights" in new_optimizer.config: - new_optimizer.config[ + if "load_weights" in self.ad_config.transforms: + self.ad_config.transforms[ "load_weights" ].checkpoint_device = self.ad_config.checkpoint_device - new_optimizer.config["load_weights"].device = cm.device + self.ad_config.transforms["load_weights"].device = cm.device + + if "resize_kv_cache" in self.ad_config.transforms: + self.ad_config.transforms[ + "resize_kv_cache" + ].free_mem_ratio = self.ad_config.free_mem_ratio + if "insert_cached_attention" in self.ad_config.transforms: + self.ad_config.transforms[ + "insert_cached_attention" + ].attn_backend = self.ad_config.attn_backend + if "insert_cached_mla_attention" in self.ad_config.transforms: + self.ad_config.transforms[ + "insert_cached_mla_attention" + ].attn_backend = self.ad_config.mla_backend + + # TODO: (hg)Missing MLA here. Figure out how to add MLA since duplicate transforms are not allowed. + # Old code: + # detect attention op and replace with cache-aware op + # for a_backend in [self.ad_config.attn_backend, self.ad_config.mla_backend]: + # attn_descriptor = AttentionRegistry.get(a_backend) + # insert_cached_attention(egm, cm, attn_descriptor, self.factory.get_cache_config()) + + if "compile_model" in self.ad_config.transforms: + self.ad_config.transforms[ + "compile_model" + ].cuda_graph_batch_sizes = self.ad_config.cuda_graph_batch_sizes + self.ad_config.transforms[ + "compile_model" + ].compile_backend = self.ad_config.compile_backend + + new_optimizer = ModularInferenceOptimizer(self.factory, self.ad_config.transforms) + # TODO: (hg) move this. let match_rope_layout and match_atten_layout use this shared config + new_optimizer.shared_config.attn_backend = self.ad_config.attn_backend egm = new_optimizer(cm) - # TODO (lucaslie): continue moving legacy transforms to the new optimizer - ############################################################################################ - # RUN POST-LOAD FUSION AND OPTIMIZATIONS - ############################################################################################ + # NOTE: (hg)Disabled visualization since compiled gm is a CapturedGraph instead of GraphModule. + # We can add a new stage in the optimizer to visualize the intermediate gm. + # if self.ad_config.visualize: + # try: + # from .library import visualize_namespace - # run MoE fusion - # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs - # fuse_moe(egm) - - # run GEMM fusion - # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs - # fuse_gemms(egm) - - # check if we can fuse allreduce, residual and rmsnorm - fuse_allreduce_residual_rmsnorm(egm) - - # check if we can fuse collectives - fuse_collectives(egm) - - # TODO (lucaslie): add backend selection as part of configurable inference optimizers - # check if we can fuse rmsnorm - fuse_rmsnorm(egm, "flashinfer") - - # visualize the final graph - if self.ad_config.visualize: - try: - from .library import visualize_namespace - - visualize_namespace(egm, args=cm.args, dynamic_shapes=cm.dynamic_shapes) - ad_logger.warning( - "Please run `pip install -r examples/auto_deploy/requirements.txt` to visualize" - " the graph." - ) - except ImportError: - pass - - ############################################################################################ - # SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES - ############################################################################################ - - update_in_out_nodes(egm, cm) - - # detect attention op and replace with cache-aware op - for a_backend in [self.ad_config.attn_backend, self.ad_config.mla_backend]: - attn_descriptor = AttentionRegistry.get(a_backend) - insert_cached_attention(egm, cm, attn_descriptor, self.factory.get_cache_config()) - - # initialize cache on correct device - cm.initialize_caches() - - # resize kv cache to occupy the available GPU memory up to free_mem_ratio - resize_kv_cache(egm, cm, free_mem_ratio=self.ad_config.free_mem_ratio) - - ############################################################################################ - # COMPILE MODEL - ############################################################################################ - - cm.info.set_generate_only_batch() - compiler_kwargs = { - "cuda_graph_batch_sizes": self.ad_config.cuda_graph_batch_sizes, - "num_batched_inputs": 2, # TODO (lucaslie): improve once we have a config system... - } - egm_compiled = compile_and_capture( - egm, - self.ad_config.compile_backend, - args=cm.args, - dynamic_shapes=cm.dynamic_shapes, - compiler_kwargs=compiler_kwargs, - ) - cm.info.reset() + # visualize_namespace(egm, args=cm.args, dynamic_shapes=cm.dynamic_shapes) + # ad_logger.warning( + # "Please run `pip install -r examples/auto_deploy/requirements.txt` to visualize" + # " the graph." + # ) + # except ImportError: + # pass torch.cuda.empty_cache() gc.collect() - return egm_compiled + return egm diff --git a/tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py b/tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py index 777c5787ef..e0e21b1d70 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py @@ -43,11 +43,13 @@ def _patch_unsupported_input_tensor(): """ original_fn = lowering.unsupported_input_tensor - def patched_fn(t: torch.Tensor, node=None): + def patched_fn(t: torch.Tensor, *args, **kwargs): """Bypass meta tensor check.""" if t.is_meta: return False - return original_fn(t, node) + return original_fn( + t, *args, **kwargs + ) # a generic pass-through of the arguments to accommodate torch side change lowering.unsupported_input_tensor = patched_fn try: diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index da4df91f69..aa1b250b3a 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -453,7 +453,8 @@ class AutoTuner: p.name for p in inspect.signature(runner.forward).parameters.values() } - valid_tactics = runner.get_valid_tactics(input_tensors, profile) + valid_tactics = runner.get_valid_tactics(input_tensors, profile, + **kwargs) if "do_preparation" in runner_arg_names and len(valid_tactics) > 0: runner( input_tensors, diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index ba71e4fbfe..098af11fc8 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -531,3 +531,11 @@ def _register_fake(): return router_logits.new_empty( sz, dtype=torch.int32), router_logits.new_empty(sz, dtype=torch.float32) + + @torch.library.register_fake("trtllm::default_moe_routing_op") + def _(router_logits, topk): + num_tokens = router_logits.shape[0] + sz = (num_tokens, topk) + return router_logits.new_empty( + sz, dtype=torch.int32), router_logits.new_empty(sz, + dtype=torch.float32) diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index bd946343b0..7d0c73364d 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -81,12 +81,9 @@ class MoERunner(TunableRunner): use_fused_finalize) self.fused_moe_runner = MoERunner.runner_dict[instance_key] - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: - return range(self.fused_moe_runner.get_tactic_num()) + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: + return range(self.fused_moe_runner.get_tactic_num(kwargs["gemm_idx"])) def forward( self, @@ -318,11 +315,8 @@ class FP8RowwiseGemmRunner(TunableRunner): self.fp8_rowwise_gemm_runner = FP8RowwiseGemmRunner.runner_dict[ instance_key] - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: return list(range(self.fp8_rowwise_gemm_runner.get_num_configs())) def forward( @@ -403,11 +397,8 @@ class FP4GemmRunner(TunableRunner): output_dtype, int(fp4_gemm_type)) self.fp4_gemm_runner = FP4GemmRunner.runner_dict[instance_key] - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: return list(range(self.fp4_gemm_runner.get_num_configs())) def forward( @@ -518,11 +509,8 @@ class FP8BatchedGemmRunner(TunableRunner): return out_tensors - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: mat1, mat2, _, _, _ = inputs @@ -735,11 +723,8 @@ class WeightOnlyQuantGemmRunner(TunableRunner): self.weight_only_quant_gemm_runner = WeightOnlyQuantGemmRunner.runner_dict[ instance_key] - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: return list(range(self.weight_only_quant_gemm_runner.get_num_configs())) def forward( @@ -813,11 +798,8 @@ class FinegrainedMixedDtypeGemm(TunableRunner): self._finegrained_mixed_dtype_gemm_runner = FinegrainedMixedDtypeGemm._runner_dict[ instance_key] - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: return list( range(self._finegrained_mixed_dtype_gemm_runner.get_num_configs())) diff --git a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py index 2bb780f6ef..bbee1b8102 100644 --- a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py @@ -122,11 +122,8 @@ class FP4BlockScaleMoERunner(TunableRunner): self.local_num_experts, self.routed_scaling_factor, self.routing_method_type, self.do_finalize, tactic) - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: args = FP4BlockScaleMoEInputs(*inputs) @@ -409,11 +406,8 @@ class FP8BlockScaleMoERunner(TunableRunner): self.local_expert_offset, self.local_num_experts, self.routed_scaling_factor, self.routing_method_type, tactic) - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: args = FP8BlockScaleMoEInputs(*inputs) @@ -670,11 +664,8 @@ class MxE4m3MxE2m1BlockScaleMoERunner(TunableRunner): self.local_expert_offset, self.local_num_experts, self.routed_scaling_factor, self.routing_method_type, tactic) - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: args = MxE4m3MxE2m1BlockScaleMoEInputs(*inputs) @@ -907,11 +898,8 @@ class E4m3MxE2m1BlockScaleMoERunner(TunableRunner): self.local_expert_offset, self.local_num_experts, self.routed_scaling_factor, self.routing_method_type, tactic) - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: args = E4m3MxE2m1BlockScaleMoEInputs(*inputs) @@ -1123,11 +1111,8 @@ class Bf16MxE2m1BlockScaleMoERunner(TunableRunner): self.local_num_experts, self.routed_scaling_factor, self.routing_method_type, tactic) - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: args = Bf16MxE2m1BlockScaleMoEInputs(*inputs) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 8eb9acfada..c9b9fa979f 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -65,7 +65,7 @@ from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig from ..modules.multi_stream_utils import maybe_execute_in_parallel from ..modules.rms_norm import RMSNorm from ..peft.lora.layer import LoraLayer -from ..speculative import MTPSpecMetadata, SpecMetadata +from ..speculative import SpecMetadata from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor from .modeling_speculative import SpecDecOneEngineForCausalLM from .modeling_utils import (DecoderModel, EagerFusionConfig, filter_weights, @@ -230,7 +230,7 @@ class DeepseekV3Attention(MLA): aux_stream: Optional[torch.cuda.Stream] = None, ): config = model_config.pretrained_config - predicted_tokens_per_seq = model_config.spec_config.num_nextn_predict_layers + 1 if model_config.spec_config is not None else 1 + predicted_tokens_per_seq = model_config.spec_config.max_draft_len + 1 if model_config.spec_config is not None else 1 super().__init__(hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, @@ -750,6 +750,7 @@ class DeepseekV3DecoderLayer(DecoderLayer): hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, residual: torch.Tensor, + spec_metadata: Optional[SpecMetadata] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: if residual is None: @@ -765,16 +766,24 @@ class DeepseekV3DecoderLayer(DecoderLayer): **kwargs, ) if isinstance(self.mlp, Deepseekv3MoE): + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + self.fusion_config.POST_MOE_FUSION = False return self.forward_MoE( hidden_states=hidden_states, attn_metadata=attn_metadata, residual=residual, + spec_metadata=spec_metadata, ) else: + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + self.fusion_config.POST_MLP_FUSION = False assert isinstance(self.mlp, GatedMLP) return self.forward_mlp( hidden_states=hidden_states, residual=residual, + spec_metadata=spec_metadata, ) def forward_MoE( @@ -782,6 +791,7 @@ class DeepseekV3DecoderLayer(DecoderLayer): hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, residual: torch.Tensor, + spec_metadata: Optional[SpecMetadata] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: def _run_MoE(hidden_states, hidden_states_fp4, do_finalize): @@ -856,6 +866,10 @@ class DeepseekV3DecoderLayer(DecoderLayer): hidden_states, residual = self.moe_allreduce( fc2_output, all_reduce_params=moe_all_reduce_params) else: + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + spec_metadata.maybe_capture_hidden_states( + self.layer_idx, hidden_states, residual) if self.next_layer_layernorm is not None: hidden_states, residual = self.next_layer_layernorm( hidden_states, residual) @@ -866,6 +880,7 @@ class DeepseekV3DecoderLayer(DecoderLayer): self, hidden_states: torch.Tensor, residual: torch.Tensor, + spec_metadata: Optional[SpecMetadata] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: if self.fusion_config.PRE_MLP_FUSION: @@ -903,6 +918,10 @@ class DeepseekV3DecoderLayer(DecoderLayer): ), ) else: + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + spec_metadata.maybe_capture_hidden_states( + self.layer_idx, hidden_states, residual) if self.next_layer_layernorm is not None: hidden_states, residual = self.next_layer_layernorm( hidden_states, residual) @@ -1105,6 +1124,7 @@ class DeepseekV3Model(DecoderModel): hidden_states=hidden_states, attn_metadata=attn_metadata, residual=residual, + spec_metadata=spec_metadata, ) return hidden_states @@ -1132,7 +1152,8 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model, model_config=model_config) self.model_nextn = 0 - if model_config.spec_config is not None: + if model_config.spec_config is not None and model_config.spec_config.spec_dec_mode.is_mtp( + ): model_nextn = model_config.spec_config.num_nextn_predict_layers ckpt_nextn = self.config.num_nextn_predict_layers self.num_hidden_layers = self.config.num_hidden_layers @@ -1167,11 +1188,10 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model, input_ids: torch.IntTensor = None, position_ids: Optional[torch.IntTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - spec_metadata: Optional[MTPSpecMetadata] = None, + spec_metadata: Optional[SpecMetadata] = None, return_context_logits: bool = False, **kwargs, ) -> torch.Tensor: - attn_metadata.num_generations_per_batch = self.model_nextn + 1 return super().forward(attn_metadata=attn_metadata, input_ids=input_ids, position_ids=position_ids, @@ -1313,7 +1333,9 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model, for name, module in tqdm(all_named_modules.items(), desc="Loading weights"): - if len(module._parameters) > 0: + if len(module._parameters) <= 0 or name.startswith("draft_model"): + continue + else: names = name.split('.') parent_module_name = '.'.join(names[:-1]) if "model.layers" in name and int( diff --git a/tensorrt_llm/_torch/models/modeling_gemma3vl.py b/tensorrt_llm/_torch/models/modeling_gemma3vl.py index e305b82dba..ce8bcc6c8f 100644 --- a/tensorrt_llm/_torch/models/modeling_gemma3vl.py +++ b/tensorrt_llm/_torch/models/modeling_gemma3vl.py @@ -194,11 +194,16 @@ class Gemma3VLM(PreTrainedModel): "text_config", "vision_config" ], f"Expected subconfig name to be either 'text_config' or 'vision_config'. Got {name} instead." pretrained_config = getattr(model_config.pretrained_config, name) + # ModelOpt currently doesn't quantize the vision part. Without setting quant config to None, + # weight loading fails for vision. + quant_config = model_config.quant_config if name == "text_config" else None + # FlashInfer backend supports custom mask which is needed for bidirectional mask in decoder. preferred_backend = "FLASHINFER" if name == "text_config" else "TRTLLM" sub_model_config: ModelConfig[Gemma3Config] = dataclasses.replace( model_config, pretrained_config=pretrained_config, - attn_backend=preferred_backend) + attn_backend=preferred_backend, + quant_config=quant_config) # Make sure some fields that are not explicitly included in the sub config, but present # in the top-level config, are replicated. if (hasattr(sub_model_config.pretrained_config, "torch_dtype") diff --git a/tensorrt_llm/_torch/models/modeling_nemotron_h.py b/tensorrt_llm/_torch/models/modeling_nemotron_h.py index 41f870f890..e548d09a08 100644 --- a/tensorrt_llm/_torch/models/modeling_nemotron_h.py +++ b/tensorrt_llm/_torch/models/modeling_nemotron_h.py @@ -221,7 +221,9 @@ class NemotronHModel(DecoderModel): ) if self.mamba_metadata is None or self.mamba_metadata.max_batch_size != attn_metadata.max_num_requests: - self.mamba_metadata = Mamba2Metadata(attn_metadata.max_num_requests) + self.mamba_metadata = Mamba2Metadata( + attn_metadata.max_num_requests, + chunk_size=self.model_config.pretrained_config.chunk_size) self.mamba_metadata.prepare(attn_metadata) if inputs_embeds is None: diff --git a/tensorrt_llm/_torch/models/modeling_phi4mm.py b/tensorrt_llm/_torch/models/modeling_phi4mm.py index ee0263eb5e..bc449e1da5 100644 --- a/tensorrt_llm/_torch/models/modeling_phi4mm.py +++ b/tensorrt_llm/_torch/models/modeling_phi4mm.py @@ -611,23 +611,21 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel): @staticmethod def lora_config(model_dir: str): _lora_config = LoraConfig( - lora_dir=[ - f"{model_dir}/vision-lora", - f"{model_dir}/speech-lora", - ], lora_target_modules=[ "attn_qkv", "attn_dense", - "mlp_h_to_4h", + "mlp_gate_up", "mlp_4h_to_h", ], trtllm_modules_to_hf_modules={ "attn_qkv": "qkv_proj", "attn_dense": "o_proj", - "mlp_h_to_4h": "gate_up_proj", + "mlp_gate_up": "gate_up_proj", "mlp_4h_to_h": "down_proj", }, max_lora_rank=320, # Max rank for Phi4MM. + swap_gate_up_proj_lora_b_weight= + False, # Disable swap gate_up_proj.lora_B.weight for Phi4MM. ) return _lora_config diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index f82c3b4de0..56a489c963 100644 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -155,10 +155,12 @@ class Eagle3DraftModel(DecoderModel): else: self.hidden_size_in = config.hidden_size - self.fc = Linear(self.hidden_size_in * 3, - config.hidden_size, - bias=getattr(config, "bias", False), - dtype=config.torch_dtype) + if self.spec_config.num_capture_layers > 1: + self.fc = Linear(self.hidden_size_in * + self.spec_config.num_capture_layers, + config.hidden_size, + bias=getattr(config, "bias", False), + dtype=config.torch_dtype) self.midlayer = Eagle3DecoderLayer(model_config, start_layer_idx) diff --git a/tensorrt_llm/_torch/modules/fused_moe/routing.py b/tensorrt_llm/_torch/modules/fused_moe/routing.py index 34c2179593..635091c7ad 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/routing.py +++ b/tensorrt_llm/_torch/modules/fused_moe/routing.py @@ -183,18 +183,28 @@ class BaseMoeRoutingMethod(nn.Module): class DefaultMoeRoutingMethod(BaseMoeRoutingMethod): - def __init__(self, top_k: int): + def __init__(self, top_k: int, force_enable_pytorch_op: bool = False): super().__init__() self.top_k = top_k + self.force_enable_pytorch_op = force_enable_pytorch_op - def apply(self, - router_logits: torch.Tensor) -> (torch.Tensor, torch.Tensor): + def apply_pytorch( + self, router_logits: torch.Tensor) -> (torch.Tensor, torch.Tensor): topk_values, topk_indices = torch.topk(torch.nn.functional.softmax( router_logits.float(), dim=-1), k=self.top_k, dim=-1) return topk_indices.to(torch.int32), topk_values + def apply(self, + router_logits: torch.Tensor) -> (torch.Tensor, torch.Tensor): + num_experts = router_logits.shape[-1] + if self.force_enable_pytorch_op or num_experts > 128 or self.top_k > 8: + return self.apply_pytorch(router_logits) + else: + return torch.ops.trtllm.default_moe_routing_op( + router_logits, self.top_k) + @property def routing_method_type(self): return RoutingMethodType.Default diff --git a/tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py b/tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py index 445c288e6f..d421cc9209 100644 --- a/tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py +++ b/tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py @@ -13,15 +13,83 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math +from typing import Tuple + import torch from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata +def cu_seqlens_to_chunk_indices_offsets( + cu_seqlens: torch.Tensor, + chunk_size: int) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + cu_seqlens (torch.Tensor): 1D tensor of cumulative sequence lengths, shape (num_seqs + 1,). The first element should be 0. Each entry represents the starting index of a sequence in the flattened token array. + chunk_size (int): The size of each physical mamba chunk (number of tokens per chunk). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - chunk_indices (torch.Tensor): 1D tensor of indices indicating the physical chunk for each logical chunk. + - chunk_offsets (torch.Tensor): 1D tensor of offsets indicating the starting index of each logical chunk within its physical chunk. + + This function computes the chunk indices and offsets for the given cu_seqlens and chunk_size. + Both are tensors of integers with length N, where N is the number of logical (pseudo) chunks. + A logical chunk is a sequence of tokens that are all part of the same sequence and are all in the same physical mamba chunk. + In other words, a logical chunk changes every time we cross a sequence boundary or a physical mamba chunk boundary. + Logical chunks are needed to handle batched requests with initial states (see _state_passing_fwd and _chunk_scan_fwd). + The chunk_indices tensor contains the index of the physical chunk for each logical chunk. + The chunk_offsets tensor contains the offset (AKA starting index) of the logical chunk in the physical chunk. + + Example: + cu_seqlens = [0, 5, 10] + chunk_size = 8 + -> chunk_indices = [0, 1, 0] + -> chunk_offsets = [0, 5, 0] + + In this example, we have 2 sequences, each with 5 tokens. The physical chunk size is 8 tokens. + We have three logical chunks: + - the first logical chunk starts at token 0 in the first physical chunk and contains all 5 tokens from the first sequence + - the second logical chunk starts at token 5 in the first physical chunk and contains first 3 tokens from the second sequence + - the third logical chunk starts at token 0 in the second physical chunk and contains the remaining 2 tokens from the second sequence + """ + + total_seqlens = cu_seqlens[-1] + cu_seqlens = cu_seqlens[1:] # remove prepended 0 + + # outputs will have length expansion of chunks that do not divide + # chunk_size + N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size + > 0).sum() + chunk_indices = torch.arange(N, dtype=torch.int, device=cu_seqlens.device) + chunk_offsets = torch.zeros((N, ), + dtype=torch.int, + device=cu_seqlens.device) + + p = 0 # num of insertions + for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): + + # if does not divide chunk_size, then there is one chunk insertion + p += (s % chunk_size > 0) + + # get the dimensions + # - the + 1 for _e is to shift the boundary by one chunk + # - this shifting is not needed if chunk_size divides e + _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size > 0) + + # adjust inidces and offsets + chunk_indices[_s:_e] -= p + chunk_offsets[_s] = s % chunk_size + + return chunk_indices, chunk_offsets + + class Mamba2Metadata: - def __init__(self, max_batch_size: int): + def __init__(self, max_batch_size: int, chunk_size: int): self.max_batch_size = max_batch_size + self.chunk_size = chunk_size # cumulative sequence lengths for prefill requests [batch_size+1] self.cu_seqlens = torch.zeros(max_batch_size + 1, @@ -31,9 +99,18 @@ class Mamba2Metadata: # sequence index for prefill requests [num_prefill_tokens] - specifies which request each token belongs to self.seq_idx: torch.Tensor = None + # helper tensors for chunked prefill + self.has_initial_states = torch.zeros(max_batch_size, + dtype=torch.bool, + device="cuda") + self.use_initial_states = False + self.chunk_indices: torch.Tensor = None + self.chunk_offsets: torch.Tensor = None + def prepare(self, attn_metadata: AttentionMetadata): num_contexts = attn_metadata.num_contexts context_lens = attn_metadata.seq_lens_cuda[:num_contexts] + num_ctx_tokens = attn_metadata.num_ctx_tokens if num_contexts > 0: torch.cumsum(context_lens, dim=0, @@ -44,4 +121,17 @@ class Mamba2Metadata: dtype=torch.int, device=self.cu_seqlens.device), repeats=context_lens, - output_size=self.cu_seqlens[num_contexts]).unsqueeze(0) + output_size=num_ctx_tokens).unsqueeze(0) + + num_cached_tokens_per_seq = attn_metadata.kv_cache_params.num_cached_tokens_per_seq + self.has_initial_states[:num_contexts] = torch.tensor( + num_cached_tokens_per_seq[:num_contexts]) > 0 + # precomputed bool to avoid host<->device syncs during forward pass + self.use_initial_states = torch.any( + self.has_initial_states[:num_contexts]).item() + if self.use_initial_states: + self.chunk_indices, self.chunk_offsets = cu_seqlens_to_chunk_indices_offsets( + self.cu_seqlens[:num_contexts + 1], self.chunk_size) + else: + self.chunk_indices = None + self.chunk_offsets = None diff --git a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py index 6ea096bb6a..d5a3e3996a 100644 --- a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py @@ -191,12 +191,15 @@ class Mamba2Mixer(nn.Module): cu_seqlens = mamba_metadata.cu_seqlens[:num_prefills + 1] seq_idx = mamba_metadata.seq_idx + has_initial_states = mamba_metadata.has_initial_states[: + num_prefills] xbc_p = causal_conv1d_fn(xbc_p.transpose(0, 1), self.conv1d.weight, self.conv1d.bias, activation="silu", conv_states=conv_states, + has_initial_state=has_initial_states, query_start_loc=cu_seqlens, cache_indices=state_indices_p).transpose( 0, 1) @@ -216,6 +219,12 @@ class Mamba2Mixer(nn.Module): "b l (h p) -> b l h p", h=self.tp_nheads) + initial_states = None + if mamba_metadata.use_initial_states: + initial_states = torch.where( + has_initial_states[:, None, None, None], + ssm_states[state_indices_p], 0) + y, current_ssm_states = mamba_chunk_scan_combined( x_p, dt_p, @@ -226,7 +235,9 @@ class Mamba2Mixer(nn.Module): D=self.D, z=z_p, dt_bias=self.dt_bias, - initial_states=None, + initial_states=initial_states, + chunk_indices=mamba_metadata.chunk_indices, + chunk_offsets=mamba_metadata.chunk_offsets, dt_softplus=self.delta_softplus, cu_seqlens=cu_seqlens, seq_idx=seq_idx, diff --git a/tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py b/tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py index 58615ab923..23b55d8811 100644 --- a/tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py +++ b/tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py @@ -314,11 +314,12 @@ def _chunk_scan_fwd_kernel( # get the cs at the offset boundary # - c_off == 0 is a passthrough + # - We need dA_cs at the boundary, defined by c_off - no need + # to increase pointer by pid_m (it is a constant offset, + # i.e. the same for all blocks) dA_cs_m_boundary = tl.load( - dA_cumsum_ptr + - (pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize, - mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1) - and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)), + dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize, + mask=(((c_off - 1) > -1) and (c_off < chunk_size)), other=0.0).to(tl.float32) if HAS_SEQ_IDX: diff --git a/tensorrt_llm/_torch/modules/mamba/ssd_combined.py b/tensorrt_llm/_torch/modules/mamba/ssd_combined.py index 0a6f18bb63..8edbe902bd 100644 --- a/tensorrt_llm/_torch/modules/mamba/ssd_combined.py +++ b/tensorrt_llm/_torch/modules/mamba/ssd_combined.py @@ -110,21 +110,24 @@ def _mamba_chunk_scan_combined_fwd( # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) # - for handling chunked prefill, this requires i) initial_states - # ii) seq_idx and iii) is_cont_batched to be all specified. + # ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified. # - When a new seq_idx is detected, we will stop passing the prev_state # and switch accordingly to the init_state corresponding to the new seq_idx. + # - We will also make sure that the dA_cumsum is taken only from the start of the + # sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries) # - this will ensure that states will be updated with the rightmost flushed seq_idx # of the previous chunk. This implies that the first chunk of states is either 0 # or equal to init_states of the first example. states, final_states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), - dA_cumsum[:, :, :, -1], + dA_cumsum, initial_states=(rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None), seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=mamba_ssm_cache_dtype or C.dtype, - is_cont_batched=cu_seqlens is not None) + is_cont_batched=cu_seqlens is not None, + chunk_offsets=chunk_offsets) states, final_states = [ rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states] diff --git a/tensorrt_llm/_torch/modules/mamba/ssd_state_passing.py b/tensorrt_llm/_torch/modules/mamba/ssd_state_passing.py index e1c4b61eaf..f751d4cd5f 100644 --- a/tensorrt_llm/_torch/modules/mamba/ssd_state_passing.py +++ b/tensorrt_llm/_torch/modules/mamba/ssd_state_passing.py @@ -41,6 +41,8 @@ def _state_passing_fwd_kernel( dA_cs_ptr, initstates_ptr, seq_idx_ptr, + chunk_offsets_ptr, + chunk_meta_num, # Matrix dimensions dim, nchunks, @@ -61,6 +63,7 @@ def _state_passing_fwd_kernel( stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, + stride_dA_cs_csize, stride_initstates_batch, stride_initstates_head, stride_initstates_dim, @@ -76,7 +79,8 @@ def _state_passing_fwd_kernel( pid_h = tl.program_id(axis=2) pid_m = tl.program_id(axis=0) states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head - dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + ( + chunk_size - 1) * stride_dA_cs_csize out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head final_states_ptr += (pid_b * stride_final_states_batch + pid_h * stride_final_states_head) @@ -105,35 +109,63 @@ def _state_passing_fwd_kernel( other=0.0).to(tl.float32) tl.store(out_ptrs, states, mask=offs_m < dim) out_ptrs += stride_out_chunk - seq_idx = 0 + prev_seq_idx_chunk_end = 0 + logical_chunk_idx = 0 for c in range(nchunks): new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - scale = tl.exp(dA_cs) + scale_mask = True if HAS_SEQ_IDX: # - the seq to pass forward is the one that is flushed to the right # boundary. - # - that is given by seq_idx_new below. - seq_idx_new = tl.load(seq_idx_ptr + - (min((c + 1) * chunk_size, seqlen) - 1) * - stride_seq_idx_seqlen) + # - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk. + seq_idx_chunk_end = tl.load(seq_idx_ptr + (min( + (c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) if HAS_INITSTATES: - if IS_CONT_BATCHED and seq_idx != seq_idx_new: + if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end: # this means in the current chunk the rightmost flushed seq # has changed. # - so we do not propagate the state from previous chunk # - but rather we load that sequence's init state - initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch + initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch # - update state with seq_idx_new's init state states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - else: - scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) - seq_idx = seq_idx_new + # - we need to consider the cumsum only of the last sequence in the chunk + # - find its starting position (given by c_off of the logical chunk index) + # - and subtract the cumsum just before that position from the total cumsum + # - first, update the logical chunk index (add the number of sequences in the current physical chunk): + # sequence index at the start of the current chunk + seq_idx_chunk_start = tl.load(seq_idx_ptr + + min(c * chunk_size, seqlen) * + stride_seq_idx_seqlen) + logical_chunk_idx += (seq_idx_chunk_end - + seq_idx_chunk_start) + # - load the chunk offset: + c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx, + mask=logical_chunk_idx < chunk_meta_num, + other=0) + # - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything + if c_off > 0: + # - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset + dA_cs_boundary = tl.load( + dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize + + (c_off - 1) * stride_dA_cs_csize, + mask=(c_off - 1) > -1 and c_off < chunk_size, + other=0.0) + dA_cs -= dA_cs_boundary + + # - increment logical chunk index for every physical chunk + logical_chunk_idx += 1 + else: + scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end + prev_seq_idx_chunk_end = seq_idx_chunk_end + + scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0) states = scale * states + new_states if c < nchunks - 1: tl.store(out_ptrs, states, mask=offs_m < dim) @@ -146,28 +178,36 @@ def _state_passing_fwd_kernel( def _state_passing_fwd( states, - dA_chunk_cumsum, + dA_cumsum, initial_states=None, seq_idx=None, chunk_size=None, out_dtype=None, is_cont_batched=False, + chunk_offsets=None, ): batch, nchunks, nheads, dim = states.shape - assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + if chunk_size is None: + chunk_size = dA_cumsum.shape[-1] + else: + assert chunk_size == dA_cumsum.shape[-1] + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) if initial_states is not None: if is_cont_batched: # - if cu_seqlens is provided, then the initial states # are used for continuous batching. In which case we # require seq_idx to be provided - assert seq_idx is not None, "" + assert seq_idx is not None, "seq_idx must be provided for continuous batching" + # - we also need chunk_offsets to be provided, to account + # for computation of dA_cumsum from the start of the + # sequence + assert chunk_offsets is not None, "chunk_offsets must be provided for continuous batching" else: # - this is the regular batching case, where initial # states are used are for each example of the batch. assert initial_states.shape == (batch, nheads, dim) if seq_idx is not None: - assert chunk_size is not None seqlen = seq_idx.shape[-1] assert seq_idx.shape == (batch, seqlen) out_dtype = states.dtype if out_dtype is None else out_dtype @@ -183,13 +223,15 @@ def _state_passing_fwd( states, out, final_states, - dA_chunk_cumsum, + dA_cumsum, initial_states, seq_idx, + chunk_offsets, + len(chunk_offsets) if chunk_offsets is not None else 0, dim, nchunks, seqlen if seq_idx is not None else 0, - chunk_size if seq_idx is not None else 0, + chunk_size, states.stride(0), states.stride(1), states.stride(2), @@ -201,9 +243,10 @@ def _state_passing_fwd( final_states.stride(0), final_states.stride(1), final_states.stride(2), - dA_chunk_cumsum.stride(0), - dA_chunk_cumsum.stride(2), - dA_chunk_cumsum.stride(1), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), *(( initial_states.stride(0), initial_states.stride(1), diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 2f0753ed31..20e3aaaa09 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -514,7 +514,8 @@ def create_py_executor_instance( resources[ResourceManagerType.PEFT_CACHE_MANAGER] = peft_cache_manager model_engine.set_lora_model_config( lora_config.lora_target_modules, - lora_config.trtllm_modules_to_hf_modules) + lora_config.trtllm_modules_to_hf_modules, + lora_config.swap_gate_up_proj_lora_b_weight) max_num_sequences = executor_config.max_batch_size * mapping.pp_size diff --git a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py index 96c5957ef9..8cfccb020a 100644 --- a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py @@ -16,7 +16,6 @@ from tensorrt_llm.mapping import CpType from ..distributed import Distributed from .llm_request import (ExecutorRequest, LlmRequest, executor_request_to_llm_request) -from .sampler import Sampler, TorchSampler SHUTDOWN_REQUEST_ID = -1 @@ -707,21 +706,19 @@ class ExecutorRequestQueue: def set_exclude_last_generation_logits(self, disable_overlap_scheduler: bool, - sampler: Sampler) -> None: + pp_size: int) -> None: # When overlap scheduler is enabled then when starting to handle a new prompt, # sample_async is called twice before the first call to update_requests: # - 1st time as a context request that handles on the 1st generated token # - 2nd time as a generation request that handles on the 2nd generated token. # and only after these two calls the sampler's update_request method is called. # So in a sampler that works by the expected flow of handling the logits in - # sample_async (TorchSampler is an anomaly that instead does that on - # update_requests), every update_request doesn't handle the newest token, but one + # sample_async, every update_request doesn't handle the newest token, but one # before it. Since all these calls work on the same request object, then its # logits storage contains the logits of both the token update_requests should work # on, and also its next token. Thus, excluding the last generation logits from any - # getter is required, when not using TorchSampler. - self.should_exclude_last_generation_logits = not disable_overlap_scheduler and not isinstance( - sampler, TorchSampler) + # getter is required. + self.should_exclude_last_generation_logits = not disable_overlap_scheduler and pp_size == 1 def _should_exclude_last_generation_logits(self) -> bool: return self.should_exclude_last_generation_logits diff --git a/tensorrt_llm/_torch/pyexecutor/handle_logits.py b/tensorrt_llm/_torch/pyexecutor/handle_logits.py index 81986df593..b3d7ced6a5 100644 --- a/tensorrt_llm/_torch/pyexecutor/handle_logits.py +++ b/tensorrt_llm/_torch/pyexecutor/handle_logits.py @@ -1,3 +1,4 @@ +from itertools import chain from typing import List import torch @@ -16,9 +17,9 @@ class HandleLogits: context_requests: List[LlmRequest], generation_requests: List[LlmRequest], logits: torch.Tensor, - num_context_logits_prefix_sum: List[int], - max_num_sequences: int, beam_width: int, + num_context_logits_prefix_sum: list[int], + is_generation_model: bool, ): """Handles context and generation logits for a batch of requests. @@ -26,10 +27,24 @@ class HandleLogits: context_requests: List of context requests to process generation_requests: List of generation requests to process logits: Input logits tensor - num_context_logits_prefix_sum: Prefix sum of context logits for each request - max_num_sequences: Maximum number of sequences to process beam_width: Beam width for the generation requests + num_context_logits_prefix_sum: Prefix sum of the logits + is_generation_model: Bool containing whether the model is generation or not """ + if not any(r.py_return_context_logits or r.py_return_generation_logits + for r in chain(context_requests, generation_requests)): + return + + if not is_generation_model: + for llm_req, logits_temp in zip(context_requests, logits): + if logits_temp.ndim == 1: + # For BERT: Add axis to be compatible with LogitsStorage + # (LogitsStorage will interpret this dim as the prompt_len which + # is not relevant for outputting logits of encoder only model). + logits_temp = logits_temp.unsqueeze(0) + llm_req.py_result.append_context_logits(logits_temp) + return + # Copy logits into decoderBuffers.logits for batch_index, llm_req in enumerate(context_requests): logits_begin = num_context_logits_prefix_sum[batch_index] diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index d9f180c0fc..1b3fbfbfc4 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -468,13 +468,16 @@ class PyTorchModelEngine(ModelEngine): def runtime_draft_len(self): return self.max_draft_len if self.enable_spec_decode else 0 - def set_lora_model_config(self, lora_target_modules: list[str], - trtllm_modules_to_hf_modules: dict[str, str]): + def set_lora_model_config(self, + lora_target_modules: list[str], + trtllm_modules_to_hf_modules: dict[str, str], + swap_gate_up_proj_lora_b_weight: bool = True): self.lora_model_config = LoraModelConfig( lora_target_modules=lora_target_modules, trtllm_modules_to_hf_modules=trtllm_modules_to_hf_modules, hidden_size=self.model.config.hidden_size, - dtype=torch_dtype_to_str(self.model.config.torch_dtype)) + dtype=torch_dtype_to_str(self.model.config.torch_dtype), + swap_gate_up_proj_lora_b_weight=swap_gate_up_proj_lora_b_weight) @property def use_mrope(self): diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index a40b9b9045..453434d9d6 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -39,6 +39,7 @@ from ..models.modeling_utils import DecoderModelForCausalLM from ..speculative.drafter import Drafter from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem from .guided_decoder import GuidedDecoder +from .handle_logits import HandleLogits from .kv_cache_transceiver import KvCacheTransceiver from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState, LlmResponse) @@ -244,7 +245,7 @@ class PyExecutor: is_disaggregated=kv_cache_transceiver is not None, ) self.executor_request_queue.set_exclude_last_generation_logits( - self.disable_overlap_scheduler, self.sampler) + self.disable_overlap_scheduler, self.dist.pp_size) self.stats_lock = threading.Lock() self.stats = [] @@ -681,24 +682,6 @@ class PyExecutor: self.response_cv.notify_all() self.shutdown_event.set() - def _need_return_logits(self, scheduled_requests: ScheduledRequests): - for req in scheduled_requests.context_requests: - if req.py_return_context_logits: - return True - for req in scheduled_requests.generation_requests: - if req.py_return_generation_logits: - return True - return False - - def _need_return_log_probs(self, scheduled_requests: ScheduledRequests): - for req in scheduled_requests.context_requests: - if req.py_return_log_probs: - return True - for req in scheduled_requests.generation_requests: - if req.py_return_log_probs: - return True - return False - def _executor_loop_pp(self): logger.debug(f"Starting executor loop for pp_rank {self.dist.pp_rank}") torch.cuda.set_device(self.device_id) @@ -790,10 +773,6 @@ class PyExecutor: else: with torch.cuda.nvtx.range("_forward_step_last_pp"): batch_outputs = self._forward_step(scheduled_batch) - logits_host = None - if self._need_return_logits(scheduled_batch): - logits_host = batch_outputs["logits"].to( - "cpu", non_blocking=True) if self.kv_cache_transceiver and self.guided_decoder: self.guided_decoder.init_disagg_gen_requests( scheduled_batch) @@ -802,7 +781,6 @@ class PyExecutor: sample_state = self._sample_async( scheduled_batch, batch_outputs) - sample_state.host.logits = logits_host self._update_request_states(scheduled_batch) if self.enable_iter_perf_stats: @@ -832,18 +810,10 @@ class PyExecutor: torch.cuda.nvtx.range_push( "_handle_new_tokens_inter_pp") # Receive tokens from previous pp rank (w.r.t model forward direction) - ( - logits, - sample_state.host, - ) = self.dist.recv_object( + sample_state.host = self.dist.recv_object( src=self.dist.prev_pp_rank, tag=prev_microbatch_id, ) - if logits is not None: - logits_host = torch.from_numpy(logits) - sample_state.host.logits = logits_host - sample_state.device.logits = logits_host.to( - self.device_id) else: torch.cuda.nvtx.range_push("_handle_new_tokens_last_pp") sample_state.sampler_event.synchronize() @@ -853,18 +823,9 @@ class PyExecutor: if not self.dist.is_second_last_pp_rank: if self.send_handles[prev_microbatch_id] is not None: self.send_handles[prev_microbatch_id].wait() - needs_logits = ( - self._need_return_logits(scheduled_batch) - or (self._need_return_log_probs(scheduled_batch) - and sample_state.host.log_probs is not None)) - serialized_logits = sample_state.host.logits.numpy( - ) if needs_logits else None self.send_handles[ prev_microbatch_id] = self.dist.isend_object( - ( - serialized_logits, - sample_state.host, - ), + sample_state.host, dest=self.dist.next_pp_rank, tag=prev_microbatch_id) torch.cuda.nvtx.range_pop() @@ -884,6 +845,40 @@ class PyExecutor: previous_batch.scheduled_ctx_reqs) self._handle_canceled_requests() + + # If logits were requested last PP rank has to send to first PP rank (who sends responses) the + # logits of the requests that have finished. + # NOTE: If the rank processing the logits ever becomes the same as + # the rank sending the responses, this code can be removed. + finished_reqs = [ + r for r in previous_batch.sample_state. + scheduled_requests.all_requests() + if r.state == LlmRequestState.GENERATION_COMPLETE + and (r.py_return_context_logits + or r.py_return_generation_logits) + ] + if self.dist.is_first_pp_rank and len(finished_reqs): + finished_reqs_py_results = [ + r.py_result for r in finished_reqs + ] + finished_reqs_py_results = self.dist.recv_object( + src=self.dist.prev_pp_rank, + tag=prev_microbatch_id, + ) + for req, py_result in zip(finished_reqs, + finished_reqs_py_results): + req.py_result = py_result + + elif self.dist.is_last_pp_rank and len(finished_reqs): + if self.send_handles[ + prev_microbatch_id] is not None: + self.send_handles[prev_microbatch_id].wait() + self.send_handles[ + prev_microbatch_id] = self.dist.isend_object( + [r.py_result for r in finished_reqs], + dest=self.dist.next_pp_rank, + tag=prev_microbatch_id) + finished_requests = self._handle_responses() previous_scheduled_batch = previous_batch.sample_state.scheduled_requests self.resource_manager.update_resources( @@ -1538,7 +1533,22 @@ class PyExecutor: batch_outputs) -> SampleState | None: try: if batch_outputs is not None: - return self.sampler.sample_async(scheduled_batch, batch_outputs) + num_context_logits_prefix_sum = [0] + prefix_sum = 0 + for request in scheduled_batch.context_requests: + prefix_sum += request.context_chunk_size if request.py_return_context_logits else 1 + num_context_logits_prefix_sum.append(prefix_sum) + + HandleLogits()(scheduled_batch.context_requests, + scheduled_batch.generation_requests, + batch_outputs["logits"], + self.sampler.beam_width( + scheduled_batch.all_requests()), + num_context_logits_prefix_sum, + self.sampler.is_generation_model()) + + return self.sampler.sample_async(scheduled_batch, batch_outputs, + num_context_logits_prefix_sum) except Exception as e: traceback.print_exc() error_msg = str(e) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index be898a54a7..715f5c7b47 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -9,7 +9,7 @@ import torch import tensorrt_llm from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType -from tensorrt_llm._utils import get_sm_family, get_sm_version +from tensorrt_llm._utils import get_sm_version from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig from tensorrt_llm.logger import logger @@ -23,7 +23,7 @@ from ..speculative import (get_num_extra_kv_tokens, get_spec_drafter, get_spec_resource_manager) from ._util import (KvCacheCreator, _adjust_torch_mem_fraction, create_py_executor_instance, instantiate_sampler, is_mla) -from .config import PyTorchConfig +from .config import LoadFormat, PyTorchConfig from .config_utils import is_mla from .guided_decoder import GuidedDecoder from .model_engine import PyTorchModelEngine @@ -252,13 +252,16 @@ def create_py_executor( with mem_monitor.observe_creation_stage( _ExecutorCreationStage.MODEL_ENGINE_DRAFT): draft_spec_config = copy.copy(spec_config) + draft_pytorch_backend_config = copy.copy(pytorch_backend_config) + if spec_config.load_format == "dummy": + draft_pytorch_backend_config.load_format = LoadFormat.DUMMY # The draft model won't have any draft tokens attached to # generation requests when we invoke it autoregressively draft_spec_config.max_draft_len = 0 draft_model_engine = PyTorchModelEngine( model_path=spec_config.speculative_model_dir, - pytorch_backend_config=pytorch_backend_config, + pytorch_backend_config=draft_pytorch_backend_config, batch_size=executor_config.max_batch_size, max_beam_width=executor_config.max_beam_width, max_num_tokens=executor_config.max_num_tokens, diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 9a5b42166d..4066b45cf8 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -1040,7 +1040,8 @@ class PeftCacheManager(BaseResourceManager): self._lora_model_config = LoraModelConfig( lora_config.lora_target_modules, lora_config.trtllm_modules_to_hf_modules, model_config.hidden_size, - binding_to_str_dtype(model_config.data_type)) + binding_to_str_dtype(model_config.data_type), + lora_config.swap_gate_up_proj_lora_b_weight) self._lora_manager = LoraManager() def add_request_peft(self, request: LlmRequest): diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 919b99be2d..e6d19a9df4 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -5,7 +5,6 @@ from typing import List, Literal, Optional import torch -from tensorrt_llm._torch.pyexecutor.handle_logits import HandleLogits from tensorrt_llm._torch.pyexecutor.make_decoding_batch_input_output import \ MakeDecodingBatchInputOutput from tensorrt_llm._utils import nvtx_range, torch_dtype_to_binding @@ -30,7 +29,6 @@ from .scheduler import ScheduledRequests @dataclass(kw_only=True) class SampleStateTensors: new_tokens: torch.Tensor - logits: torch.Tensor | None = None log_probs: torch.Tensor | None = None def values(self): @@ -58,14 +56,24 @@ class Sampler(ABC): return None @abstractmethod - def sample_async(self, scheduled_requests: ScheduledRequests, - model_outputs) -> SampleState: + def sample_async(self, scheduled_requests: ScheduledRequests, model_outputs, + num_context_logits_prefix_sum: list[int]) -> SampleState: raise NotImplementedError @abstractmethod def update_requests(self, state: SampleState) -> None: raise NotImplementedError + @staticmethod + def beam_width(scheduled_requests: Iterable[LlmRequest]) -> int: + for req in scheduled_requests: + return req.sampling_config.beam_width + return 0 + + @abstractmethod + def is_generation_model(self) -> bool: + raise NotImplementedError + class EarlyStopSampler(Sampler): """ @@ -73,10 +81,9 @@ class EarlyStopSampler(Sampler): such as encoder-only model (e.g., BERT) or reward models that only need context phase. """ - def sample_async(self, scheduled_requests: ScheduledRequests, - model_outputs) -> SampleState: - host = SampleStateTensors(logits=model_outputs['logits'], - new_tokens=torch.empty(0)) + def sample_async(self, scheduled_requests: ScheduledRequests, model_outputs, + num_context_logits_prefix_sum: list[int]) -> SampleState: + host = SampleStateTensors(new_tokens=torch.empty(0)) return SampleState(scheduled_requests=scheduled_requests, host=host) def update_requests(self, state: SampleState) -> None: @@ -87,14 +94,9 @@ class EarlyStopSampler(Sampler): request.state = LlmRequestState.GENERATION_COMPLETE # NOTE: This is a hack: set finish reason manually and set the beam 0 request.set_finished_reason(FinishReason.LENGTH, 0) - if request.py_return_context_logits: - logits = state.host.logits[idx] - if logits.ndim == 1: - # For BERT: Add axis to be compatible with LogitsStorage - # (LogitsStorage will interpret this dim as the prompt_len which - # is not relevant for outputting logits of encoder only model). - logits = logits.unsqueeze(0) - request.py_result.append_context_logits(logits) + + def is_generation_model(self) -> bool: + return False @dataclass(kw_only=True) @@ -117,8 +119,10 @@ class EarlyStopWithMMResult(Sampler): Use for skipping decoding step for non generation model, and return the batch_output (such as mm_embeddings) """ - def sample_async(self, scheduled_requests: ScheduledRequests, - model_outputs) -> SampleStateWithMMResult: + def sample_async( + self, scheduled_requests: ScheduledRequests, model_outputs, + num_context_logits_prefix_sum: list[int] + ) -> SampleStateWithMMResult: # from model_outputs to MultimodalResult data = MultimodalResult(mm_embeddings=model_outputs['mm_embeddings']) return SampleStateWithMMResult(scheduled_requests=scheduled_requests, @@ -141,6 +145,9 @@ class EarlyStopWithMMResult(Sampler): request.py_result.append_mm_embeddings(mm_embedding) + def is_generation_model(self) -> bool: + return False + def top_k_sampling_batch(logits, top_k=50, @@ -352,6 +359,9 @@ class TorchSampler(Sampler): BEAM = 0 MAX_BEAM_WIDTH = BEAM + 1 + def is_generation_model(self) -> bool: + return True + @dataclass(frozen=True, kw_only=True) class Store: new_tokens: torch.Tensor @@ -445,13 +455,9 @@ class TorchSampler(Sampler): return False - def handle_logits(self, request: LlmRequest, state: SampleState, *, - beam: int, count: int): + def handle_logprobs(self, request: LlmRequest, state: SampleState, *, + beam: int, count: int): current_slice = slice(0, count), request.py_seq_slot, beam - if request.py_return_generation_logits: - assert state.host.logits is not None - current_logits = state.host.logits[current_slice] - request.py_result.append_generation_logits(current_logits) if request.py_return_log_probs: assert state.host.log_probs is not None log_probs = state.host.log_probs[request.py_seq_slot][beam][:count] @@ -546,7 +552,7 @@ class TorchSampler(Sampler): continue new_token = add_token(req, new_tokens, beam=self.BEAM) self._handle_stop_criteria(req, new_token) - self.handle_logits(req, state, beam=self.BEAM, count=1) + self.handle_logprobs(req, state, beam=self.BEAM, count=1) req.py_decoding_iter += 1 for req in state.scheduled_requests.generation_requests: @@ -558,37 +564,28 @@ class TorchSampler(Sampler): req.py_num_accepted_draft_tokens = num_accepted req.py_rewind_len = req.py_draft_pages_allocated - num_accepted processed += num_accepted - self.handle_logits(req, state, beam=self.BEAM, count=processed) + self.handle_logprobs(req, state, beam=self.BEAM, count=processed) req.py_decoding_iter += 1 - def log_probs_host(self, requests: Iterable[LlmRequest]): + def log_probs_host(self, scheduled_requests: ScheduledRequests): """Shape: In lockstep with TRTLLMSampler: https://github.com/NVIDIA/TensorRT-LLM/blob/cea5dd1e3883b18bf50901a7f196f50a9544c28c/cpp/include/tensorrt_llm/runtime/decoderState.h#L103""" - if any(req.py_return_log_probs for req in requests): + if any(req.py_return_log_probs + for req in scheduled_requests.all_requests()): return torch.empty( (self.max_num_sequences, self.MAX_BEAM_WIDTH, self.max_tokens), device="cpu", pin_memory=True) return None - def gen_logits_host(self, requests: Iterable[LlmRequest], vocab_size: int): - if any(req.py_return_generation_logits for req in requests): - return torch.empty((self.max_tokens, self.max_num_sequences, - self.MAX_BEAM_WIDTH, vocab_size), - device="cpu", - pin_memory=True) - return None - def sample_async(self, scheduled_requests: ScheduledRequests, - model_outputs: dict[str, torch.Tensor]) -> SampleState: - requests = scheduled_requests.all_requests() + model_outputs: dict[str, torch.Tensor], + num_context_logits_prefix_sum: list[int]) -> SampleState: new_tokens = self.store.new_tokens - vocab_size = model_outputs["logits"].shape[-1] - log_probs_host = self.log_probs_host(requests) - gen_logits_host = self.gen_logits_host(requests, vocab_size) - self._process_requests(requests, + log_probs_host = self.log_probs_host(scheduled_requests) + self._process_requests(scheduled_requests, model_outputs, new_tokens, - gen_logits_host=gen_logits_host, + num_context_logits_prefix_sum, log_probs_host=log_probs_host) new_tokens_host = new_tokens.to(device="cpu", non_blocking=True) sampler_event = torch.cuda.Event() @@ -596,8 +593,7 @@ class TorchSampler(Sampler): return SampleState(scheduled_requests=scheduled_requests, device=SampleStateTensors(new_tokens=new_tokens), host=SampleStateTensors(new_tokens=new_tokens_host, - log_probs=log_probs_host, - logits=gen_logits_host), + log_probs=log_probs_host), sampler_event=sampler_event) @staticmethod @@ -659,19 +655,37 @@ class TorchSampler(Sampler): return logits def _process_requests(self, - requests: list[LlmRequest], + scheduled_requests: ScheduledRequests, model_outputs: dict[str, torch.Tensor], new_tokens: torch.Tensor, + num_context_logits_prefix_sum: list[int], *, - gen_logits_host: torch.Tensor | None = None, log_probs_host: torch.Tensor | None = None): beam_width = self.MAX_BEAM_WIDTH beam = self.BEAM - raw_logits = model_outputs["logits"] + + # raw_logits should contain only the logits from the gen requests. + # If return context logits is requested, fetch only the logits from gen requests. + if any(r.py_return_context_logits + for r in scheduled_requests.context_requests): + gen_logits_indices = [] + total_context_logits = num_context_logits_prefix_sum[-1] + for i in range(len(scheduled_requests.context_requests)): + gen_logits_indices.append(num_context_logits_prefix_sum[i + 1] - + 1) + gen_logits_indices.extend( + range( + total_context_logits, total_context_logits + + len(scheduled_requests.generation_requests))) + raw_logits = model_outputs["logits"][gen_logits_indices] + else: + raw_logits = model_outputs["logits"] + + requests = scheduled_requests.all_requests() num_steps = [1 + get_draft_token_length(req) for req in requests] sum_steps = sum(num_steps) no_draft_tokens = len(requests) == sum_steps - fast_path = not self.enable_mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None + fast_path = not self.enable_mixed_sampler and no_draft_tokens and log_probs_host is None seq_slots_host = torch.as_tensor([r.py_seq_slot for r in requests]) seq_slots = seq_slots_host.to(device="cuda", non_blocking=True) @@ -727,8 +741,6 @@ class TorchSampler(Sampler): new_tokens[current_slice] = next_tokens if request.py_draft_logits is not None: request.py_target_probs = softmax.clone() - if gen_logits_host is not None: - gen_logits_host[current_slice].copy_(logits, non_blocking=True) if log_probs_host is not None: assert beam == 0, "The following call relies on beam_width to be 1 - hence the unsqueeze" token_probs = torch.gather( @@ -769,6 +781,9 @@ class TRTLLMSampler(Sampler): MAX_DECODING_TOKENS = 1 # It must be 1 when not in speculative decoding SampleState = SampleStateTRTLLM + def is_generation_model(self) -> bool: + return True + def __init__( self, executor_config: ExecutorConfig, @@ -864,7 +879,6 @@ class TRTLLMSampler(Sampler): speculative_decoding_fast_logits=False, is_leader_in_orch_mode=False, is_normalize_log_probs=False) - self.algs.handle_logits = HandleLogits() self.algs.make_decoding_batch_input_output = MakeDecodingBatchInputOutput( ) @@ -898,13 +912,6 @@ class TRTLLMSampler(Sampler): slots = torch.tensor([r.py_seq_slot for r in adp], dtype=torch.int32) self.algs.decoder.underlying_decoder().setup(config, batch_size, slots) - @staticmethod - @torch.inference_mode() - def beam_width(scheduled_requests: Iterable[LlmRequest]) -> int: - for req in scheduled_requests: - return req.sampling_config.beam_width - return 0 - def get_cache_indirection(self) -> torch.Tensor | None: return self.store["decoder_state"].cache_indirection_output @@ -920,8 +927,9 @@ class TRTLLMSampler(Sampler): @torch.inference_mode() @nvtx_range("sample_async") - def sample_async(self, scheduled_requests: ScheduledRequests, - model_outputs) -> SampleStateTRTLLM: + def sample_async( + self, scheduled_requests: ScheduledRequests, model_outputs, + num_context_logits_prefix_sum: list[int]) -> SampleStateTRTLLM: batch_size = scheduled_requests.batch_size beam_width = self.beam_width(scheduled_requests.all_requests()) @@ -934,29 +942,10 @@ class TRTLLMSampler(Sampler): self.setup_sampler_step(scheduled_requests) - num_context_logits_prefix_sum = [0] - prefix_sum = 0 - for request in scheduled_requests.context_requests: - prefix_sum += request.context_chunk_size if request.py_return_context_logits else 1 - num_context_logits_prefix_sum.append(prefix_sum) - - if any(r.py_return_context_logits or r.py_return_generation_logits - for r in scheduled_requests.all_requests()): - self.algs.handle_logits(scheduled_requests.context_requests, - scheduled_requests.generation_requests, - model_outputs["logits"], - num_context_logits_prefix_sum, - self.max_num_sequences, beam_width) - # For beam search, cache indirection needs to be updated if beam_width > 1: self._update_cache_indirection_buffer(scheduled_requests) - # TODO: Enable this back once nanobind is merged and/or llm request is a pure python object - # decoding_input = self.algs.make_decoding_batch_input_output( - # scheduled_requests, model_outputs["logits"], beam_width, - # num_context_logits_prefix_sum) - self.store["decoding_input"][ self.micro_batch_idx] = make_decoding_batch_input( scheduled_requests.context_requests, diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 417becf12f..2d4225641b 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List, Optional, Tuple +from typing import List, Optional, Set import torch from torch import nn @@ -35,9 +35,10 @@ class Eagle3ResourceManager(BaseResourceManager): # empty hidden states tensor max_num_tokens = min(max_num_tokens, max_num_requests * self.max_seq_len) - self.hidden_states = torch.empty((max_num_tokens, self.hidden_size * 3), - dtype=self.dtype, - device='cuda') + self.hidden_states = torch.empty( + (max_num_tokens, self.hidden_size * config.num_capture_layers), + dtype=self.dtype, + device='cuda') # sequence length, only used for metadata preparation self.seq_lens = {i: 0 for i in range(max_num_requests)} # start indices of each slot @@ -79,8 +80,7 @@ class Eagle3ResourceManager(BaseResourceManager): @dataclass class Eagle3SpecMetadata(SpecMetadata): hidden_states: List[torch.Tensor] = field(default_factory=list) - num_capture_layers: int = 3 - layers_to_capture: Tuple[int, ...] = field(init=False) + layers_to_capture: Optional[Set[int]] = None target_model_embed_tokens: Optional[torch.nn.Module] = None hidden_size: int = 0 max_num_tokens: int = 0 @@ -90,14 +90,19 @@ class Eagle3SpecMetadata(SpecMetadata): eagle3_resource_manager: Optional[Eagle3ResourceManager] = None def __post_init__(self): - if self.num_layers == 1: - self.layers_to_capture = (0, ) - else: - if self.num_layers <= 5: - raise ValueError("Not enough hidden layers for EAGLE") + if self.layers_to_capture is None: + if self.num_layers == 1: + self.layers_to_capture = (self.num_layers - 1, ) + else: + if self.num_layers <= 5: + raise ValueError( + "Not enough hidden layers for default EAGLE3 capture") - self.layers_to_capture = (1, self.num_layers // 2 - 1, - self.num_layers - 4) + self.layers_to_capture = (1, self.num_layers // 2 - 1, + self.num_layers - 4) + else: + self.layers_to_capture = sorted(list(self.layers_to_capture)) + self.num_capture_layers = len(self.layers_to_capture) # Initialize to 0 to avoid reading uninitialized memory during warmup self.hidden_states_read_indices = torch.zeros([self.max_num_tokens], @@ -186,7 +191,7 @@ class Eagle3OneModelSpecMetadata(SpecMetadata): # The hidden states hidden_states: Optional[torch.Tensor] = None # The layers to be captured - layers_to_capture: Tuple[int, ...] = field(init=False) + layers_to_capture: Optional[Set[int]] = None # The hidden size of the hidden states hidden_size: int = 0 # The max number of tokens @@ -197,14 +202,19 @@ class Eagle3OneModelSpecMetadata(SpecMetadata): batch_indices_cuda: Optional[torch.Tensor] = None def __post_init__(self): - if self.num_layers == 1: - self.layers_to_capture = (1, ) - else: - if self.num_layers <= 5: - raise ValueError("Not enough hidden layers for EAGLE") + if self.layers_to_capture is None: + if self.num_layers == 1: + self.layers_to_capture = (self.num_layers - 1, ) + else: + if self.num_layers <= 5: + raise ValueError( + "Not enough hidden layers for default EAGLE3 capture") - self.layers_to_capture = (1, self.num_layers // 2 - 1, - self.num_layers - 4) + self.layers_to_capture = (1, self.num_layers // 2 - 1, + self.num_layers - 4) + else: + self.layers_to_capture = sorted(list(self.layers_to_capture)) + self.num_capture_layers = len(self.layers_to_capture) self.hidden_states = torch.empty( (self.max_num_tokens, self.hidden_size * len(self.layers_to_capture)), diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index f7cdd92a56..1d306b9029 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -185,6 +185,13 @@ class SpecMetadata: cuda_graph_metadata.__post_init__() return cuda_graph_metadata + def is_layer_capture(self, layer_id: int): + """ + Whether the layer should be captured (eg for Eagle3). + By default, does nothing. + """ + return False + def maybe_capture_hidden_states(self, layer_id: int, hidden_states: torch.Tensor, residual: torch.Tensor) -> None: diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index 7f11142c3f..5d54f2f3be 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -9,6 +9,7 @@ from tensorrt_llm._utils import nvtx_range from tensorrt_llm.logger import logger from ..pyexecutor.guided_decoder import GuidedDecoder +from ..pyexecutor.handle_logits import HandleLogits from ..pyexecutor.llm_request import (LlmRequest, LlmRequestState, get_draft_token_length) from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager @@ -266,7 +267,21 @@ class ModelDrafter(Drafter): """Sample tokens from draft model outputs.""" try: if self.sampler is not None: - return self.sampler.sample_async(draft_batch, outputs) + num_context_logits_prefix_sum = [0] + prefix_sum = 0 + for request in draft_batch.context_requests: + prefix_sum += request.context_chunk_size if request.py_return_context_logits else 1 + num_context_logits_prefix_sum.append(prefix_sum) + + HandleLogits()( + draft_batch.context_requests, + draft_batch.generation_requests, outputs["logits"], + self.sampler.beam_width(draft_batch.all_requests()), + num_context_logits_prefix_sum, + self.sampler.is_generation_model()) + + return self.sampler.sample_async(draft_batch, outputs, + num_context_logits_prefix_sum) return None except Exception as e: logger.error(f"Error in sampling: {str(e)}") diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 2658ce539b..b31512df91 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -268,8 +268,10 @@ class MTPSampler(TorchSampler): req.py_rewind_len = self.draft_len - (num_new_tokens - 1) self._request_common_handling(req, next_draft_tokens_list) - def sample_async(self, scheduled_requests: ScheduledRequests, - outputs: dict[str, torch.Tensor]) -> SampleStateMTP: + def sample_async( + self, scheduled_requests: ScheduledRequests, + outputs: dict[str, torch.Tensor], + num_context_logits_prefix_sum: list[int]) -> SampleStateMTP: # new_tokens_device: accepted tokens, device tensor, shape: batch_size, nextn + 1 # new_tokens_lens_device: accepted lengths, device tensor, shape: batch_size # next_draft_tokens_device: predicted draft tokens, device tensor, shape: batch_size, nextn diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index c4a4ccf7e3..16fef4862b 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -38,6 +38,7 @@ def get_spec_metadata(spec_config, dtype=model_config.torch_dtype, is_draft_model=is_draft_model, eagle3_resource_manager=spec_resource_manager, + layers_to_capture=spec_config.eagle3_layers_to_capture, ) if spec_config.spec_dec_mode.is_eagle3_one_model(): return Eagle3OneModelSpecMetadata( @@ -47,6 +48,7 @@ def get_spec_metadata(spec_config, num_layers=model_config.num_hidden_layers, hidden_size=model_config.hidden_size, max_num_tokens=max_num_tokens, + layers_to_capture=spec_config.eagle3_layers_to_capture, ) if spec_config.spec_dec_mode.is_draft_target() or \ spec_config.spec_dec_mode.is_ngram() or \ diff --git a/tensorrt_llm/bench/dataclasses/reporting.py b/tensorrt_llm/bench/dataclasses/reporting.py index acf7f60bcb..fd76466cd5 100755 --- a/tensorrt_llm/bench/dataclasses/reporting.py +++ b/tensorrt_llm/bench/dataclasses/reporting.py @@ -273,6 +273,22 @@ class ReportUtility: }, } + # Retrieve KV cache information. + kv_cache_config = self.kwargs.get("kv_cache_config", KvCacheConfig()) + if isinstance(kv_cache_config, KvCacheConfig): + kv_cache_dtype = kv_cache_config.dtype + kv_cache_mem_percent = kv_cache_config.free_gpu_memory_fraction + elif isinstance(kv_cache_config, dict): + kv_cache_dtype = kv_cache_config.get("dtype", "auto") + kv_cache_mem_percent = kv_cache_config.get( + "free_gpu_memory_fraction") + else: + raise ValueError( + f"Invalid kv_cache_config type: {type(kv_cache_config)}.") + + kv_cache_mem_percent = f"{kv_cache_mem_percent * 100.0:.2f}%" \ + if kv_cache_mem_percent is not None else "None" + # Engine/Backend details if self.rt_cfg.backend not in ('pytorch', '_autodeploy'): config_path = self.rt_cfg.engine_dir / "config.json" @@ -302,15 +318,6 @@ class ReportUtility: model = self.rt_cfg.model_path or self.rt_cfg.model model_config = ModelConfig.from_pretrained(model, trust_remote_code=True) - kv_cache_config = self.kwargs.get("kv_cache_config", - KvCacheConfig()) - if isinstance(kv_cache_config, KvCacheConfig): - kv_cache_dtype = kv_cache_config.dtype - elif isinstance(kv_cache_config, dict): - kv_cache_dtype = kv_cache_config.get("dtype", "auto") - else: - raise ValueError( - f"Invalid kv_cache_config type: {type(kv_cache_config)}.") validate_and_set_kv_cache_quant(model_config, kv_cache_dtype) @@ -336,8 +343,7 @@ class ReportUtility: "max_batch_size": self.rt_cfg.settings_config.max_batch_size, "max_num_tokens": self.rt_cfg.settings_config.max_num_tokens, "scheduling_policy": self.rt_cfg.settings_config.scheduler_policy, - "kv_cache_percentage": - self.rt_cfg.settings_config.kv_cache_percent * 100.0, + "kv_cache_percentage": kv_cache_mem_percent, "issue_rate": self.convert_rate_to_s(self.statistics.issue_rate_ns) } @@ -526,7 +532,7 @@ class ReportUtility: f"Max Runtime Batch Size: {world_info['max_batch_size']}\n" f"Max Runtime Tokens: {world_info['max_num_tokens']}\n" f"Scheduling Policy: {world_info['scheduling_policy']}\n" - f"KV Memory Percentage: {world_info['kv_cache_percentage']:.2f}%\n" + f"KV Memory Percentage: {world_info['kv_cache_percentage']}\n" f"Issue Rate (req/sec): {world_info['issue_rate']:.4E}\n" f"\n") diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 07eb13d796..c1013eb3c5 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -14,6 +14,7 @@ from torch.cuda import device_count from tensorrt_llm import LLM as PyTorchLLM from tensorrt_llm import MultimodalEncoder from tensorrt_llm._tensorrt_engine import LLM +from tensorrt_llm._torch.auto_deploy.llm import LLM as AutoDeployLLM from tensorrt_llm._utils import mpi_rank from tensorrt_llm.executor.utils import LlmLauncherEnvs from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy, @@ -109,7 +110,7 @@ def get_llm_args(model: str, capacity_scheduler_policy=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, dynamic_batch_config=dynamic_batch_config, ) - + backend = backend if backend in ["pytorch", "_autodeploy"] else None llm_args = { "model": model, @@ -140,7 +141,7 @@ def get_llm_args(model: str, "kv_cache_config": kv_cache_config, "backend": - backend if backend == "pytorch" else None, + backend, "num_postprocess_workers": num_postprocess_workers, "postprocess_tokenizer_dir": @@ -162,9 +163,15 @@ def launch_server(host: str, backend = llm_args["backend"] model = llm_args["model"] - if backend == 'pytorch': llm = PyTorchLLM(**llm_args) + elif backend == '_autodeploy': + # AutoDeploy does not support build_config + llm_args.pop("build_config", None) + # TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/7142): + # AutoDeploy does not support cache reuse yet. + llm_args["kv_cache_config"].enable_block_reuse = False + llm = AutoDeployLLM(**llm_args) else: llm = LLM(**llm_args) @@ -204,10 +211,13 @@ def launch_mm_encoder_server( default="localhost", help="Hostname of the server.") @click.option("--port", type=int, default=8000, help="Port of the server.") -@click.option("--backend", - type=click.Choice(["pytorch", "trt"]), - default="pytorch", - help="Set to 'pytorch' for pytorch path. Default is cpp path.") +@click.option( + "--backend", + type=click.Choice(["pytorch", "trt", "_autodeploy"]), + default="pytorch", + help= + "Set to 'pytorch' for pytorch path and '_autodeploy' for autodeploy path. Default is pytorch path." +) @click.option('--log_level', type=click.Choice(severity_map.keys()), default='info', diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index 1cb86dfdff..78a0d07620 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -317,7 +317,7 @@ class GenerationExecutorProxy(GenerationExecutor): while True: if self.worker_init_status_queue.poll(1): - ready_signal = self.worker_init_status_queue.get() + ready_signal, error_trace = self.worker_init_status_queue.get() break if any(fut.done() for fut in self.mpi_futures): logger.error("Executor worker died during initialization.") @@ -325,6 +325,7 @@ class GenerationExecutorProxy(GenerationExecutor): self._handle_background_error() if ready_signal != GenerationExecutorProxy.READY_SIGNAL: + logger.error(f"Executor worker initialization error: {error_trace}") self.mpi_session.shutdown_abort(reason=ready_signal) raise RuntimeError( "Executor worker returned error") from ready_signal diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 6d5ec9c1d7..8a1dab6a23 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -774,7 +774,7 @@ def worker_main( logger.error(traceback.format_exc()) print_colored_debug(f"error: {traceback.format_exc()}", "red") if is_leader: - worker_init_status_queue.put(e) + worker_init_status_queue.put((e, traceback.format_exc())) return with worker: @@ -792,7 +792,7 @@ def worker_main( mp_stats_queue) worker._set_iteration_result_queue(worker.kv_events_queues, kv_cache_events_queue) - worker_init_status_queue.put(ready_signal) + worker_init_status_queue.put((ready_signal, None)) while (req := request_queue.get()) is not None: if isinstance(req, CancellingRequest): worker.abort_request(req.id) diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index a5e0c0a5a4..458b0a11d8 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -433,7 +433,7 @@ def apply_chat_template( if model_type in PLACEHOLDER_EXCEPTIONS: # flattened content do not work for these models, so go back to other formats as needed conversation = handle_placeholder_exceptions(model_type, conversation, - mm_placeholder_counts) + [mm_placeholder_counts]) return tokenizer.apply_chat_template( conversation=conversation, diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 9022f7070c..43edb6b62c 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -124,15 +124,21 @@ class BaseLLM: self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor) self._llm_id = None + log_level = logger.level + logger.set_level("info") # force display the backend + try: backend = kwargs.get('backend', None) - if backend == 'pytorch': + if backend == "pytorch": + logger.info("Using LLM with PyTorch backend") llm_args_cls = TorchLlmArgs elif backend == '_autodeploy': + logger.info("Using LLM with AutoDeploy backend") from .._torch.auto_deploy.llm_args import \ LlmArgs as AutoDeployLlmArgs llm_args_cls = AutoDeployLlmArgs else: + logger.info("Using LLM with TensorRT backend") llm_args_cls = TrtLlmArgs # check the kwargs and raise ValueError directly @@ -162,6 +168,9 @@ class BaseLLM: f"Failed to parse the arguments for the LLM constructor: {e}") raise e + finally: + logger.set_level(log_level) # restore the log level + print_colored_debug(f"LLM.args.mpi_session: {self.args.mpi_session}\n", "yellow") self.mpi_session = self.args.mpi_session diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index da5071e3b0..6ed4dea76c 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field from enum import Enum, EnumMeta from pathlib import Path from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional, - Type, TypeAlias, TypeVar, Union, get_args, get_origin) + Set, Type, TypeAlias, TypeVar, Union, get_args, get_origin) import torch import yaml @@ -352,6 +352,7 @@ class DecodingBaseConfig(StrictBaseModel): # When specified, speculation will be disabled at batch sizes above # this value. Otherwise, speculation will always be on. max_concurrency: Optional[int] = None + load_format: Optional[str] = None @classmethod def from_dict(cls, data: dict): @@ -424,6 +425,7 @@ class EagleDecodingConfig(DecodingBaseConfig): num_eagle_layers: Optional[int] = None max_non_leaves_per_layer: Optional[int] = None eagle3_one_model: Optional[bool] = True + eagle3_layers_to_capture: Optional[Set[int]] = None @classmethod def from_dict(cls, data: dict): @@ -443,6 +445,17 @@ class EagleDecodingConfig(DecodingBaseConfig): return TorchSpeculativeDecodingMode.EAGLE3_ONE_MODEL return TorchSpeculativeDecodingMode.EAGLE3 + @functools.cached_property + def num_capture_layers(self): + """ + Returns the number of layers to capture of the target model. + If eagle3_layers_to_capture is not None, return the length of the set. + Otherwise, assume Eagle3 base set and return 3. + """ + if self.eagle3_layers_to_capture is not None: + return len(self.eagle3_layers_to_capture) + return 3 + class UserProvidedDecodingConfig(DecodingBaseConfig): # Cannot use real type annotations due to circular imports @@ -523,7 +536,9 @@ class MTPDecodingConfig(DecodingBaseConfig): @classmethod def from_dict(cls, data: dict): - return cls(**data) + out = cls(**data) + out.max_draft_len = out.num_nextn_predict_layers + return out decoding_type: ClassVar[str] = "MTP" diff --git a/tensorrt_llm/lora_helper.py b/tensorrt_llm/lora_helper.py index 37f5d534f7..719df51079 100644 --- a/tensorrt_llm/lora_helper.py +++ b/tensorrt_llm/lora_helper.py @@ -88,6 +88,7 @@ class LoraConfig(DictConversion): trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict) max_loras: Optional[int] = None max_cpu_loras: Optional[int] = None + swap_gate_up_proj_lora_b_weight: bool = True def __post_init__(self): assert self.lora_ckpt_source in [ diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index 7440715474..f1ca920415 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -243,6 +243,7 @@ class LoraModelConfig: trtllm_modules_to_hf_modules: dict[str, str] hidden_size: int dtype: str + swap_gate_up_proj_lora_b_weight: bool = True class HfLoraLoader: @@ -968,16 +969,17 @@ class LoraManager(object): ) hf_modules = set(hf_modules_to_trtllm_modules.keys()) - def preprocess_lora_weights(lora_model): + def preprocess_lora_weights(lora_model, model_config): # Swap weights of gate_up_proj - for key, value in lora_model.items(): - if "gate_up_proj.lora_B.weight" in key: - original_weights = value.contiguous().clone() - half_split = original_weights.shape[0] // 2 - first_half = original_weights[:half_split, :] - second_half = original_weights[half_split:, :] - value = torch.cat((second_half, first_half), dim=0) - lora_model[key] = value + if getattr(model_config, "swap_gate_up_proj_lora_b_weight", True): + for key, value in lora_model.items(): + if "gate_up_proj.lora_B.weight" in key: + original_weights = value.contiguous().clone() + half_split = original_weights.shape[0] // 2 + first_half = original_weights[:half_split, :] + second_half = original_weights[half_split:, :] + value = torch.cat((second_half, first_half), dim=0) + lora_model[key] = value return lora_model def load_from_model_dir(uid, model_dir, hf_config): @@ -989,7 +991,7 @@ class LoraManager(object): lora_model = load_state_dict(get_model_path(model_dir, "adapter_model")) if lora_model is None: raise ValueError(f"Failed to load adapter_model from {model_dir}") - lora_model = preprocess_lora_weights(lora_model) + lora_model = preprocess_lora_weights(lora_model, model_config) all_weights = get_all_hf_lora_weights(lora_model, hf_modules, component) rank = int(hf_config["r"]) rs_lora = bool(hf_config.get("use_rslora", False)) diff --git a/tensorrt_llm/tools/multimodal_builder.py b/tensorrt_llm/tools/multimodal_builder.py index 9a2096852b..de3943c563 100644 --- a/tensorrt_llm/tools/multimodal_builder.py +++ b/tensorrt_llm/tools/multimodal_builder.py @@ -1190,8 +1190,18 @@ def build_mllama_engine(args): model = MllamaForConditionalGeneration.from_pretrained(args.model_path, torch_dtype='auto', device_map='auto') - wrapper = MLLaMAVisionWrapper(model.vision_model, - model.multi_modal_projector) + + # Check if the model structure is updated to transformers >= 4.52.0 + if hasattr(model, 'model') and hasattr(model.model, 'vision_model'): + vision_model = model.model.vision_model + multi_modal_projector = model.model.multi_modal_projector + else: + # transformers < 4.52.0 + vision_model = model.vision_model + multi_modal_projector = model.multi_modal_projector + + wrapper = MLLaMAVisionWrapper(vision_model, multi_modal_projector) + model_dtype = model.dtype image = Image.new('RGB', [2048, 2688]) # dummy image inputs = processor(images=image, diff --git a/tensorrt_llm/tools/profiler/nsys_profile_tools/README.md b/tensorrt_llm/tools/profiler/nsys_profile_tools/README.md new file mode 100644 index 0000000000..b7b9f084de --- /dev/null +++ b/tensorrt_llm/tools/profiler/nsys_profile_tools/README.md @@ -0,0 +1,174 @@ +# gputrc2graph.py + +This script processes NVIDIA Nsight Systems (`nsys`) GPU trace files +(`.nsys-rep`) with -t cuda tracing enabled, and generates kernel-level +summaries and visualizations of GPU and non-GPU time. It is useful for +profiling and analyzing nsys profile output. + +## Usage + +### Command-line Arguments + +- `--in_file` + **(required)** + List of input files and their metadata. Each entry should be in the format: + `,,,` + - `nsys-rep`: Path to the `.nsys-rep` file. + - `engine`: Engine name (e.g., `trtllm`). + - `model`: Model name (e.g., `llama`, `gpt-oss`, `ds`). + - `elapsed_nonprofiled_sec`: Wall-clock runtime (in seconds) without + profiling. Specify `0` to use the elapsed GPU time calculated from the nsys-rep file (this may inflate non-GPU time if actual runtime without profiling is less). Multiple entries can be provided, separated by spaces. + +- `--out_dir` + Output directory for the generated CSV and HTML files. + If not specified, results are saved in the current directory. + +- `--title` + Title for the HTML chart/visualization. + +- `--nsys_cmd` + Path to the `nsys` command. + Default: `nsys` (assumes it is in your PATH). + Use this if `nsys` is not in your system PATH. + +## Notes + +- Make sure you have pandas and plotly python packages installed. +- Make sure [nsys](https://developer.nvidia.com/nsight-systems/get-started) is +installed, and specify the path to the `nsys` command with `--nsys_cmd` if it + is not in your PATH. +- For more details on available engines and models, see the help string in + the script or run: + +```bash +python3 gputrc2graph.py --help +``` + +## Example 1: analyze a single profile + +To analyze the GPU cycles of for example, a llama-3.1-8B model with trtllm: + +1. Run the following command to collect nsys profile, for trtllm serve config. + + ```bash + nsys profile -t cuda -o nsys_res -f true --trace-fork-before-exec=true \ + --cuda-graph-trace=node --delay --duration \ + python3 -m trtllm-serve meta-llama/Llama-4-Scout-17B-16E-Instruct ... + ``` + + where: + + - DELAY: how many seconds to delay nsys from collecting profiles, needed so + that profiles aren't captured till trtllm server has come up and load + generation starts. + - DURATION: how many seconds for nsys profile to run before generating the + profile. This should be > the duration of the run. + +2. Run again, this time without collecting the profile, and get the total run + time in seconds. This value will be used by the script to calculate the + CPU(non-GPU) seconds for the analysis. + +3. Say the run elapsed time is .35 seconds, from step #2. Run script to + analyze: + + ```bash + python3 gputrc2graph.py \ + --in_file run1.nsys-rep,trtllm,llama,.35 + ``` + +The command will produce 2 files for analysis: + +- result.html: this categorizes kernel names into different categories in a + stacked bar chart. +- result.csv: shows how the kernel names are mapped to the different + categories. + +### HTML visualization with result.html + +The html file shows the number of elapsed seconds due to different GPU +Substages or categories, which consist of moe_gemm as the biggest +category, at .14 seconds, followed by "attn" kernels. This lets the user +prioritize the kernels to focus on for performance optimizations. + +![Example GPU Trace Visualization](images/html.png) + +There's also an appended data table underneath the bar chart for copying out to + other post-processing tools. + +![Example GPU Trace Visualization Table](images/html_tbl.png) + +### Kernel to category mapping with result.csv + +Suppose the user would like to focus on improving decreasing calls to nccl +kernels. The next step is to use the result.csv to dive into what the kernels +are which compose the nccl GPU cycles. The following image shows that +ar_fusion all reduce kernel to be the biggest contributor to GPU cycles for +nccl, followed by AllGather. + +![Example GPU Trace csv](images/csv.png) + +## Example 2: analyze multiple profiles + +Suppose the user has multiple nsys trace files, captured for different models, +say llama and gpt-oss in this case, and wish to compare their GPU/non-GPU +time, something like the following command can be used. + +```bash +python3 gputrc2graph.py \ +--in_file run1.nsys-rep,trtllm,llama,100 run2.nsys-rep,trtllm,gpt-oss,102 \ +--out_dir results +``` + +The analysis process is similar to example 1 but now there will be multiple +stack bar charts that can be compared. The categories for the different +kernels will remain the same, so that it's easy to compare the GPU cycles for +the same categories. + +Once a category is shown to have more cycles for one configuration than +another, the next step would be to use the csv file to see what kernels are +mapped into that category, and which kernels are taking the largest amount of +time which would cause a difference for the overall category. + +## Example 3: add new classification for a new model + +To create a new engine DEF with model ABC, just add another json file in the +same directory as gputrc2graph.py with the same format as the other json files. +The script will automatically pick up all the json files in the same directory +as engine/model specifications. + +Then, for this new model, suppose there are 4 kernels to be classified into +"gemm" and "attn", where the gemm kernelshave names with "*H*" or "*I*" in +them, and attn kernels have names with "*J*" or "*K*" in them, just add another + .json file in the same directory as gputrc2graph.py with the same format as + the other json files, like the following: + +```json +{ + "DEF": { + "ABC": { + "H|I": "gemm", + "J|K": "attn", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + } + } +} +``` + +Each entry in the dictionary consists of: + +- key: a regex used to classify the kernels +- value: the category to classify the kernels into. + +The last 2 entries are common for all engine/models, consisting of CUDA memory +operations and a 'misc' for anything that's leftover and can't be classified. + +When invoking gputrc2graph.py, specify a trace file with this new model/engine +like the following: + +```bash +--in_file new.nsys-rep,DEF,ABC, +``` + +If the engine_DEF.json file already exists, just add the model as a new node in + the existing engine file, after the other models. diff --git a/tensorrt_llm/tools/profiler/nsys_profile_tools/gputrc2graph.py b/tensorrt_llm/tools/profiler/nsys_profile_tools/gputrc2graph.py new file mode 100755 index 0000000000..1ca8a0ff23 --- /dev/null +++ b/tensorrt_llm/tools/profiler/nsys_profile_tools/gputrc2graph.py @@ -0,0 +1,349 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + This generates gpu kernel analysis output from nsys rep. Will call nsys + stats -r cuda_gpu_trace, get non-overlapped gpu cycles, then generate + csv and html output for analysis +""" + +import argparse +import logging +import os + +import regex as re + +logger = logging.getLogger(__name__) + + +# helper data class for annotating kernels +def load_engine_model(): + """returns engine_model built from all json files in the current dir""" + import glob + import json + + engine_model = {} + + json_files = glob.glob( + os.path.join(os.path.dirname(__file__) or ".", "*.json")) + for fname in json_files: + with open(fname, encoding="utf-8") as f: + engine_model.update(json.load(f)) + return engine_model + + +class GPUTrace2Graph: + """ + Parses output of nsys report, generates csv and bar chart output + """ + + def __init__(self): + import pandas as pd # avoid importing till needed + + self.pd = pd + self.pd.options.mode.copy_on_write = True + + # helper functions for generating trace->summary csvs + def gen_nonoverlapped_sum_from_gputrace(self, in_file, out_file): + logger.info("loading %s", in_file) + df = self.pd.read_csv(in_file, + usecols=["Start (ns)", "Duration (ns)", "Name"]) + if df.empty: + return + df["End (ns)"] = df["Start (ns)"] + df["Duration (ns)"] + df = self.sum_non_overlapping_intervals(df) + # get ready to print table with elapsed times per kernel + df["Instances"] = 1 + df_sum = df.groupby("Name", as_index=False).agg({ + "Elapsed Time (ns)": "sum", + "Duration (ns)": "sum", + "Instances": "size" + }) + + # generate csv + df_sum["Total Time (sec)"] = df_sum["Duration (ns)"] / 1e9 + df_sum["Elapsed Time (sec)"] = df_sum["Elapsed Time (ns)"] / 1e9 + df_sum = df_sum.sort_values(by="Elapsed Time (sec)", ascending=False) + df_sum[["Elapsed Time (sec)", "Total Time (sec)", "Instances", + "Name"]].to_csv(out_file, index=False) + + def sum_non_overlapping_intervals(self, df): + """ + returns new sorted df with Elapsed Time (ns) column using + vectorized operations + """ + logger.info("sorting %s trace records by start time", str(df.shape)) + assert not df.empty, 'empty nsys records' + # Sort by start time and reset index + df = df.sort_values(by="Start (ns)").reset_index(drop=True) + + # Initialize elapsed time as duration + df["Elapsed Time (ns)"] = df["Duration (ns)"] + + # Get numpy arrays for faster operations + starts = df["Start (ns)"].values + ends = df["End (ns)"].values + + # Keep track of current interval end + current_end = ends[0] + display_units = max(1, int(len(df) / 100)) + # Update current_end for overlapping intervals + for i in range(1, len(df)): + if i % display_units == 0: + print(f"processing trace: {int(i/len(df) * 100)} %", end="\r") + if starts[i] <= current_end: + if ends[i] > current_end: + # Partial overlap + df.iloc[i, df.columns.get_loc("Elapsed Time (ns)")] = ( + ends[i] - current_end) + current_end = ends[i] + else: + # Complete overlap + df.iloc[i, df.columns.get_loc("Elapsed Time (ns)")] = 0 + else: + # No overlap + current_end = ends[i] + + return df + + # functions for generating html files + def make_html(self, df, output_dir, title): + """make html graph from df""" + import plotly.express as px + + if df.empty: + return + output_name = os.path.join(output_dir, "result") + if not title: + title = "Model_Engine" + x = "Model_Engine" + y = "Elapsed Time (sec)" + color = "Category" + """ generate kernel mapping table """ + # Sort Model_Engine categories by last field after underscore + df["Model_Engine"] = self.pd.Categorical( + df["Model_Engine"], + sorted(df["Model_Engine"].unique(), key=lambda x: x.split("_")[-1]), + ) + df[["Model_Engine", color, "Instances", "Name", + y]].sort_values(by=color).to_csv(f"{output_name}.csv", index=False) + graph = px.histogram( + df.round(2), + x=x, + y=y, + title=(f"{y} for {title}"), + color=color, + text_auto=True, + ) + # wrap x axis labels + graph.update_xaxes(automargin=True) + graph.write_html(f"{output_name}.html") + """ + Generate data table with columns per Model_Engine into result.html + """ + pivot_df = df.pivot_table( + values="Elapsed Time (sec)", + index="Category", + columns="Model_Engine", + aggfunc="sum", + observed=False, + ).round(2) + # Add sum row at bottom + pivot_df.loc["total_elapsed_sec"] = pivot_df.sum() + pivot_df.fillna("").to_html("temp.html") + with ( + open(f"{output_name}.html", "a", encoding="utf-8") as outfile, + open("temp.html", encoding="utf-8") as infile, + ): + outfile.write(infile.read()) + os.remove("temp.html") + + print(f"Finished generating: \n" + f" {output_name}.html for stack bar chart \n" + f" {output_name}.csv for Kernel-Category mapping") + + def anno_gpu_kernname(self, df, mapping): + """add "Category" column""" + + def anno_gpu_kernname_helper(name): + for kern_name, val in mapping.items(): + if re.search(kern_name, name): + return val + + df["Category"] = df["Name"].apply(anno_gpu_kernname_helper) + + def make_nongpu_row(self, df, nongpu_sec): + """this will append non-gpu time entry at end of df""" + nongpu_row = self.pd.DataFrame([df.iloc[-1]]) + nongpu_row["Category"] = nongpu_row["Name"] = "CPU(non-GPU)" + nongpu_row["Instances"] = 1 + nongpu_row["Elapsed Time (sec)"] = nongpu_sec + return nongpu_row + + def is_valid_file(self, base_file): + """asserts if base_file is non-existent or is empty""" + assert (os.path.isfile(base_file) and os.path.getsize(base_file) + > 0), f"{base_file} doesn't exist or is empty" + + def should_gen_file(self, new_file, base_file): + """figure out if new file should be generated from base_file""" + self.is_valid_file(base_file) + if (os.path.exists(new_file) + and (os.path.getmtime(new_file) > os.path.getmtime(base_file)) + and (os.path.getsize(base_file) > 0)): + logger.info("reusing %s", new_file) + return False + else: + logger.info("generating %s", new_file) + return True + + def gen_sum_file(self, file, nsys_cmd): + """ + generates sum file from nsys trace with times per kernel and + returns the name of the sum file + """ + import subprocess # nosec B404 + + file_dir = os.path.dirname(file) + file_name = os.path.basename(file) + + if not file_dir: + file_dir = "." + # Walk through trace and get the total non-overlapped time + nsys_stats_file = os.path.join(file_dir, + f"{file_name}_cuda_gpu_trace.csv") + sum_file = os.path.join(file_dir, + f"{file_name}_cuda_gpu_kernel_tracesum.csv") + if self.should_gen_file(nsys_stats_file, file): + cmd = [ + nsys_cmd, + "stats", + "-r", + "cuda_gpu_trace", + file, + "-o", + f"{file_dir}/{file_name}", + ] + cmd_str = " ".join(cmd) + logger.info("+ %s", cmd_str) + # estimate time based on calibrated 240M/min + file_size_mb = os.path.getsize(file) / 1e6 + logger.info( + "nsys stats for %.2f MB file expected to take %.2f min", + file_size_mb, + file_size_mb / 240, + ) + try: + subprocess.run(cmd) + except Exception: + logger.error("%s failed; Use --nsys_cmd to specify nsys path", + cmd_str) + exit(1) + logger.info("generating non-overalapped sum %s", sum_file) + self.gen_nonoverlapped_sum_from_gputrace(nsys_stats_file, sum_file) + self.is_valid_file(sum_file) + logger.info("Finished generating %s", sum_file) + return sum_file + + def gen_graph(self, in_file, out_dir, title, nsys_cmd, engine_model): + """generates graph and csv file from in_file into out_dir""" + # Initialize an empty DataFrame to store combined data + combined_df = self.pd.DataFrame() + for idx, (file, engine, model, total_sec) in enumerate(in_file): + file_dir = os.path.dirname(file) + file_name = os.path.basename(file) + if not file_dir: + file_dir = "." + sum_file = self.gen_sum_file(file, nsys_cmd) + # read kernel summary file + df = self.pd.read_csv(sum_file) + # annotate kernel to their categories + assert engine_model.get(engine), f"engine {engine} unknown" + assert engine_model[engine].get(model), f"model {model} unknown" + # remove nsys-rep from file_name for shorter x-label + file_name = file_name.replace(".nsys-rep", "") + df["Model_Engine"] = f"{model}_{engine}_{file_name}_{idx}" + self.anno_gpu_kernname(df, engine_model[engine][model]) + # patch in non-gpu time + gpu_sec = round(df["Elapsed Time (sec)"].sum(), 1) + total_sec = round(float(total_sec), 1) + if total_sec < gpu_sec: + logger.warning( + "Elapsed sec %.2f < GPU sec %.2f resetting Elapsed sec ", + total_sec, + gpu_sec, + ) + total_sec = gpu_sec + nongpu_row = self.make_nongpu_row(df, total_sec - gpu_sec) + df = self.pd.concat([df, nongpu_row], ignore_index=True) + combined_df = self.pd.concat([combined_df, df], ignore_index=True) + if out_dir is None: + out_dir = "." + else: + os.makedirs(out_dir, exist_ok=True) + # generate html file + self.make_html(combined_df, out_dir, title) + + +def parse_tuple(s): + return tuple(s.split(",")) + + +def main(): + logging.basicConfig(format=("%(asctime)s - %(levelname)s - %(message)s"), + level=logging.INFO) + parser = argparse.ArgumentParser( + description=( + "Process nsys rep and generate kernel non-overlapped cycles. \n" + "Example:\n" + "gputrc2graph.py --in_file d1.nsys-rep,trtllm,llama,100 \n" + "d2.nsys-rep,trtllm,gpt-oss,102 " + '--out_dir results/ --title "Model=gpt-oss TRTLLM chart"'), + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + # load supported engine_model + engine_model_supported = load_engine_model() + # Get a string representation of supported engine/model combinations + engine_model_supported_str = ", ".join( + f"{engine}:[{', '.join(models.keys())}]" + for engine, models in engine_model_supported.items()) + parser.add_argument( + "--in_file", + type=parse_tuple, + nargs="+", + help=("list of (nsys-rep, engine, model, elapsed_nonprofiled_sec) " + "separated by space. Elapsed_nonprofiled_sec is runtime without " + "profiling used to calculate non-gpu time. Specify 0 to use " + "elapsed time from nsys-rep but that might inflate non-gpu time. " + f"Available engine:[model] are: {engine_model_supported_str} " + f"Example: --in_file d1.nsys-rep,sglan,llama,100 " + "d2.nsys-rep,trtllm,gpt-oss,102"), + required=True, + ) + parser.add_argument("--out_dir", help=("output dir for result.csv/html")) + parser.add_argument("--title", help=("title for html chart")) + parser.add_argument( + "--nsys_cmd", + help=("nsys cmd, e.g. /usr/bin/nsys, Default: nsys"), + default="nsys", + ) + args = parser.parse_args() + gputrace = GPUTrace2Graph() + gputrace.gen_graph(args.in_file, args.out_dir, args.title, args.nsys_cmd, + engine_model_supported) + + +if __name__ == "__main__": + main() diff --git a/tensorrt_llm/tools/profiler/nsys_profile_tools/images/csv.png b/tensorrt_llm/tools/profiler/nsys_profile_tools/images/csv.png new file mode 100644 index 0000000000..3fd412f657 Binary files /dev/null and b/tensorrt_llm/tools/profiler/nsys_profile_tools/images/csv.png differ diff --git a/tensorrt_llm/tools/profiler/nsys_profile_tools/images/html.png b/tensorrt_llm/tools/profiler/nsys_profile_tools/images/html.png new file mode 100644 index 0000000000..35992c7f89 Binary files /dev/null and b/tensorrt_llm/tools/profiler/nsys_profile_tools/images/html.png differ diff --git a/tensorrt_llm/tools/profiler/nsys_profile_tools/images/html_tbl.png b/tensorrt_llm/tools/profiler/nsys_profile_tools/images/html_tbl.png new file mode 100644 index 0000000000..cb134a014c Binary files /dev/null and b/tensorrt_llm/tools/profiler/nsys_profile_tools/images/html_tbl.png differ diff --git a/tensorrt_llm/tools/profiler/nsys_profile_tools/trtllm_engine_model.json b/tensorrt_llm/tools/profiler/nsys_profile_tools/trtllm_engine_model.json new file mode 100644 index 0000000000..9287a6d9c6 --- /dev/null +++ b/tensorrt_llm/tools/profiler/nsys_profile_tools/trtllm_engine_model.json @@ -0,0 +1,62 @@ +{ + "trtllm": { + "llama": { + "Fused_Moe_Kernel|gemm::|fused_moe|bmm_|GemmUniversal": "moe_gemm", + "gemm|nvjet_": "gemm", + "moe|Expert|Moe": "moe", + "CatArrayBatched": "prepare_next", + "ncclDevKernel|AllReduce": "nccl_and_custom_ar", + "RMSNormKernel": "norm", + "topk": "topk", + "act_and_mul_|Activation": "activation", + "Rotary": "rope", + "SoftMax": "softmax", + "flash|splitKreduce|kernel_mha|mmha|fmha": "attn", + "elementwise": "elementwise", + "Quantize|cvt_": "quantize", + "reduce_kernel": "reduce", + "triton": "triton_kernel", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + }, + "ds": { + "fp8_blockscale_gemm": "block_fp8_gemm", + "gemm::GroupProblemShape|Fused_Moe_Kernel|bmm_": "moe_gemm", + "gemm|matmul|nvjet|gemvx": "gemm", + "moe|buildExpertMaps|Moe|Expert|Moe": "moe", + "CatArrayBatched": "prepare_next", + "ncclDevKernel|cross_device_reduce|AllReduce": "nccl_and_custom_ar", + "Norm|_norm_": "norm", + "topk": "topk", + "act_and_mul_|Activation": "activation", + "Rope": "rope", + "elementwise": "elementwise", + "fmha|flash_fwd_kernel": "attn", + "Quantize|fp8_quant|quant_fp8|cvt_": "quantize", + "reduce": "reduce", + "SoftMax": "softmax", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + }, + "gpt-oss": { + "block_fp8|gemm_fp8_blockwise": "block_fp8_gemm", + "fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_|matmul_ogs_|_topk_forward|_combined_routing|_sum_bitmatrix_rows|_compute_writeback_idx": "moe_gemm", + "gemm|matmul|nvjet": "gemm", + "moe|sigmoid|expert|splitKreduce|Moe": "moe", + "CatArrayBatched": "prepare_next", + "ncclDevKernel|cross_device_reduce|AllReduce": "nccl_and_custom_ar", + "Norm|_norm_": "norm", + "sbtopk": "topk", + "act_and_mul_|Activation": "activation", + "Rope": "rope", + "elementwise": "elementwise", + "fp8_quant|quant_fp8|cvt_": "quantize", + "reduce": "reduce", + "SoftMax": "softmax", + "fmha|mha|flash_fwd_kernel": "attn", + "triton": "triton_kernel", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + } + } +} diff --git a/tensorrt_llm/version.py b/tensorrt_llm/version.py index 603fd689b7..93b6027df5 100644 --- a/tensorrt_llm/version.py +++ b/tensorrt_llm/version.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.1.0rc1" +__version__ = "1.1.0rc2" diff --git a/tests/integration/defs/accuracy/references/cnn_dailymail.yaml b/tests/integration/defs/accuracy/references/cnn_dailymail.yaml index 0c469f2d94..dbf2be50f3 100644 --- a/tests/integration/defs/accuracy/references/cnn_dailymail.yaml +++ b/tests/integration/defs/accuracy/references/cnn_dailymail.yaml @@ -5,6 +5,9 @@ google/gemma-3-1b-it: accuracy: 20.699 google/gemma-3-27b-it: - accuracy: 28.90 + - quant_algo: FP8 + kv_cache_quant_algo: FP8 + accuracy: 27.90 gpt2: - accuracy: 18.408 - quant_algo: W8A16 diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index 33c264b9e4..ddf3ab5a86 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -150,8 +150,14 @@ speakleash/Bielik-11B-v2.2-Instruct: accuracy: 40.41 google/gemma-3-1b-it: - accuracy: 25.52 # score getting from lm-eval with HF implementation + - quant_algo: FP8 + kv_cache_quant_algo: FP8 + accuracy: 23.96 google/gemma-3-27b-it: - accuracy: 91.66 + - quant_algo: FP8 + kv_cache_quant_algo: FP8 + accuracy: 90.66 mistralai/Ministral-8B-Instruct-2410: - accuracy: 79.25 - quant_algo: FP8 diff --git a/tests/integration/defs/accuracy/references/mmlu.yaml b/tests/integration/defs/accuracy/references/mmlu.yaml index 9dd1c25d3c..9786e417b2 100644 --- a/tests/integration/defs/accuracy/references/mmlu.yaml +++ b/tests/integration/defs/accuracy/references/mmlu.yaml @@ -119,6 +119,9 @@ google/gemma-3-1b-it: accuracy: 37.5 google/gemma-3-27b-it: - accuracy: 77.80 + - quant_algo: FP8 + kv_cache_quant_algo: FP8 + accuracy: 76.80 Qwen/Qwen2-0.5B-Instruct: - accuracy: 45.30 - quant_algo: FP8 diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 63509cd698..adcc4be979 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -129,37 +129,81 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any], ctx_total_gpus = ctx_tp * ctx_pp gen_total_gpus = gen_tp * gen_pp - env_ctx = os.environ.copy() - env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1" - env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(ctx_total_gpus))) + ctx_urls = disaggregated_server_config["context_servers"]["urls"] + gen_urls = disaggregated_server_config["generation_servers"]["urls"] - env_gen = os.environ.copy() - env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1" - env_gen["CUDA_VISIBLE_DEVICES"] = ",".join( - map(str, range(ctx_total_gpus, ctx_total_gpus + gen_total_gpus))) - ctx_server_args = ctx_args + [ - "--port", "8001", "--extra_llm_api_options", ctx_server_config_path, - f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}" - ] - gen_server_args = gen_args + [ - "--port", "8002", "--extra_llm_api_options", gen_server_config_path, - f"--tp_size={gen_tp}", f"--pp_size={gen_pp}" - ] - if "max_num_tokens" in ctx_server_config: - ctx_server_args.append( - f"--max_num_tokens={ctx_server_config['max_num_tokens']}") - if "max_num_tokens" in gen_server_config: - gen_server_args.append( - f"--max_num_tokens={gen_server_config['max_num_tokens']}") + ctx_ports = [int(url.split(":")[1]) for url in ctx_urls] + gen_ports = [int(url.split(":")[1]) for url in gen_urls] + + ctx_servers = [] + current_gpu_offset = 0 + + for i, port in enumerate(ctx_ports): + env_ctx = os.environ.copy() + env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1" + gpu_range = range(current_gpu_offset, + current_gpu_offset + ctx_total_gpus) + env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_range)) + current_gpu_offset += ctx_total_gpus + + ctx_server_args = ctx_args + [ + "--port", + str(port), "--extra_llm_api_options", ctx_server_config_path, + f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}" + ] + if "max_num_tokens" in ctx_server_config: + ctx_server_args.append( + f"--max_num_tokens={ctx_server_config['max_num_tokens']}") + + ctx_servers.append((env_ctx, ctx_server_args)) + + gen_servers = [] + + for i, port in enumerate(gen_ports): + env_gen = os.environ.copy() + env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1" + gpu_range = range(current_gpu_offset, + current_gpu_offset + gen_total_gpus) + env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_range)) + current_gpu_offset += gen_total_gpus + + gen_server_args = gen_args + [ + "--port", + str(port), "--extra_llm_api_options", gen_server_config_path, + f"--tp_size={gen_tp}", f"--pp_size={gen_pp}" + ] + if "max_num_tokens" in gen_server_config: + gen_server_args.append( + f"--max_num_tokens={gen_server_config['max_num_tokens']}") + + gen_servers.append((env_gen, gen_server_args)) + + @contextlib.contextmanager + def multi_popen(server_configs): + processes = [] + try: + for env, args in server_configs: + proc = popen(args, env=env) + processes.append(proc) + + with contextlib.ExitStack() as stack: + opened_processes = [ + stack.enter_context(proc) for proc in processes + ] + yield opened_processes + except Exception as e: + print( + f"Failed to start disaggregated server processes in multi_popen: {e}" + ) + raise with (MyThreadPoolExecutor(max_workers=16) as - thread_pool, temp_dir, popen(ctx_server_args, env=env_ctx) as - ctx_server, popen(gen_server_args, env=env_gen) as gen_server, + thread_pool, temp_dir, multi_popen(ctx_servers + gen_servers), popen([ trtllm_serve_path, "disaggregated", "-c", disaggregated_serving_config_path, "--server_start_timeout", "3600" - ]) as disaggregated_server): + ])): start_time = time.time() while time.time() - start_time < 3600: time.sleep(1) @@ -225,17 +269,7 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any], return future tokenizer = load_hf_tokenizer(model_name) - - try: - yield DuckLLM(args, tokenizer, generate_async) - finally: - ctx_server.terminate() - gen_server.terminate() - disaggregated_server.terminate() - - ctx_server.wait() - gen_server.wait() - disaggregated_server.wait() + yield DuckLLM(args, tokenizer, generate_async) def run_parallel_test(model_name: str, @@ -244,13 +278,18 @@ def run_parallel_test(model_name: str, ctx_tp: int, gen_pp: int, gen_tp: int, + ctx_instances: int, + gen_instances: int, test_sets: List[LlmapiAccuracyTestHarness], ctx_model: str = None, gen_model: str = None): - if ctx_tp * ctx_pp + gen_tp * gen_pp > get_device_count(): + total_ctx_gpus = ctx_tp * ctx_pp * ctx_instances + total_gen_gpus = gen_tp * gen_pp * gen_instances + if total_ctx_gpus + total_gen_gpus > get_device_count(): pytest.fail( - f"Not enough devices for ctx_pp={ctx_pp}+ctx_tp={ctx_tp} and gen_pp={gen_pp}+gen_tp={gen_tp} test" + f"Not enough devices for {ctx_instances} ctx instances (ctx_pp={ctx_pp}*ctx_tp={ctx_tp}) + {gen_instances} gen instances (gen_pp={gen_pp}*gen_tp={gen_tp}), total: {total_ctx_gpus + total_gen_gpus}" ) + kv_cache_config = { "free_gpu_memory_fraction": 0.5, } @@ -272,17 +311,21 @@ def run_parallel_test(model_name: str, "backend": "DEFAULT" } } + + ctx_urls = [f"localhost:{8001 + i * 2}" for i in range(ctx_instances)] + gen_urls = [f"localhost:{8002 + i * 2}" for i in range(gen_instances)] + disaggregated_server_config = { "hostname": "localhost", "port": 8000, "backend": "pytorch", "context_servers": { - "num_instances": 1, - "urls": ["localhost:8001"] + "num_instances": ctx_instances, + "urls": ctx_urls }, "generation_servers": { - "num_instances": 1, - "urls": ["localhost:8002"] + "num_instances": gen_instances, + "urls": gen_urls } } with launch_disaggregated_llm(disaggregated_server_config, @@ -532,8 +575,9 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): if tp * pp * 2 > get_device_count(): pytest.skip(f"Not enough devices for tp={tp}*pp={pp} test") return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, pp, tp, pp, - tp, [get_accuracy_task(testset)]) + tp, 1, 1, [get_accuracy_task(testset)]) + @pytest.mark.skip_less_device(4) @parametrize_with_ids("ctx_pp", [2, 4]) @parametrize_with_ids("gen_tp", [1, 2]) @pytest.mark.parametrize("testset", ["GSM8K", "MMLU"]) @@ -542,7 +586,13 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): pytest.skip( f"Not enough devices for ctx_pp={ctx_pp}*gen_tp={gen_tp} test") return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1, - gen_tp, [get_accuracy_task(testset)]) + gen_tp, 1, 1, [get_accuracy_task(testset)]) + + @pytest.mark.skip_less_device(4) + @pytest.mark.parametrize("testset", ["GSM8K", "MMLU"]) + def test_multi_instance(self, testset): + return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, 1, 1, 1, 1, + 2, 2, [get_accuracy_task(testset)]) @pytest.mark.skip_less_device_memory(140000) diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index da64969337..d761ae6851 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -30,6 +30,8 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness): return { 'skip_tokenizer_init': False, 'trust_remote_code': True, + # TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/7142): + # AutoDeploy does not support cache reuse yet. 'kv_cache_config': { 'enable_block_reuse': False, }, diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 7af8c437d0..0390c97e64 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -15,6 +15,7 @@ import os import pytest +import torch from defs.conftest import get_sm_version from tensorrt_llm import LLM @@ -398,6 +399,40 @@ class TestLlama3_2_1B(LlmapiAccuracyTestHarness): task = CnnDailymail(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_hopper + @pytest.mark.skip_less_device(4) + @pytest.mark.parametrize("disable_overlap_scheduler", [True, False]) + @pytest.mark.parametrize("pp_size", [2, 4], ids=["pp2", "pp4"]) + def test_return_logits_pp(self, pp_size, disable_overlap_scheduler): + prompts = ["A B C"] + + llm = LLM(model=self.MODEL_PATH, + pipeline_parallel_size=pp_size, + disable_overlap_scheduler=disable_overlap_scheduler) + + sampling_params = SamplingParams(max_tokens=8, + return_context_logits=True, + return_generation_logits=True, + logprobs=True) + + with llm: + for output in llm.generate(prompts, + sampling_params=sampling_params): + assert output.context_logits is not None + # NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315] + expected_len = len(prompts[0].split()) + 1 + assert expected_len == output.context_logits.shape[0] + + gen_logits = output.outputs[0].generation_logits + assert gen_logits is not None + assert gen_logits.ndim == 2 + assert gen_logits.shape[0] == sampling_params.max_tokens + assert torch.argmax( + gen_logits, dim=1).tolist() == output.outputs[0].token_ids + + assert len( + output.outputs[0].logprobs) == sampling_params.max_tokens + class TestLlama3_2_3B(LlmapiAccuracyTestHarness): MODEL_NAME = "meta-llama/Llama-3.2-3B" @@ -842,6 +877,25 @@ class TestGemma3_27BInstruct(LlmapiAccuracyTestHarness): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + def test_fp8_prequantized(self): + # Disabling kv cache reuse as a WAR to deal with gaps in kernel support for Gemma3's non-inclusive sliding window size. + kv_cache_config = KvCacheConfig(enable_block_reuse=False, + enable_partial_reuse=False, + dtype="fp8") + # Note: This has only the LLM part quantized. Vision part is in bfloat16. + prequantized_model_path = f"{llm_models_root()}/gemma/gemma-3-27b-it-fp8/" + with LLM(prequantized_model_path, + kv_cache_config=kv_cache_config, + attn_backend="FLASHINFER", + cuda_graph_config=None) as llm: + assert llm.args.quant_config.quant_algo == QuantAlgo.FP8 + task = CnnDailymail(self.MODEL_NAME) + task.evaluate(llm) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness): MODEL_NAME = "google/gemma-3-1b-it" @@ -875,6 +929,8 @@ class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness): assert llm.args.quant_config.quant_algo == QuantAlgo.FP8 task = CnnDailymail(self.MODEL_NAME) task.evaluate(llm) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) task = MMLU(self.MODEL_NAME) task.evaluate(llm) @@ -1734,6 +1790,7 @@ class TestKimiK2(LlmapiAccuracyTestHarness): MODEL_PATH = f"{llm_models_root()}/Kimi-K2-Instruct" @pytest.mark.skip_less_mpi_world_size(8) + @skip_post_blackwell @skip_pre_hopper @pytest.mark.parametrize( "tp_size,pp_size,ep_size,fp8kv,attention_dp,cuda_graph,overlap_scheduler,max_batch_size", @@ -2445,11 +2502,12 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness): [ (8, 1, 8, True, True, True, "CUTLASS", False), (8, 1, 8, True, True, True, "TRTLLM", False), - (8, 1, 8, False, False, False, "TRTLLM", True), + (8, 1, 8, True, True, True, "TRTLLM", True), ], ids=[ - "latency_moe_cutlass", "latency_moe_trtllm", - "latency_moe_trtllm_eagle3" + "latency_moe_cutlass", + "latency_moe_trtllm", + "latency_moe_trtllm_eagle3", ], ) def test_nvfp4(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph, @@ -2484,6 +2542,50 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_blackwell + @pytest.mark.skip_less_mpi_world_size(4) + @pytest.mark.parametrize( + "tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler,moe_backend,eagle3", + [ + (4, 1, 4, False, False, False, "TRTLLM", + True), # TP8 has bug when we use TRTLLM moe backend and eagle3 + ], + ids=[ + "latency_moe_trtllm_eagle3", + ], + ) + def test_nvfp4_4gpus(self, tp_size, pp_size, ep_size, attention_dp, + cuda_graph, overlap_scheduler, moe_backend, eagle3): + + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None, + moe_config=MoeConfig(backend=moe_backend)) + + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4, + enable_block_reuse=not eagle3) + spec_config = None + if eagle3: + spec_config = EagleDecodingConfig( + max_draft_len=2, + speculative_model_dir= + f"{llm_models_root()}/Qwen3/qwen3-235B-eagle3/", + eagle3_one_model=True) + with LLM( + f"{llm_models_root()}/Qwen3/saved_models_Qwen3-235B-A22B_nvfp4_hf", + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + moe_expert_parallel_size=ep_size, + **pytorch_config, + enable_attention_dp=attention_dp, + kv_cache_config=kv_cache_config, + speculative_config=spec_config) as llm: + + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + class TestPhi4MiniInstruct(LlmapiAccuracyTestHarness): MODEL_NAME = "microsoft/Phi-4-mini-instruct" diff --git a/tests/integration/defs/common.py b/tests/integration/defs/common.py index a61a5b8c28..7136bac02d 100644 --- a/tests/integration/defs/common.py +++ b/tests/integration/defs/common.py @@ -956,3 +956,23 @@ def get_dummy_spec_decoding_heads(hf_model_dir, export_hf_checkpoint(model, dtype=model.config.torch_dtype, export_dir=os.path.join(save_dir, 'fp8')) + + +def get_mmlu_accuracy(output): + mmlu_line = None + for line in output.split('\n'): + if "MMLU weighted average accuracy:" in line: + mmlu_line = line + break + + if mmlu_line is None: + raise Exception( + f"Could not find 'MMLU weighted average accuracy:' in output. Full output:\n{output}" + ) + + mmlu_accuracy = float( + mmlu_line.split("MMLU weighted average accuracy: ")[1].split(" (")[0]) + + print(f"MMLU weighted average accuracy is: {mmlu_accuracy}") + + return mmlu_accuracy diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index 0d86204ecb..24000b1f80 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -1235,7 +1235,7 @@ def get_config_for_benchmark(model_root, backend): "num_instances": 1, "max_batch_size": 2, "max_num_tokens": 384, - "max_seq_len": 384, + "max_seq_len": 320, "tensor_parallel_size": 1, "pipeline_parallel_size": 1, "disable_overlap_scheduler": True, @@ -1251,7 +1251,7 @@ def get_config_for_benchmark(model_root, backend): "pipeline_parallel_size": 1, "max_batch_size": 2, "max_num_tokens": 384, - "max_seq_len": 384, + "max_seq_len": 320, "cache_transceiver_config": { "backend": backend, "max_tokens_in_buffer": 512, diff --git a/tests/integration/defs/llmapi/_run_llmapi_llm.py b/tests/integration/defs/llmapi/_run_llmapi_llm.py index 854af24efa..14dde17077 100644 --- a/tests/integration/defs/llmapi/_run_llmapi_llm.py +++ b/tests/integration/defs/llmapi/_run_llmapi_llm.py @@ -1,25 +1,32 @@ #!/usr/bin/env python3 import os +from typing import Optional import click -from tensorrt_llm._tensorrt_engine import LLM -from tensorrt_llm.llmapi import BuildConfig, SamplingParams +from tensorrt_llm._tensorrt_engine import LLM as TrtLLM +from tensorrt_llm.llmapi import LLM, BuildConfig, SamplingParams @click.command() @click.option("--model_dir", type=str, required=True) @click.option("--tp_size", type=int, default=1) @click.option("--engine_dir", type=str, default=None) -def main(model_dir: str, tp_size: int, engine_dir: str): +@click.option("--backend", type=str, default=None) +def main(model_dir: str, tp_size: int, engine_dir: str, backend: Optional[str]): build_config = BuildConfig() build_config.max_batch_size = 8 build_config.max_input_len = 256 build_config.max_seq_len = 512 - llm = LLM(model_dir, - tensor_parallel_size=tp_size, - build_config=build_config) + backend = backend or "tensorrt" + assert backend in ["pytorch", "tensorrt"] + + llm_cls = TrtLLM if backend == "tensorrt" else LLM + + kwargs = {} if backend == "pytorch" else {"build_config": build_config} + + llm = llm_cls(model_dir, tensor_parallel_size=tp_size, **kwargs) if engine_dir is not None and os.path.abspath( engine_dir) != os.path.abspath(model_dir): diff --git a/tests/integration/defs/llmapi/test_llm_api_qa.py b/tests/integration/defs/llmapi/test_llm_api_qa.py new file mode 100644 index 0000000000..def4be0895 --- /dev/null +++ b/tests/integration/defs/llmapi/test_llm_api_qa.py @@ -0,0 +1,70 @@ +# Confirm that the default backend is changed +import os + +from defs.common import venv_check_output + +from ..conftest import llm_models_root + +model_path = llm_models_root() + "/llama-models-v3/llama-v3-8b-instruct-hf" + + +class TestLlmDefaultBackend: + """ + Check that the default backend is PyTorch for v1.0 breaking change + """ + + def test_llm_args_type_default(self, llm_root, llm_venv): + # Keep the complete example code here + from tensorrt_llm.llmapi import LLM, KvCacheConfig, TorchLlmArgs + + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4) + llm = LLM(model=model_path, kv_cache_config=kv_cache_config) + + # The default backend should be PyTorch + assert llm.args.backend == "pytorch" + assert isinstance(llm.args, TorchLlmArgs) + + for output in llm.generate(["Hello, world!"]): + print(output) + + def test_llm_args_type_tensorrt(self, llm_root, llm_venv): + # Keep the complete example code here + from tensorrt_llm._tensorrt_engine import LLM + from tensorrt_llm.llmapi import KvCacheConfig, TrtLlmArgs + + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4) + + llm = LLM(model=model_path, kv_cache_config=kv_cache_config) + + # If the backend is TensorRT, the args should be TrtLlmArgs + assert llm.args.backend in ("tensorrt", None) + assert isinstance(llm.args, TrtLlmArgs) + + for output in llm.generate(["Hello, world!"]): + print(output) + + def test_llm_args_logging(self, llm_root, llm_venv): + # It should print the backend in the log + script_path = os.path.join(os.path.dirname(__file__), + "_run_llmapi_llm.py") + print(f"script_path: {script_path}") + + # Test with pytorch backend + pytorch_cmd = [ + script_path, "--model_dir", model_path, "--backend", "pytorch" + ] + + pytorch_output = venv_check_output(llm_venv, pytorch_cmd) + + # Check that pytorch backend keyword appears in logs + assert "Using LLM with PyTorch backend" in pytorch_output, f"Expected 'pytorch' in logs, got: {pytorch_output}" + + # Test with tensorrt backend + tensorrt_cmd = [ + script_path, "--model_dir", model_path, "--backend", "tensorrt" + ] + + tensorrt_output = venv_check_output(llm_venv, tensorrt_cmd) + + # Check that tensorrt backend keyword appears in logs + assert "Using LLM with TensorRT backend" in tensorrt_output, f"Expected 'tensorrt' in logs, got: {tensorrt_output}" diff --git a/tests/integration/defs/perf/pytorch_model_config.py b/tests/integration/defs/perf/pytorch_model_config.py index e095f2b85a..15354b36ea 100644 --- a/tests/integration/defs/perf/pytorch_model_config.py +++ b/tests/integration/defs/perf/pytorch_model_config.py @@ -159,14 +159,14 @@ def get_model_yaml_config(model_label: str, 'llama_v4_maverick_17b_128e_instruct_fp8' ], 'config': { - 'use_cuda_graph': - True, - 'cuda_graph_padding_enabled': - True, - 'cuda_graph_batch_sizes': [ - 1, 2, 4, 8, 16, 32, 64, 128, 256, 384, 512, 1024, 2048, - 4096, 8192 - ] + 'cuda_graph_config': { + 'enable_padding': + True, + 'batch_sizes': [ + 1, 2, 4, 8, 16, 32, 64, 128, 256, 384, 512, 1024, 2048, + 4096, 8192 + ] + } } } ] @@ -191,15 +191,17 @@ def get_model_yaml_config(model_label: str, } if 'phi_4_multimodal_instruct' in model_label: lora_config['lora_config']['lora_target_modules'] = [ - "attn_qkv", "attn_dense", "mlp_h_to_4h", "mlp_4h_to_h" + "attn_qkv", "attn_dense", "mlp_gate_up", "mlp_4h_to_h" ] lora_config['lora_config']['trtllm_modules_to_hf_modules'] = { "attn_qkv": "qkv_proj", "attn_dense": "o_proj", - "mlp_h_to_4h": "gate_up_proj", + "mlp_gate_up": "gate_up_proj", "mlp_4h_to_h": "down_proj" } lora_config['lora_config']['max_lora_rank'] = 320 + lora_config['lora_config'][ + 'swap_gate_up_proj_lora_b_weight'] = False base_config.update(lora_config) kv_cache_config = base_config.get('kv_cache_config', KvCacheConfig()) diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index da4faf578b..ef615843bb 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -28,8 +28,9 @@ from defs.common import convert_weights from defs.trt_test_alternative import (check_call, check_call_negative_test, check_output) -from .common import (PluginOptions, convert_weights, prune_checkpoint, - quantize_data, refit_model, venv_check_call) +from .common import (PluginOptions, convert_weights, get_mmlu_accuracy, + prune_checkpoint, quantize_data, refit_model, + venv_check_call) from .conftest import (llm_models_root, skip_no_sm120, skip_nvlink_inactive, skip_post_blackwell, skip_pre_blackwell, skip_pre_hopper, tests_path, unittest_path) @@ -42,6 +43,7 @@ if TEST_MEM_USAGE: os.environ['TLLM_LOG_LEVEL'] = 'INFO' _MEM_FRACTION_50 = 0.5 +_MEM_FRACTION_80 = 0.8 _MEM_FRACTION_95 = 0.95 @@ -2404,15 +2406,15 @@ def test_ptp_quickstart_multimodal_phi4mm(llm_root, llm_venv, modality): } expected_keywords = { "image": [ - ["image", "depicts", "mountain", "half", "rock"], - ["road", "car", "lane", "traffic", "bus"], + ["object", "mountain", "weather", "clear", "clouds"], + ["traffic", "road", "vehicles", "cars", "bus"], ], "audio": [ ["what", "is", "the", "traffic", "sign", "in", "image"], ["what", "is", "shown", "in", "this", "image"], ], "image_audio": [ - ["image", "depicts", "Grand", "rock", "scene"], + ["image", "depicts", "scenic", "famous", "landmark"], ], } @@ -2446,6 +2448,109 @@ def test_ptp_quickstart_multimodal_phi4mm(llm_root, llm_venv, modality): print("All answers are correct!") +@pytest.mark.skip_less_device(2) +@pytest.mark.skip_less_device_memory(80000) +@pytest.mark.parametrize("model_name,model_path", [ + ("gemma-3-27b-it", "gemma/gemma-3-27b-it"), + ("mistral-small-3.1-24b-instruct", "Mistral-Small-3.1-24B-Instruct-2503"), + ("Phi-4-multimodal-instruct", "multimodals/Phi-4-multimodal-instruct"), +]) +def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name, + model_path): + example_root = Path(os.path.join(llm_root, "examples", "llm-api")) + test_data_root = Path( + os.path.join(llm_models_root(), "multimodals", "test_data")) + + print(f"Accuracy test {model_name} image mode with example inputs.") + + # Define accuracy inputs for image modality + accuracy_inputs = { + "image": { + "prompt": [ + "Describe the object and the weather condition in the image.", + "Describe the traffic condition on the road in the image.", + ], + "media": [ + str(test_data_root / "inpaint.png"), + str(test_data_root / "61.jpg"), + ], + } + } + + # Define expected keywords for each model + expected_keywords = { + "gemma-3-27b-it": { + "image": [ + ["half", "dome", "yosemite", "landmark", "rounded"], + ["flowing", "traffic", "vehicles", "road", "Changi"], + ], + }, + "mistral-small-3.1-24b-instruct": { + "image": [ + ["scenic", "rock", "landscape", "monolith", "formation"], + [ + "multi-lane", "highway", "moderate", "traffic", "flow", + "vehicles", "congestion" + ], + ], + }, + "Phi-4-multimodal-instruct": { + "image": [ + ["image", "depicts", "mountain", "half", "rock"], + ["road", "car", "lane", "traffic", "bus"], + ], + }, + } + + # Build command for image modality + cmd = [ + str(example_root / "quickstart_multimodal.py"), + "--model_dir", + f"{llm_models_root()}/{model_path}", + "--modality", + "image", + "--prompt", + *accuracy_inputs["image"]["prompt"], + "--media", + *accuracy_inputs["image"]["media"], + "--tp_size", + "2", + ] + + # Add model-specific configurations + if model_name == "gemma-3-27b-it": + # Gemma3 VLM needs a custom mask which is only supported by flashinfer backend currently. + # Custom mask involves bidirectional masking of image tokens in context phase. To get this + # correct, chunked prefill and kv cache reuse need to be turned off. + cmd.append("--image_format=pil") + cmd.append("--attention_backend=FLASHINFER") + cmd.append("--disable_kv_cache_reuse") + elif model_name == "Phi-4-multimodal-instruct": + # Set max_seq_len to 4096 to use short rope factor. + cmd.append("--max_seq_len=4096") + cmd.append("--load_lora") + cmd.append("--auto_model_name") + cmd.append("Phi4MMForCausalLM") + + output = llm_venv.run_cmd(cmd, caller=check_output) + + # Set match ratio based on model + match_ratio = 4.0 / 5 + if model_name == "Phi-4-multimodal-instruct": + match_ratio = 0.6 + + # Check output accuracy + for prompt_output, prompt_keywords in zip( + parse_output(output), expected_keywords[model_name]["image"]): + matches = [ + keyword in prompt_output.lower() for keyword in prompt_keywords + ] + obs_match_ratio = 1. * sum(matches) / len(matches) + assert obs_match_ratio >= match_ratio, f"Incorrect output!\nGenerated \"{prompt_output}\"\nExpected keywords \"{prompt_keywords}\"\n Matched keywords: {matches}\n Observed match ratio {obs_match_ratio} below threshold {match_ratio}" + + print("All answers are correct!") + + @pytest.mark.parametrize("model_name,model_path", [ ("BertForSequenceClassification", "bert/bert-base-uncased-yelp-polarity"), ]) @@ -2574,4 +2679,43 @@ def test_ptp_quickstart_advanced_llama_multi_nodes(llm_root, llm_venv, check_call(" ".join(run_cmd), shell=True, env=llm_venv._new_env) -# End of Pivot-To-Python examples +@pytest.mark.timeout(5400) +@pytest.mark.skip_less_device_memory(80000) +@pytest.mark.skip_less_device(4) +@pytest.mark.parametrize("eval_task", ["mmlu"]) +@pytest.mark.parametrize("tp_size,pp_size,ep_size", [(16, 1, 8), (8, 2, 8)], + ids=["tp16", "tp8pp2"]) +@pytest.mark.parametrize("model_path", [ + pytest.param('llama-3.3-models/Llama-3.3-70B-Instruct', + marks=skip_pre_hopper), + pytest.param('llama4-models/Llama-4-Maverick-17B-128E-Instruct', + marks=skip_pre_hopper), + pytest.param('llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8', + marks=skip_pre_hopper), + pytest.param('Qwen3/Qwen3-235B-A22B', marks=skip_pre_hopper), + pytest.param('Qwen3/saved_models_Qwen3-235B-A22B_nvfp4_hf', + marks=skip_pre_blackwell), + pytest.param('DeepSeek-R1/DeepSeek-R1-0528-FP4', marks=skip_pre_blackwell), +]) +def test_multi_nodes_eval(llm_venv, model_path, tp_size, pp_size, ep_size, + eval_task): + if "Llama-4" in model_path and tp_size == 16: + pytest.skip("Llama-4 with tp16 is not supported") + + mmlu_threshold = 81.5 + run_cmd = [ + "trtllm-llmapi-launch", + "trtllm-eval", + f"--model={llm_models_root()}/{model_path}", + f"--ep_size={ep_size}", + f"--tp_size={tp_size}", + f"--pp_size={pp_size}", + f"--kv_cache_free_gpu_memory_fraction={_MEM_FRACTION_80}", + "--max_batch_size=32", + eval_task, + ] + output = check_output(" ".join(run_cmd), shell=True, env=llm_venv._new_env) + + if os.environ.get("SLURM_PROCID", '0') == '0': + mmlu_accuracy = get_mmlu_accuracy(output) + assert mmlu_accuracy > mmlu_threshold, f"MMLU accuracy {mmlu_accuracy} is less than threshold {mmlu_threshold}" diff --git a/tests/integration/defs/triton_server/build_engines.py b/tests/integration/defs/triton_server/build_engines.py index b8298d7309..1155b47d4a 100755 --- a/tests/integration/defs/triton_server/build_engines.py +++ b/tests/integration/defs/triton_server/build_engines.py @@ -1763,3 +1763,75 @@ def prepare_rcca_nvbug_4714193_engine(tensorrt_llm_example_root, assert os.path.exists(engine_dir), f"{engine_dir} does not exists." return engine_dir + + +def prepare_mistral3_pixtral_engine(tensorrt_llm_multimodal_example_root, + tensorrt_llm_llama_example_root, + mistral_small_model_root): + # Convert Mistral3 from HF + model_base_name = os.path.basename(mistral_small_model_root.rstrip("/")) + ckpt_dir = os.path.join(tensorrt_llm_multimodal_example_root, "model_dir", + model_base_name) + convert_cmd = [ + "python3", + f"{tensorrt_llm_llama_example_root}/convert_checkpoint.py", + "--dtype=bfloat16", + f"--model_dir={mistral_small_model_root}", + f"--output_dir={ckpt_dir}", + ] + + # Build Mistral3 LLM engine + engine_dir = os.path.join(tensorrt_llm_multimodal_example_root, + "engine_dir", model_base_name) + + build_cmd = [ + "trtllm-build", + f"--checkpoint_dir={ckpt_dir}", + "--max_batch_size=4", + "--max_input_len=8192", + "--max_seq_len=8192", + # Allow an arbitrary number of image tokens by setting: + # max_multimodal_len = max_batch_size * max_input_len + "--max_multimodal_len=32768", + "--use_paged_context_fmha=enable", + f"--output_dir={engine_dir}", + ] + + # Build Pixtral visual encoder engine + multimodal_engine_dir = os.path.join(tensorrt_llm_multimodal_example_root, + "tmp", "trt_engines", model_base_name, + "multimodal_encoder") + build_visual_engine_cmd = [ + "python3", + "build_multimodal_engine.py", + "--model_type=pixtral", + f"--model_path={mistral_small_model_root}", + f"--output_dir={multimodal_engine_dir}", + "--max_batch_size=2", + ] + + append_timing_cache_args(build_cmd) + convert_cmd = " ".join(convert_cmd) + build_cmd = " ".join(build_cmd) + build_visual_engine_cmd = " ".join(build_visual_engine_cmd) + if not os.path.exists(engine_dir) or not os.path.exists( + multimodal_engine_dir): + check_call(install_requirement_cmd, + shell=True, + cwd=tensorrt_llm_llama_example_root) + check_call(convert_cmd, shell=True) + check_call(build_cmd, shell=True) + check_call(build_visual_engine_cmd, + shell=True, + cwd=tensorrt_llm_multimodal_example_root) + + else: + print_info(f"Reusing engine: {engine_dir}") + print_info(f"Skipped: {convert_cmd}") + print_info(f"Skipped: {build_cmd}") + print_info(f"Skipped: {build_visual_engine_cmd}") + + assert os.path.exists(engine_dir), f"{engine_dir} does not exists." + assert os.path.exists( + multimodal_engine_dir), f"{multimodal_engine_dir} does not exists." + return engine_dir, multimodal_engine_dir diff --git a/tests/integration/defs/triton_server/common.py b/tests/integration/defs/triton_server/common.py index 174c1a7f58..fb41bdc00c 100644 --- a/tests/integration/defs/triton_server/common.py +++ b/tests/integration/defs/triton_server/common.py @@ -247,7 +247,8 @@ def modify_ib_config_pbtxt(REPO_PATH, CROSS_KV_CACHE_FRACTION="", ENCODER_INPUT_FEATURES_DTYPE="TYPE_FP16", GUIDED_DECODING_BACKEND="", - XGRAMMAR_TOKENIZER_INFO_PATH=""): + XGRAMMAR_TOKENIZER_INFO_PATH="", + PROMPT_EMBEDDING_TABLE_DTYPE="TYPE_FP16"): fill_template_py = os.path.join(llm_backend_repo_root, "tools", "fill_template.py") tensorrt_llm_config = os.path.join(llm_backend_repo_root, REPO_PATH, @@ -274,6 +275,7 @@ def modify_ib_config_pbtxt(REPO_PATH, check_call( f"python3 {fill_template_py} -i {multimodal_enc_config} triton_max_batch_size:{TRITON_MAX_BATCH_SIZE}," \ f"multimodal_model_path:{MULTIMODAL_ENGINE_PATH},encoder_input_features_data_type:{ENCODER_INPUT_FEATURES_DTYPE}," \ + f"prompt_embedding_table_data_type:{PROMPT_EMBEDDING_TABLE_DTYPE}," \ f"hf_model_path:{TOKENIZER_PATH}", shell=True) check_call( @@ -305,6 +307,7 @@ def modify_ib_config_pbtxt(REPO_PATH, f"lookahead_ngram_size:{EXECUTOR_LOOKAHEAD_NGRAM}," \ f"lookahead_verification_set_size:{EXECUTOR_LOOKAHEAD_VERIFICATION_SET}," \ f"encoder_input_features_data_type:{ENCODER_INPUT_FEATURES_DTYPE}," \ + f"prompt_embedding_table_data_type:{PROMPT_EMBEDDING_TABLE_DTYPE}," \ f"participant_ids:{PARTICIPANT_IDS_DRAFT}," \ f"logits_datatype:TYPE_FP32'", shell=True) @@ -329,6 +332,7 @@ def modify_ib_config_pbtxt(REPO_PATH, f"lookahead_ngram_size:{EXECUTOR_LOOKAHEAD_NGRAM}," \ f"lookahead_verification_set_size:{EXECUTOR_LOOKAHEAD_VERIFICATION_SET}," \ f"encoder_input_features_data_type:{ENCODER_INPUT_FEATURES_DTYPE}," \ + f"prompt_embedding_table_data_type:{PROMPT_EMBEDDING_TABLE_DTYPE}," \ f"participant_ids:{PARTICIPANT_IDS_TARGET}," \ f"logits_datatype:TYPE_FP32'", shell=True) @@ -348,7 +352,8 @@ def modify_ib_config_pbtxt(REPO_PATH, check_call( f"python3 {fill_template_py} -i {tensorrt_llm_bls_config} triton_max_batch_size:{TRITON_MAX_BATCH_SIZE}," \ f"decoupled_mode:{DECOUPLED_MODE},accumulate_tokens:{ACCUMULATE_TOKEN},bls_instance_count:{BLS_INSTANCE_COUNT}," \ - f"tensorrt_llm_model_name:{TENSORRT_LLM_TARGET_MODEL_NAME},tensorrt_llm_draft_model_name:{TENSORRT_LLM_DRAFT_MODEL_NAME},logits_datatype:TYPE_FP32", + f"tensorrt_llm_model_name:{TENSORRT_LLM_TARGET_MODEL_NAME},tensorrt_llm_draft_model_name:{TENSORRT_LLM_DRAFT_MODEL_NAME},logits_datatype:TYPE_FP32," \ + f"prompt_embedding_table_data_type:{PROMPT_EMBEDDING_TABLE_DTYPE}", shell=True) check_call( @@ -363,6 +368,7 @@ def modify_ib_config_pbtxt(REPO_PATH, f"gpu_weights_percent:{GPU_WEIGHTS_PERCENT},encoder_engine_dir:{ENCODER_ENGINE_PATH},max_queue_size:{MAX_QUEUE_SIZE}," \ f"enable_context_fmha_fp32_acc:{ENABLE_CONTEXT_FMHA_FP32_ACC}," \ f"encoder_input_features_data_type:{ENCODER_INPUT_FEATURES_DTYPE}," \ + f"prompt_embedding_table_data_type:{PROMPT_EMBEDDING_TABLE_DTYPE}," \ f"participant_ids:{PARTICIPANT_IDS}," \ f"logits_datatype:TYPE_FP32,guided_decoding_backend:{GUIDED_DECODING_BACKEND},tokenizer_dir:{TOKENIZER_PATH},xgrammar_tokenizer_info_path:{XGRAMMAR_TOKENIZER_INFO_PATH}'", shell=True) diff --git a/tests/integration/defs/triton_server/conftest.py b/tests/integration/defs/triton_server/conftest.py index 2afebbee14..d66bc0f09d 100644 --- a/tests/integration/defs/triton_server/conftest.py +++ b/tests/integration/defs/triton_server/conftest.py @@ -564,6 +564,19 @@ def tiny_llama_model_root(): return tiny_llama_model_root +@pytest.fixture(scope="session") +def mistral_small_3_1_24b_model_root(): + models_root = llm_models_root() + assert models_root, "Did you set LLM_MODELS_ROOT?" + model_root = os.path.join(models_root, + "Mistral-Small-3.1-24B-Instruct-2503") + + assert os.path.exists( + model_root + ), f"{model_root} does not exist under NFS LLM_MODELS_ROOT dir" + return model_root + + # Returns an array of total memory for each available device @pytest.fixture(scope="session") def total_gpu_memory_mib(): diff --git a/tests/integration/defs/triton_server/test.sh b/tests/integration/defs/triton_server/test.sh index e0819738ef..7782ecfe97 100755 --- a/tests/integration/defs/triton_server/test.sh +++ b/tests/integration/defs/triton_server/test.sh @@ -163,6 +163,7 @@ print_test_params () { echo "DECODING_MODE: ${DECODING_MODE}" echo "MAX_QUEUE_SIZE: ${MAX_QUEUE_SIZE}" echo "ENABLE_CONTEXT_FMHA_FP32_ACC: ${ENABLE_CONTEXT_FMHA_FP32_ACC}" + echo "PROMPT_EMBEDDING_TABLE_DTYPE: ${PROMPT_EMBEDDING_TABLE_DTYPE}" echo "run_all_tests: ${run_all_tests}" echo "----------------------------------" } @@ -180,26 +181,26 @@ fill_triton_repo () { fi echo "Filling triton repository at ${TRITON_REPO}/tensorrt_llm with engine ${DECODER_ENGINE_PATH}" - python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm/config.pbtxt triton_backend:${BACKEND},engine_dir:${DECODER_ENGINE_PATH},decoupled_mode:${DECOUPLED_MODE},max_tokens_in_paged_kv_cache:${MAX_TOKENS_IN_KV_CACHE},max_attention_window_size:${MAX_ATTENTION_WINDOW_SIZE},batch_scheduler_policy:${BATCH_SCHEDULER_POLICY},batching_strategy:${BATCHING_STRATEGY},kv_cache_free_gpu_mem_fraction:${KV_CACHE_FREE_GPU_MEM_FRACTION},exclude_input_in_output:${EXCLUDE_INPUT_IN_OUTPUT},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},max_beam_width:${MAX_BEAM_WIDTH},enable_kv_cache_reuse:${ENABLE_KV_CACHE_REUSE},normalize_log_probs:${NORMALIZE_LOG_PROBS},enable_chunked_context:${ENABLE_CHUNKED_CONTEXT},gpu_device_ids:${GPU_DEVICE_IDS},decoding_mode:${DECODING_MODE},max_queue_size:${MAX_QUEUE_SIZE},enable_context_fmha_fp32_acc:${ENABLE_CONTEXT_FMHA_FP32_ACC},encoder_input_features_data_type:${ENCODER_INPUT_FEATURES_DTYPE},logits_datatype:TYPE_FP32,lookahead_window_size:${LOOKAHEAD_WINDOW_SIZE},lookahead_ngram_size:${LOOKAHEAD_NGRAM_SIZE},lookahead_verification_set_size:${LOOKAHEAD_VERIFICATION_SET_SIZE} + python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm/config.pbtxt triton_backend:${BACKEND},engine_dir:${DECODER_ENGINE_PATH},decoupled_mode:${DECOUPLED_MODE},max_tokens_in_paged_kv_cache:${MAX_TOKENS_IN_KV_CACHE},max_attention_window_size:${MAX_ATTENTION_WINDOW_SIZE},batch_scheduler_policy:${BATCH_SCHEDULER_POLICY},batching_strategy:${BATCHING_STRATEGY},kv_cache_free_gpu_mem_fraction:${KV_CACHE_FREE_GPU_MEM_FRACTION},exclude_input_in_output:${EXCLUDE_INPUT_IN_OUTPUT},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},max_beam_width:${MAX_BEAM_WIDTH},enable_kv_cache_reuse:${ENABLE_KV_CACHE_REUSE},normalize_log_probs:${NORMALIZE_LOG_PROBS},enable_chunked_context:${ENABLE_CHUNKED_CONTEXT},gpu_device_ids:${GPU_DEVICE_IDS},decoding_mode:${DECODING_MODE},max_queue_size:${MAX_QUEUE_SIZE},enable_context_fmha_fp32_acc:${ENABLE_CONTEXT_FMHA_FP32_ACC},encoder_input_features_data_type:${ENCODER_INPUT_FEATURES_DTYPE},prompt_embedding_table_data_type:${PROMPT_EMBEDDING_TABLE_DTYPE},logits_datatype:TYPE_FP32,lookahead_window_size:${LOOKAHEAD_WINDOW_SIZE},lookahead_ngram_size:${LOOKAHEAD_NGRAM_SIZE},lookahead_verification_set_size:${LOOKAHEAD_VERIFICATION_SET_SIZE} python3 tools/fill_template.py -i ${TRITON_REPO}/preprocessing/config.pbtxt tokenizer_dir:${TOKENIZER_PATH},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},preprocessing_instance_count:${PREPROCESSING_INSTANCE_COUNT} python3 tools/fill_template.py -i ${TRITON_REPO}/postprocessing/config.pbtxt tokenizer_dir:${TOKENIZER_PATH},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},postprocessing_instance_count:${POSTPROCESSING_INSTANCE_COUNT} python3 tools/fill_template.py -i ${TRITON_REPO}/ensemble/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},logits_datatype:TYPE_FP32 if [ "${DRAFT_ENGINE_PATH}" != "" ] && [ "${DRAFT_ENGINE_PATH}" != "skip" ] && [ "${TARGET_ENGINE_PATH}" != "" ] && [ "${TARGET_ENGINE_PATH}" != "skip" ]; then - python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},accumulate_tokens:${ACCUMULATE_TOKEN},bls_instance_count:${BLS_INSTANCE_COUNT},tensorrt_llm_model_name:${TENSORRT_LLM_TARGET_MODEL_NAME},logits_datatype:TYPE_FP32,tensorrt_llm_draft_model_name:${TENSORRT_LLM_DRAFT_MODEL_NAME} + python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},accumulate_tokens:${ACCUMULATE_TOKEN},bls_instance_count:${BLS_INSTANCE_COUNT},tensorrt_llm_model_name:${TENSORRT_LLM_TARGET_MODEL_NAME},logits_datatype:TYPE_FP32,tensorrt_llm_draft_model_name:${TENSORRT_LLM_DRAFT_MODEL_NAME},prompt_embedding_table_data_type:${PROMPT_EMBEDDING_TABLE_DTYPE} else - python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},accumulate_tokens:${ACCUMULATE_TOKEN},bls_instance_count:${BLS_INSTANCE_COUNT},tensorrt_llm_model_name:${TENSORRT_LLM_MODEL_NAME},logits_datatype:TYPE_FP32,tensorrt_llm_draft_model_name:"" + python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},accumulate_tokens:${ACCUMULATE_TOKEN},bls_instance_count:${BLS_INSTANCE_COUNT},tensorrt_llm_model_name:${TENSORRT_LLM_MODEL_NAME},logits_datatype:TYPE_FP32,tensorrt_llm_draft_model_name:"",prompt_embedding_table_data_type:${PROMPT_EMBEDDING_TABLE_DTYPE} fi if [ "${DRAFT_ENGINE_PATH}" != "" ] && [ "${DRAFT_ENGINE_PATH}" != "skip" ]; then echo "Filling triton repository at ${TRITON_REPO}/tensorrt_llm_draft with engine ${DRAFT_ENGINE_PATH}" - python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm_draft/config.pbtxt triton_backend:${BACKEND},engine_dir:${DRAFT_ENGINE_PATH},decoupled_mode:${DECOUPLED_MODE},max_tokens_in_paged_kv_cache:${MAX_TOKENS_IN_KV_CACHE},max_attention_window_size:${MAX_ATTENTION_WINDOW_SIZE},batch_scheduler_policy:${BATCH_SCHEDULER_POLICY},batching_strategy:${BATCHING_STRATEGY},kv_cache_free_gpu_mem_fraction:${KV_CACHE_FREE_GPU_MEM_FRACTION},exclude_input_in_output:${EXCLUDE_INPUT_IN_OUTPUT},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},max_beam_width:${MAX_BEAM_WIDTH},enable_kv_cache_reuse:${ENABLE_KV_CACHE_REUSE},normalize_log_probs:${NORMALIZE_LOG_PROBS},enable_chunked_context:${ENABLE_CHUNKED_CONTEXT},gpu_device_ids:${GPU_DEVICE_IDS},decoding_mode:${DECODING_MODE},max_queue_size:${MAX_QUEUE_SIZE},encoder_input_features_data_type:${ENCODER_INPUT_FEATURES_DTYPE},logits_datatype:TYPE_FP32 + python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm_draft/config.pbtxt triton_backend:${BACKEND},engine_dir:${DRAFT_ENGINE_PATH},decoupled_mode:${DECOUPLED_MODE},max_tokens_in_paged_kv_cache:${MAX_TOKENS_IN_KV_CACHE},max_attention_window_size:${MAX_ATTENTION_WINDOW_SIZE},batch_scheduler_policy:${BATCH_SCHEDULER_POLICY},batching_strategy:${BATCHING_STRATEGY},kv_cache_free_gpu_mem_fraction:${KV_CACHE_FREE_GPU_MEM_FRACTION},exclude_input_in_output:${EXCLUDE_INPUT_IN_OUTPUT},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},max_beam_width:${MAX_BEAM_WIDTH},enable_kv_cache_reuse:${ENABLE_KV_CACHE_REUSE},normalize_log_probs:${NORMALIZE_LOG_PROBS},enable_chunked_context:${ENABLE_CHUNKED_CONTEXT},gpu_device_ids:${GPU_DEVICE_IDS},decoding_mode:${DECODING_MODE},max_queue_size:${MAX_QUEUE_SIZE},encoder_input_features_data_type:${ENCODER_INPUT_FEATURES_DTYPE},prompt_embedding_table_data_type:${PROMPT_EMBEDDING_TABLE_DTYPE},logits_datatype:TYPE_FP32 fi if [ "${TARGET_ENGINE_PATH}" != "" ] && [ "${TARGET_ENGINE_PATH}" != "skip" ]; then echo "Filling triton repository at ${TRITON_REPO}/tensorrt_llm_target with engine ${TARGET_ENGINE_PATH}" - python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm_target/config.pbtxt triton_backend:${BACKEND},engine_dir:${TARGET_ENGINE_PATH},decoupled_mode:${DECOUPLED_MODE},max_tokens_in_paged_kv_cache:${MAX_TOKENS_IN_KV_CACHE},max_attention_window_size:${MAX_ATTENTION_WINDOW_SIZE},batch_scheduler_policy:${BATCH_SCHEDULER_POLICY},batching_strategy:${BATCHING_STRATEGY},kv_cache_free_gpu_mem_fraction:${KV_CACHE_FREE_GPU_MEM_FRACTION},exclude_input_in_output:${EXCLUDE_INPUT_IN_OUTPUT},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},max_beam_width:${MAX_BEAM_WIDTH},enable_kv_cache_reuse:true,normalize_log_probs:${NORMALIZE_LOG_PROBS},enable_chunked_context:${ENABLE_CHUNKED_CONTEXT},gpu_device_ids:${GPU_DEVICE_IDS},decoding_mode:${DECODING_MODE},max_queue_size:${MAX_QUEUE_SIZE},encoder_input_features_data_type:${ENCODER_INPUT_FEATURES_DTYPE},logits_datatype:TYPE_FP32 + python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm_target/config.pbtxt triton_backend:${BACKEND},engine_dir:${TARGET_ENGINE_PATH},decoupled_mode:${DECOUPLED_MODE},max_tokens_in_paged_kv_cache:${MAX_TOKENS_IN_KV_CACHE},max_attention_window_size:${MAX_ATTENTION_WINDOW_SIZE},batch_scheduler_policy:${BATCH_SCHEDULER_POLICY},batching_strategy:${BATCHING_STRATEGY},kv_cache_free_gpu_mem_fraction:${KV_CACHE_FREE_GPU_MEM_FRACTION},exclude_input_in_output:${EXCLUDE_INPUT_IN_OUTPUT},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},max_beam_width:${MAX_BEAM_WIDTH},enable_kv_cache_reuse:true,normalize_log_probs:${NORMALIZE_LOG_PROBS},enable_chunked_context:${ENABLE_CHUNKED_CONTEXT},gpu_device_ids:${GPU_DEVICE_IDS},decoding_mode:${DECODING_MODE},max_queue_size:${MAX_QUEUE_SIZE},encoder_input_features_data_type:${ENCODER_INPUT_FEATURES_DTYPE},prompt_embedding_table_data_type:${PROMPT_EMBEDDING_TABLE_DTYPE},logits_datatype:TYPE_FP32 fi @@ -217,7 +218,7 @@ fill_triton_repo () { cp all_models/multimodal/multimodal_encoders ${TRITON_REPO} -r python3 tools/fill_template.py -i ${TRITON_REPO}/ensemble/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},logits_datatype:TYPE_FP32 python3 tools/fill_template.py -i ${TRITON_REPO}/preprocessing/config.pbtxt multimodal_model_path:${MULTIMODAL_ENGINE_PATH},engine_dir:${DECODER_ENGINE_PATH} - python3 tools/fill_template.py -i ${TRITON_REPO}/multimodal_encoders/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},multimodal_model_path:${MULTIMODAL_ENGINE_PATH},encoder_input_features_data_type:${ENCODER_INPUT_FEATURES_DTYPE},hf_model_path:${TOKENIZER_PATH} + python3 tools/fill_template.py -i ${TRITON_REPO}/multimodal_encoders/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},multimodal_model_path:${MULTIMODAL_ENGINE_PATH},encoder_input_features_data_type:${ENCODER_INPUT_FEATURES_DTYPE},prompt_embedding_table_data_type:${PROMPT_EMBEDDING_TABLE_DTYPE},hf_model_path:${TOKENIZER_PATH} python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm_bls/config.pbtxt multimodal_encoders_name:multimodal_encoders fi @@ -649,6 +650,7 @@ TRITON_METRICS_PORT="8002" GPU_DEVICE_IDS="" DECODING_MODE="top_k_top_p" MAX_QUEUE_SIZE="0" +PROMPT_EMBEDDING_TABLE_DTYPE="TYPE_FP16" if [ "$MODEL" = "gpt-ib" ] || [ "$MODEL" = "mistral-ib" ] || [ "$MODEL" = "mistral-ib-mm" ]; then diff --git a/tests/integration/defs/triton_server/test_triton_llm.py b/tests/integration/defs/triton_server/test_triton_llm.py index d6f4be2b05..02755da458 100644 --- a/tests/integration/defs/triton_server/test_triton_llm.py +++ b/tests/integration/defs/triton_server/test_triton_llm.py @@ -1,4 +1,5 @@ import os +import re import sys import pytest @@ -3893,3 +3894,198 @@ def test_tiny_llama_ifb_token_counts( print_info( f"Successfully tested token count functionality for {TOKEN_COUNT_TEST} mode" ) + + +@pytest.mark.skip_less_device_memory(80000) +@pytest.mark.parametrize("E2E_MODEL_NAME", ["ensemble", "tensorrt_llm_bls"]) +@pytest.mark.parametrize("ACCUMULATE_TOKEN", ["False"]) +@pytest.mark.parametrize("BLS_INSTANCE_COUNT", ["1"]) +@pytest.mark.parametrize("PREPROCESSING_INSTANCE_COUNT", ["1"]) +@pytest.mark.parametrize("POSTPROCESSING_INSTANCE_COUNT", ["1"]) +@pytest.mark.parametrize("MAX_TOKENS_IN_KV_CACHE", [""]) +@pytest.mark.parametrize("MAX_ATTENTION_WINDOW_SIZE", [""]) +@pytest.mark.parametrize("BATCH_SCHEDULER_POLICY", + ["max_utilization", "guaranteed_no_evict"]) +@pytest.mark.parametrize("KV_CACHE_FREE_GPU_MEM_FRACTION", ["0.7"]) +@pytest.mark.parametrize("CROSS_KV_CACHE_FRACTION", [""]) +@pytest.mark.parametrize("ENABLE_TRT_OVERLAP", ["False"], + ids=["disableTrtOverlap"]) +@pytest.mark.parametrize("BATCHING_STRATEGY", ["inflight_fused_batching"]) +@pytest.mark.parametrize("DECOUPLED_MODE", ["True", "False"], + ids=["enableDecoupleMode", "disableDecoupleMode"]) +@pytest.mark.parametrize("TRITON_MAX_BATCH_SIZE", ["1"]) +@pytest.mark.parametrize("MAX_QUEUE_DELAY_MICROSECONDS", ["0"]) +@pytest.mark.parametrize("ENABLE_KV_CACHE_REUSE", ["False"]) +@pytest.mark.parametrize("NORMALIZE_LOG_PROBS", ["True"]) +@pytest.mark.parametrize("ENABLE_CHUNKED_CONTEXT", ["False"]) +@pytest.mark.parametrize("GPU_DEVICE_IDS", [""]) +@pytest.mark.parametrize("DECODING_MODE", [""]) +@pytest.mark.parametrize("MAX_BEAM_WIDTH", ["1"]) +@pytest.mark.parametrize("EXCLUDE_INPUT_IN_OUTPUT", ["False"]) +@pytest.mark.parametrize("PROMPT_EMBEDDING_TABLE_DTYPE", + ["TYPE_BF16"]) # allow override later +@pytest.mark.parametrize("ENCODER_INPUT_FEATURES_DTYPE", + ["TYPE_FP16"]) # pixtral uses fp16 vision by default +def test_mistral_small_3_1_24b_pixtral( + E2E_MODEL_NAME, + MAX_TOKENS_IN_KV_CACHE, + MAX_ATTENTION_WINDOW_SIZE, + BATCH_SCHEDULER_POLICY, + KV_CACHE_FREE_GPU_MEM_FRACTION, + CROSS_KV_CACHE_FRACTION, + ENABLE_TRT_OVERLAP, + BATCHING_STRATEGY, + DECOUPLED_MODE, + TRITON_MAX_BATCH_SIZE, + MAX_QUEUE_DELAY_MICROSECONDS, + MAX_BEAM_WIDTH, + ENABLE_KV_CACHE_REUSE, + NORMALIZE_LOG_PROBS, + ENABLE_CHUNKED_CONTEXT, + GPU_DEVICE_IDS, + DECODING_MODE, + PREPROCESSING_INSTANCE_COUNT, + POSTPROCESSING_INSTANCE_COUNT, + ACCUMULATE_TOKEN, + BLS_INSTANCE_COUNT, + EXCLUDE_INPUT_IN_OUTPUT, + PROMPT_EMBEDDING_TABLE_DTYPE, + ENCODER_INPUT_FEATURES_DTYPE, + tensorrt_llm_multimodal_example_root, + tensorrt_llm_llama_example_root, + mistral_small_3_1_24b_model_root, + llm_backend_multimodal_example_root, + llm_backend_venv, + llm_root, +): + if BATCHING_STRATEGY == "V1" and BATCH_SCHEDULER_POLICY == "max_utilization": + pytest.skip("Skipping. V1 doesn't support max_utilization.") + + llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"] + + # Build Engines (LLM + vision) + ENGINE_PATH, MULTIMODAL_ENGINE_DIR = prepare_mistral3_pixtral_engine( + tensorrt_llm_multimodal_example_root, tensorrt_llm_llama_example_root, + mistral_small_3_1_24b_model_root) + + # Prepare model repo + new_model_repo = os.path.join(llm_backend_repo_root, "triton_repo") + prepare_ib_model_repo(llm_backend_repo_root, new_model_repo) + + # Prepare multimodal specific repo + prepare_multimodal_model_repo(llm_backend_repo_root, new_model_repo, + "ensemble") + prepare_multimodal_model_repo(llm_backend_repo_root, new_model_repo, + "multimodal_encoders") + + # Modify config.pbtxt + TOKENIZER_PATH = mistral_small_3_1_24b_model_root + modify_ib_config_pbtxt( + new_model_repo, + ENGINE_PATH, + TOKENIZER_PATH, + llm_backend_repo_root, + DECOUPLED_MODE, + MAX_TOKENS_IN_KV_CACHE, + MAX_ATTENTION_WINDOW_SIZE, + BATCH_SCHEDULER_POLICY, + BATCHING_STRATEGY, + KV_CACHE_FREE_GPU_MEM_FRACTION, + EXCLUDE_INPUT_IN_OUTPUT, + ENABLE_TRT_OVERLAP, + TRITON_MAX_BATCH_SIZE, + MAX_QUEUE_DELAY_MICROSECONDS, + MAX_BEAM_WIDTH, + ENABLE_KV_CACHE_REUSE, + NORMALIZE_LOG_PROBS, + ENABLE_CHUNKED_CONTEXT, + GPU_DEVICE_IDS, + DECODING_MODE, + PREPROCESSING_INSTANCE_COUNT, + POSTPROCESSING_INSTANCE_COUNT, + ACCUMULATE_TOKEN, + BLS_INSTANCE_COUNT, + MULTIMODAL_ENGINE_PATH=MULTIMODAL_ENGINE_DIR, + ENCODER_INPUT_FEATURES_DTYPE=ENCODER_INPUT_FEATURES_DTYPE, + PROMPT_EMBEDDING_TABLE_DTYPE=PROMPT_EMBEDDING_TABLE_DTYPE, + ) + + # Launch Triton Server + launch_server_py = os.path.join(llm_backend_repo_root, "scripts", + "launch_triton_server.py") + check_call( + f"PMIX_MCA_gds=hash python3 {launch_server_py} --world_size=1 --model_repo={new_model_repo}", + shell=True) + check_server_ready() + + image_merlion = os.path.join( + llm_root, + "tests/integration/test_input_files/merlion.png", + ) + image_football = os.path.join( + llm_root, + "tests/integration/test_input_files/pexels-franco-monsalvo-252430633-32285228.jpg", + ) + image_hockey = os.path.join( + llm_root, + "tests/integration/test_input_files/pexels-ron-lach-8975010.jpg", + ) + image_basketball = os.path.join( + llm_root, + "tests/integration/test_input_files/pexels-maxim-shklyaev-1511525-2914194.jpg", + ) + + test_cases = [ + { + "text": "What is the capital of England?", + "image": "", + "match": re.compile("london", re.IGNORECASE) + }, + { + "text": "In as few words as possible, what city is this?", + "image": image_merlion, + "match": re.compile("singapore", re.IGNORECASE) + }, + { + "text": + "In as few words as possible, what sports are depicted in the images?", + "image": + ",".join([image_football, image_hockey]), + "match": + re.compile("(football|soccer).*hockey", re.IGNORECASE | re.DOTALL) + }, + { + "text": + "In as few words as possible, what sports are depicted in the images?", + "image": + ",".join([image_football, image_hockey, image_basketball]), + "match": + re.compile("(football|soccer).*hockey.*basket", + re.IGNORECASE | re.DOTALL) + }, + ] + + for test_case in test_cases: + TEXT = test_case["text"] + IMAGE = test_case["image"] + MATCH = test_case["match"] + + # Run Test: use multimodal client; set model_type to pixtral + run_cmd = [ + f"{llm_backend_multimodal_example_root}/client.py", + "--model_type=pixtral", + f"--text={TEXT}", + f"--image={IMAGE}", + "--request-output-len=128", + "--end-id=2", + ] + if DECOUPLED_MODE == "True": + run_cmd += ["--streaming"] + + if E2E_MODEL_NAME == "tensorrt_llm_bls": + run_cmd += ["--use_bls"] + + output = venv_check_output(llm_backend_venv, run_cmd) + + assert MATCH.search( + output), f"Test failed for input: {TEXT=}, {IMAGE=}, {output=}" diff --git a/tests/integration/test_input_files/excel_table_test.jpg b/tests/integration/test_input_files/excel_table_test.jpg index 5c07c5e61e..f81e6b7bdc 100644 Binary files a/tests/integration/test_input_files/excel_table_test.jpg and b/tests/integration/test_input_files/excel_table_test.jpg differ diff --git a/tests/integration/test_input_files/merlion.png b/tests/integration/test_input_files/merlion.png new file mode 100644 index 0000000000..c8d299377a --- /dev/null +++ b/tests/integration/test_input_files/merlion.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1f3b6a507ec92e8f47ac6d7c64e11b03fcba8c550bcb6851f80e261e8951431 +size 1604159 diff --git a/tests/integration/test_input_files/pexels-franco-monsalvo-252430633-32285228.jpg b/tests/integration/test_input_files/pexels-franco-monsalvo-252430633-32285228.jpg new file mode 100644 index 0000000000..ae27b79375 --- /dev/null +++ b/tests/integration/test_input_files/pexels-franco-monsalvo-252430633-32285228.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4bd1efd0c8fe48b421210cd132dc3b3b2902ccf1523bb9bec3a3883bb5c7a650 +size 116299 diff --git a/tests/integration/test_input_files/pexels-maxim-shklyaev-1511525-2914194.jpg b/tests/integration/test_input_files/pexels-maxim-shklyaev-1511525-2914194.jpg new file mode 100644 index 0000000000..238fd51d10 --- /dev/null +++ b/tests/integration/test_input_files/pexels-maxim-shklyaev-1511525-2914194.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd922b837bc92353d49a60df1dd933eddfe7546e2b16b365acaadb9b2a0a683b +size 72231 diff --git a/tests/integration/test_input_files/pexels-ron-lach-8975010.jpg b/tests/integration/test_input_files/pexels-ron-lach-8975010.jpg new file mode 100644 index 0000000000..07ed42b2df --- /dev/null +++ b/tests/integration/test_input_files/pexels-ron-lach-8975010.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:31c6fedadcb79990687d00d24350f774f4ad319439c89ed67d47c1df35a556fb +size 83652 diff --git a/tests/integration/test_lists/qa/llm_function_full.txt b/tests/integration/test_lists/qa/llm_function_full.txt index 9e6e12b400..6af50c17b6 100644 --- a/tests/integration/test_lists/qa/llm_function_full.txt +++ b/tests/integration/test_lists/qa/llm_function_full.txt @@ -464,6 +464,7 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_eagle3_tp8[eagl accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_eagle3_tp8[eagle3_one_model=False] accuracy/test_llm_api_pytorch.py::TestMistral7B::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype +accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_auto_dtype[tp8-cuda_graph=False] accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_auto_dtype[tp8ep4-cuda_graph=True] accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_auto_dtype[tp8ep8-cuda_graph=True] @@ -571,6 +572,8 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[GSM8K] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[MMLU] accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True] accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[False] @@ -579,7 +582,7 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_fp8[throughput_laten accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_fp8[latency] accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_cutlass] accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm] -accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_eagle3] +accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4_4gpus[latency_moe_trtllm_eagle3] accuracy/test_llm_api_pytorch.py::TestKanana_Instruct::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestBielik11BInstruct::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestBielik11BInstruct::test_fp8 @@ -631,7 +634,6 @@ test_e2e.py::test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k[Llama-4-S test_e2e.py::test_relaxed_acceptance_quickstart_advanced_deepseek_r1_8gpus[DeepSeek-R1-DeepSeek-R1/DeepSeek-R1] test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-image-False] test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-video-False] -test_e2e.py::test_ptp_quickstart_multimodal[llava-v1.6-mistral-7b-llava-v1.6-mistral-7b-hf-image-False] test_e2e.py::test_ptp_quickstart_multimodal[qwen2-vl-7b-instruct-Qwen2-VL-7B-Instruct-image-False] test_e2e.py::test_ptp_quickstart_multimodal[qwen2-vl-7b-instruct-Qwen2-VL-7B-Instruct-video-False] test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-image-False] @@ -646,6 +648,9 @@ test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[audio] test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image] test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image_audio] +test_e2e.py::test_ptp_quickstart_multimodal_2gpu[gemma-3-27b-it-gemma/gemma-3-27b-it] +test_e2e.py::test_ptp_quickstart_multimodal_2gpu[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503] +test_e2e.py::test_ptp_quickstart_multimodal_2gpu[Phi-4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct] test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity] test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity] test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B] @@ -723,3 +728,9 @@ disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyL # These tests will impact triton. They should be at the end of all tests (https://nvbugs/4904271) # examples/test_openai.py::test_llm_openai_triton_1gpu # examples/test_openai.py::test_llm_openai_triton_plugingen_1gpu + +# llm-api promote pytorch to default +llmapi/test_llm_api_qa.py::TestLlmDefaultBackend::test_llm_args_logging +llmapi/test_llm_api_qa.py::TestLlmDefaultBackend::test_llm_args_type_tensorrt +llmapi/test_llm_api_qa.py::TestLlmDefaultBackend::test_llm_args_type_default +llmapi/test_llm_api_qa.py::TestLlmDefaultBackend::test_llm_args_logging diff --git a/tests/integration/test_lists/qa/llm_function_multinode.txt b/tests/integration/test_lists/qa/llm_function_multinode.txt index 1348faa84b..06a3d4714b 100644 --- a/tests/integration/test_lists/qa/llm_function_multinode.txt +++ b/tests/integration/test_lists/qa/llm_function_multinode.txt @@ -1,9 +1,8 @@ -examples/test_llama.py::test_llm_llama_v3_1_2nodes_8gpus[llama-3.1-8b-disable_fp8-tp16pp1-build] -examples/test_llama.py::test_llm_llama_v3_1_2nodes_8gpus[llama-3.1-8b-disable_fp8-tp16pp1-infer] -examples/test_mixtral.py::test_llm_mixtral_2nodes_8gpus[Mixtral-8x22B-v0.1-plugin-renormalize-tensor_parallel-build] -examples/test_mixtral.py::test_llm_mixtral_2nodes_8gpus[Mixtral-8x22B-v0.1-plugin-renormalize-tensor_parallel-infer] test_e2e.py::test_ptp_quickstart_advanced_deepseek_multi_nodes[DeepSeek-V3] -test_e2e.py::test_ptp_quickstart_advanced_deepseek_multi_nodes[DeepSeek-R1/DeepSeek-R1-0528-FP4] -test_e2e.py::test_ptp_quickstart_advanced_llama_multi_nodes[llama-3.3-models/Llama-3.3-70B-Instruct] -test_e2e.py::test_ptp_quickstart_advanced_llama_multi_nodes[llama4-models/Llama-4-Maverick-17B-128E-Instruct] test_e2e.py::test_openai_multinodes_chat_tp16pp1 +test_e2e.py::test_multi_nodes_eval[llama-3.3-models/Llama-3.3-70B-Instruct-tp16-mmlu] +test_e2e.py::test_multi_nodes_eval[llama4-models/Llama-4-Maverick-17B-128E-Instruct-tp8pp2-mmlu] +test_e2e.py::test_multi_nodes_eval[llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8-tp8pp2-mmlu] +test_e2e.py::test_multi_nodes_eval[Qwen3/Qwen3-235B-A22B-tp16-mmlu] +test_e2e.py::test_multi_nodes_eval[Qwen3/saved_models_Qwen3-235B-A22B_nvfp4_hf-tp16-mmlu] +test_e2e.py::test_multi_nodes_eval[DeepSeek-R1/DeepSeek-R1-0528-FP4-tp16-mmlu] diff --git a/tests/integration/test_lists/qa/llm_function_sanity.txt b/tests/integration/test_lists/qa/llm_function_sanity.txt index c977a77d3c..7f9c03d963 100644 --- a/tests/integration/test_lists/qa/llm_function_sanity.txt +++ b/tests/integration/test_lists/qa/llm_function_sanity.txt @@ -47,6 +47,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True] accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype +accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestKanana_Instruct::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestKimiK2::test_fp8_blockscale[latency] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4 @@ -116,7 +117,7 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_cutl accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_cutlass-torch_compile=True] accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_trtllm-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_trtllm-torch_compile=True] -accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_eagle3] +accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4_4gpus[latency_moe_trtllm_eagle3] accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-CUTLASS] accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-TRITON] accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-TRTLLM] @@ -164,7 +165,6 @@ test_e2e.py::test_ptp_quickstart_advanced[Llama3.2-11B-BF16-llama-3.2-models/Lla test_e2e.py::test_ptp_quickstart_advanced[Qwen3-30B-A3B-Qwen3/Qwen3-30B-A3B] test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity] test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity] -test_e2e.py::test_ptp_quickstart_multimodal[llava-v1.6-mistral-7b-llava-v1.6-mistral-7b-hf-image-False] test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-False] test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True] test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True] diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index ce285faa79..30fc6c05b5 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -199,7 +199,7 @@ l0_a10: tests: - stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-MAX_UTILIZATION-pytorch-stress-test] - stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-GUARANTEED_NO_EVICT-pytorch-stress-test] -l0_a10_nanobind: +l0_a10_pybind: - condition: ranges: system_gpu_count: @@ -211,6 +211,7 @@ l0_a10_nanobind: linux_distribution_name: ubuntu* terms: stage: pre_merge - backend: tensorrt tests: - unittest/bindings + - test_e2e.py::test_openai_chat_example[trt] + - test_e2e.py::test_openai_chat_example[pytorch] TIMEOUT (90) diff --git a/tests/integration/test_lists/test-db/l0_a100.yml b/tests/integration/test_lists/test-db/l0_a100.yml index b8a846ccff..7a36bca755 100644 --- a/tests/integration/test_lists/test-db/l0_a100.yml +++ b/tests/integration/test_lists/test-db/l0_a100.yml @@ -99,3 +99,11 @@ l0_a100: - triton_server/test_triton.py::test_eagle[eagle] - triton_server/test_triton.py::test_llava_onevision[llava_onevision] - triton_server/test_triton.py::test_qwen2_vl[qwen2_vl] + - triton_server/test_triton_llm.py::test_mistral_small_3_1_24b_pixtral[TYPE_FP16-TYPE_BF16-False-1---False-True-False-0-1-enableDecoupleMode-inflight_fused_batching-disableTrtOverlap--0.7-max_utilization---1-1-1-False-ensemble] + - triton_server/test_triton_llm.py::test_mistral_small_3_1_24b_pixtral[TYPE_FP16-TYPE_BF16-False-1---False-True-False-0-1-enableDecoupleMode-inflight_fused_batching-disableTrtOverlap--0.7-max_utilization---1-1-1-False-tensorrt_llm_bls] + - triton_server/test_triton_llm.py::test_mistral_small_3_1_24b_pixtral[TYPE_FP16-TYPE_BF16-False-1---False-True-False-0-1-enableDecoupleMode-inflight_fused_batching-disableTrtOverlap--0.7-guaranteed_no_evict---1-1-1-False-ensemble] + - triton_server/test_triton_llm.py::test_mistral_small_3_1_24b_pixtral[TYPE_FP16-TYPE_BF16-False-1---False-True-False-0-1-enableDecoupleMode-inflight_fused_batching-disableTrtOverlap--0.7-guaranteed_no_evict---1-1-1-False-tensorrt_llm_bls] + - triton_server/test_triton_llm.py::test_mistral_small_3_1_24b_pixtral[TYPE_FP16-TYPE_BF16-False-1---False-True-False-0-1-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap--0.7-max_utilization---1-1-1-False-ensemble] + - triton_server/test_triton_llm.py::test_mistral_small_3_1_24b_pixtral[TYPE_FP16-TYPE_BF16-False-1---False-True-False-0-1-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap--0.7-max_utilization---1-1-1-False-tensorrt_llm_bls] + - triton_server/test_triton_llm.py::test_mistral_small_3_1_24b_pixtral[TYPE_FP16-TYPE_BF16-False-1---False-True-False-0-1-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap--0.7-guaranteed_no_evict---1-1-1-False-ensemble] + - triton_server/test_triton_llm.py::test_mistral_small_3_1_24b_pixtral[TYPE_FP16-TYPE_BF16-False-1---False-True-False-0-1-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap--0.7-guaranteed_no_evict---1-1-1-False-tensorrt_llm_bls] diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index ae0d0bd041..66cf676f2f 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -79,7 +79,7 @@ l0_b200: - '*b100*' linux_distribution_name: ubuntu* terms: - stage: pre_merge + stage: post_merge backend: tensorrt tests: # ------------- TRT tests --------------- diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index a3179e38e9..36fcdce532 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -54,6 +54,8 @@ l0_dgx_h100: - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=2] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=2] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[GSM8K] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[MMLU] - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend - test_e2e.py::test_ptp_quickstart_advanced_bs1 diff --git a/tests/integration/test_lists/test-db/l0_gb200.yml b/tests/integration/test_lists/test-db/l0_gb200.yml index ac39fbdc88..7d1cc92fef 100644 --- a/tests/integration/test_lists/test-db/l0_gb200.yml +++ b/tests/integration/test_lists/test-db/l0_gb200.yml @@ -69,3 +69,4 @@ l0_gb200: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] + - accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4_4gpus[latency_moe_trtllm_eagle3] TIMEOUT (90) diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_nodes.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_nodes.yml index 9c04ad7090..857319c44c 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_nodes.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_nodes.yml @@ -19,4 +19,3 @@ l0_gb200_multi_nodes: - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput] TIMEOUT (180) - accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_cutlass] TIMEOUT (90) - accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm] TIMEOUT (90) - - accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_eagle3] TIMEOUT (90) diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index bc84082317..0263c452b3 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -218,6 +218,7 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=none-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True] - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_fp8_prequantized + - accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_fp8_prequantized - accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype - accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False] diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index db4f919855..cc970b452f 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -263,7 +263,6 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-9b-it] SKIP (https://nvbugs/5434451) examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-27b-it] SKIP (https://nvbugs/5434451) examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-3-1b-it] SKIP (https://nvbugs/5434451) -accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_eagle3] SKIP (https://nvbugs/5437405,https://nvbugs/5437384) accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4 SKIP (https://nvbugs/5440241) test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-image-False] SKIP (https://nvbugs/5444060,https://nvbugs/5444095) test_e2e.py::test_ptp_quickstart_multimodal[llava-v1.6-mistral-7b-llava-v1.6-mistral-7b-hf-image-False] SKIP (https://nvbugs/5444060,https://nvbugs/5444095) @@ -298,6 +297,7 @@ triton_server/test_triton.py::test_mistral_ib_mm[mistral-ib-mm] SKIP (https://nv triton_server/test_triton.py::test_t5_ib[t5-ib] SKIP (https://nvbugs/5456482) triton_server/test_triton_llm.py::test_gpt_speculative_decoding_bls[False-False-1---False-True-True-0-128-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap-0.2-guaranteed_no_evict---1-1-1-False-ensemble] SKIP (https://nvbugs/5456485) accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320) +accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_eagle3] SKIP (https://nvbugs/5437384) accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend SKIP (https://nvbugs/5448437) accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend SKIP (https://nvbugs/5448437) accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen] SKIP (https://nvbugs/5445466) @@ -316,3 +316,12 @@ disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1 examples/test_multimodal.py::test_llm_multimodal_general[Mistral-Small-3.1-24B-Instruct-2503-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5431146) accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[latency] SKIP (https://nvbugs/5464461) disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5448449) +full:H100/accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_eagle3_tp8[eagle3_one_model=True] SKIP (https://nvbugs/5467815) +full:H100/accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_eagle3_tp8[eagle3_one_model=False] SKIP (https://nvbugs/5467815) +full:H100/accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp8[tp4-cuda_graph=True] SKIP (https://nvbugs/5467815) +full:H100/accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp8_chunked_prefill[tp4ep4-cuda_graph=True] SKIP (https://nvbugs/5467815) +accuracy/test_disaggregated_serving.py::TestQwen3_30B_A3B::test_mixed_ctx_gen_model[ctxpp2gentp2] SKIP (https://nvbugs/5470769) +full:L40S/accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=False] SKIP (https://nvbugs/5347051) +full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] SKIP (https://nvbugs/5471106) +full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp2pp2] SKIP (https://nvbugs/5471108) +test_e2e.py::test_multi_nodes_eval[llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8-tp8pp2-mmlu] SKIP (https://nvbugs/5473781) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py index c81ca0ae1c..58d22302f2 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py @@ -8,9 +8,7 @@ from torch.export import export from tensorrt_llm._torch.auto_deploy.distributed import common as dist from tensorrt_llm._torch.auto_deploy.distributed.trtllm import is_trtllm_op_available from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm -from tensorrt_llm._torch.auto_deploy.transformations.library.collectives import ( - fuse_allreduce_residual_rmsnorm, -) +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op from tensorrt_llm.llmapi.mpi_session import MpiPoolSession @@ -65,14 +63,21 @@ def _test_allreduce_fusion(port: int): original_outputs, residual_original = gm(x, residual) # Fuse ops - fuse_allreduce_residual_rmsnorm(gm) + gm_transformed = InferenceOptimizer( + None, + { + "fuse_allreduce_residual_rmsnorm": { + "stage": "post_load_fusion", + }, + }, + )(None, gm) # Run the fused graph - fused_outputs, residual_fused = gm(x, residual) + fused_outputs, residual_fused = gm_transformed(x, residual) # Check if fused node in the graph has_fused_node = False - for node in gm.graph.nodes: + for node in gm_transformed.graph.nodes: if is_op(node, torch.ops.dist.fused_allreduce_residual_rmsnorm): has_fused_node = True assert has_fused_node, "Fused node not found." @@ -86,8 +91,8 @@ def _test_allreduce_fusion(port: int): ) # check if we can still export the model as expected - export(gm, args=args) - torch_export_to_gm(gm, args=args) + export(gm_transformed, args=args) + torch_export_to_gm(gm_transformed, args=args) @pytest.mark.parametrize("device_count", get_device_counts()) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py index 4aa1a875c4..ed3b98f281 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py @@ -8,12 +8,13 @@ import torch import torch.nn as nn import torch.nn.functional as F from _dist_test_utils import get_device_counts -from _graph_test_helpers import run_test +from _graph_test_helpers import run_test_transformed_gm from _torch_test_utils import fp8_compatible import tensorrt_llm._torch.auto_deploy.distributed.common as dist_common from tensorrt_llm._torch.auto_deploy.custom_ops.quant import FP8Linear -from tensorrt_llm._torch.auto_deploy.transformations.library import fuse_collectives +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op @@ -61,11 +62,21 @@ def _run_job( is_op(n, torch.ops.auto_deploy.torch_dist_all_reduce) for n in gm.graph.nodes ) + gm = torch_export_to_gm(model, args=(x,), clone=True) + gm_transformed = InferenceOptimizer( + None, + { + "fuse_collectives": { + "stage": "post_load_fusion", + }, + }, + )(None, gm) + # now run the test - run_test( + run_test_transformed_gm( model, x, - transform=fuse_collectives, + gm_transformed, check_transformed_graph=check_transformed_graph, _get_expected_num_params=_get_expected_num_params, test_load_hook=False, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py index be2f9d52af..691aad78c5 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py @@ -1,12 +1,11 @@ -from functools import partial - import pytest import torch -from _graph_test_helpers import run_test +from _graph_test_helpers import run_test_transformed_gm from torch.export import Dim from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import * # noqa -from tensorrt_llm._torch.auto_deploy.transformations.library.rms_norm import fuse_rmsnorm +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op @@ -52,15 +51,28 @@ def test_rmsnorm_fusion(eps, variant, op): return any(is_op(n, op) for n in gm.graph.nodes) model = TestModel(eps) - gm_transformed = run_test( + x = torch.randn(2, 1024, device="cuda", dtype=torch.float16) + dynamic_shapes = {0: Dim("batch_size", max=8)} + gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True) + gm_transformed = InferenceOptimizer( + None, + { + "fuse_rmsnorm": { + "stage": "post_load_fusion", + "backend": variant, + }, + }, + )(None, gm) + + run_test_transformed_gm( model, - torch.randn(2, 1024, device="cuda", dtype=torch.float16), - partial(fuse_rmsnorm, backend=variant), + x, + gm_transformed, checker, lambda num_p_og: num_p_og, - dynamic_shapes={0: Dim("batch_size", max=8)}, + dynamic_shapes=dynamic_shapes, ) - print(gm_transformed.graph) + new_input = torch.randn(4, 1024, device="cuda", dtype=torch.float16) y_transformed = gm_transformed(new_input) y_model = model(new_input) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py index 82a5104503..b99862fdc1 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py @@ -7,11 +7,12 @@ import pytest import torch import torch.nn as nn import torch.nn.functional as F -from _graph_test_helpers import count_buffers, run_test +from _graph_test_helpers import count_buffers, run_test_transformed_gm from _torch_test_utils import all_close, fp8_compatible, reset_parameters from tensorrt_llm._torch.auto_deploy.custom_ops.quant import FP8Linear -from tensorrt_llm._torch.auto_deploy.transformations.library import fuse_gemms +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op torch.manual_seed(0) @@ -254,10 +255,20 @@ def test_fusion(get_model: Callable[[], TestModel], dtype: str): buffer_size_before = count_buffers(model) - gm_transformed = run_test( + gm = torch_export_to_gm(model, args=(x,), clone=True) + gm_transformed = InferenceOptimizer( + None, + { + "fuse_gemms": { + "stage": "post_load_fusion", + }, + }, + )(None, gm) + + run_test_transformed_gm( model, x, - fuse_gemms, + gm_transformed, lambda gm: sum(is_linear_op(n, include_quantization=True) for n in gm.graph.nodes) == model.num_gemms_after_fusion, lambda num_p_og: num_p_og, # unchanged since fusing doesn't change param count diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py index f2fd32ea3e..9266027e11 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py @@ -2,19 +2,35 @@ from typing import Optional import pytest import torch -from _graph_test_helpers import FakeFactory +from _graph_test_helpers import SequenceEmbeddingInfo from _model_test_utils import GQA from _torch_test_utils import all_close -from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import CacheConfig, SequenceInfo -from tensorrt_llm._torch.auto_deploy.custom_ops.flashinfer_attention import FlashInferAttention -from tensorrt_llm._torch.auto_deploy.custom_ops.triton_attention import TritonAttention +from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import CacheConfig from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactory from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface -from tensorrt_llm._torch.auto_deploy.transform.interface import InferenceOptimizerConfig from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer -from tensorrt_llm._torch.auto_deploy.transformations.library import update_in_out_nodes -from tensorrt_llm._torch.auto_deploy.transformations.library.kvcache import insert_cached_attention + + +class DummyFactory(ModelFactory): + """Dummy factory to pass cache_config for testing.""" + + def __init__(self, model, cache_config): + self._model = model + self.cache_config = cache_config + + def build_model(self, device: str): + return self._model.to(device=device) + + def _build_model(self, device: str): + return + + def _load_checkpoint(self, model, device): + return + + def get_cache_config(self): + return self.cache_config # Class that uses SDPA directly instead of the regular attention mechanism @@ -68,42 +84,6 @@ class GQAWithSdpa(GQA): return self.o_proj(attn_output) -def _get_optimizer_config() -> InferenceOptimizerConfig: - return { - "build_model": { - "stage": "factory", - "device": "cuda", - "run_graph_cleanup": False, - "requires_clean_graph": False, - }, - "export_to_gm": { - "stage": "export", - "strict": False, - "clone_state_dict": True, - "run_graph_cleanup": False, - "requires_clean_graph": False, - }, - "cleanup_input_constraints": { - "stage": "post_export", - }, - } - - -class SequenceEmbeddingInfo(SequenceInfo): - hidden_size: int - dtype: torch.dtype - - def set_example_sequence(self) -> None: - super().set_example_sequence() - # set input ids to a 3D tensor (actually input embeddings) - self.input_ids = torch.rand( - *self.input_ids.shape, - self.hidden_size, - device=self.input_ids.device, - dtype=self.dtype, - ) - - # TODO (lucaslie): consider rewriting this test with a custom InferenceOptimizer config @pytest.mark.parametrize( "dtype", @@ -111,8 +91,8 @@ class SequenceEmbeddingInfo(SequenceInfo): ids=["float16", "float32"], ) @pytest.mark.parametrize( - "attn_descriptor", - [TritonAttention, FlashInferAttention], + "attn_backend", + ["triton", "flashinfer"], ids=["triton", "flashinfer"], ) @pytest.mark.parametrize( @@ -125,10 +105,10 @@ class SequenceEmbeddingInfo(SequenceInfo): ids=["regular", "gqa", "mqa"], ) @torch.inference_mode() -def test_sdpa_with_kv_cache(dtype, attn_descriptor, gqa_config): +def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config): """Test the SDPA transformation with KV cache.""" # flashinfer doesn't support float32 data type - if attn_descriptor == FlashInferAttention and dtype == torch.float32: + if attn_backend == "flashinfer" and dtype == torch.float32: pytest.skip("flashinfer doesn't support float32 data type") # Unpack the GQA configuration @@ -157,7 +137,6 @@ def test_sdpa_with_kv_cache(dtype, attn_descriptor, gqa_config): hidden_size, num_key_value_heads, ).to(dtype=dtype, device="cuda") - factory = FakeFactory(model) # Create input tensor and position_ids x = torch.rand(batch_size, seq_len, hidden_size).to(device="cuda", dtype=dtype) @@ -166,21 +145,37 @@ def test_sdpa_with_kv_cache(dtype, attn_descriptor, gqa_config): # Get the model's regular output y_model = model(x, position_ids) # b, s, d - # run modular inference optimizer up to post_export - optimizer = InferenceOptimizer(factory, _get_optimizer_config()) # type: ignore + # Apply the transformation + optimizer = InferenceOptimizer( + DummyFactory(model, CacheConfig()), + { + "build_model": { + "stage": "factory", + "device": "cuda", + "run_graph_cleanup": False, + "requires_clean_graph": False, + }, + "export_to_gm": { + "stage": "export", + "strict": False, + "clone_state_dict": True, + "run_graph_cleanup": False, + "requires_clean_graph": False, + }, + "cleanup_input_constraints": { + "stage": "post_export", + }, + "update_in_out_nodes": { + "stage": "cache_init", + }, + "insert_cached_attention": { + "stage": "cache_init", + "attn_backend": attn_backend, + }, + }, + ) # type: ignore gm = optimizer(cm) - y_gm = gm(x, position_ids) - assert all_close(y_model, y_gm, atol=atol, rtol=rtol) - - # Set up cache configuration - cache_config = CacheConfig() - - # Get input node(s) - update_in_out_nodes(gm, cm) - - # Apply the transformation - insert_cached_attention(gm, cm, attn_descriptor=attn_descriptor, cache_config=cache_config) gm.to("cuda") cm.initialize_caches() diff --git a/tests/unittest/_torch/misc/test_autotuner.py b/tests/unittest/_torch/misc/test_autotuner.py index c2f5c32141..5ed816df8d 100644 --- a/tests/unittest/_torch/misc/test_autotuner.py +++ b/tests/unittest/_torch/misc/test_autotuner.py @@ -151,7 +151,8 @@ def test_autotuner_try_block(): class PartialCrashedRunner(TunableRunner): def get_valid_tactics(self, inputs: List[FakeTensor], - profile: OptimizationProfile) -> List[int]: + profile: OptimizationProfile, + **kwargs) -> List[int]: return [-1, 0, 1] def forward(self, @@ -226,7 +227,7 @@ class GemmRunnerWithAttributes(TunableRunner): self.num_warps = num_warps def get_valid_tactics(self, inputs: List[FakeTensor], - profile: OptimizationProfile) -> List[int]: + profile: OptimizationProfile, **kwargs) -> List[int]: return [-1, 0, 1] def forward(self, @@ -313,11 +314,9 @@ def test_multiple_dynamic_shapes_cache(): class GemmRunnerWithTacticConfigs(TunableRunner): valid_tactic_ids = [-1, 0, 1] - def get_valid_tactics( - self, - inputs: List[FakeTensor], - profile: OptimizationProfile, - ) -> List[Dict[str, int]]: + def get_valid_tactics(self, inputs: List[FakeTensor], + profile: OptimizationProfile, + **kwargs) -> List[Dict[str, int]]: # The simulated delay is not deterministic, so we need to return specific tactics here return [{ "block_size": block_size, diff --git a/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py b/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py index 58c854931e..3e727e654b 100644 --- a/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py +++ b/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py @@ -33,7 +33,9 @@ def extract_decode_logprobs(result: RequestOutput, def create_nemotron_h_llm(use_cuda_graph, disable_overlap_scheduler, max_batch_size, - mamba_ssm_cache_dtype=None): + mamba_ssm_cache_dtype=None, + enable_chunked_prefill=False, + max_num_tokens=None): """Create LLM with specific overlap scheduler setting""" model_dir = f"{llm_models_root(check=True)}/Nemotron-H-8B-Base-8K" return LLM( @@ -47,6 +49,8 @@ def create_nemotron_h_llm(use_cuda_graph, mamba_ssm_cache_dtype="auto" if mamba_ssm_cache_dtype is None else mamba_ssm_cache_dtype), sampler_type="TRTLLMSampler", + enable_chunked_prefill=enable_chunked_prefill, + max_num_tokens=max_num_tokens, ) @@ -336,3 +340,62 @@ def test_nemotron_h_cuda_graph_overlap_scheduler(): msg=lambda x: f"Prompt {i}: with/without overlap scheduler (with CG) logprobs for all selected tokens {x}" ) + + +def test_nemotron_h_chunked_prefill(): + # Long prompts (~100 tokens) to make sure chunked prefill is enabled + # (At the time of development, tokens_per_block isn't configurable from the LLM API, + # and max_tokens (i.e. chunk size) needs to be a multiple of tokens_per_block) + prompts = [ + "Artificial Intelligence in Healthcare: Artificial intelligence (AI) is transforming healthcare by improving diagnostics, treatment plans, and patient care. AI algorithms can analyze medical images with high accuracy, assist in early disease detection, and personalize treatment plans based on patient data. Additionally, AI-powered chatbots and virtual assistants provide support to patients, enhancing accessibility and efficiency in healthcare services. As AI technology continues to advance, its integration into healthcare systems promises to deliver better outcomes and reduce costs. With continuous research and development, AI in healthcare is poised to", + "The Role of Cloud Computing: Cloud computing has revolutionized the way businesses operate by providing scalable, on-demand access to computing resources. This technology allows organizations to store and process data remotely, reducing the need for physical infrastructure and enabling greater flexibility. Cloud services facilitate collaboration, enhance data security, and support the deployment of innovative applications. As businesses increasingly adopt cloud solutions, they benefit from improved efficiency, cost savings, and the ability to rapidly adapt to changing market conditions. Companies leveraging cloud computing are better positioned to", + "Advancements in Renewable Energy: Renewable energy technologies, such as solar and wind power, are crucial for addressing climate change and reducing dependence on fossil fuels. Advances in energy storage, grid integration, and efficiency are making renewable energy sources more viable and cost-effective. Innovations in materials science and engineering are also driving the development of next-generation renewable technologies. As global efforts to combat climate change intensify, the continued advancement of renewable energy will play a pivotal role in achieving a sustainable future. Governments and industries are increasingly investing in", + "The Importance of Cybersecurity: In today's digital age, cybersecurity has become essential to protect sensitive information and maintain the integrity of systems. With the rise of cyber threats such as hacking, phishing, and ransomware, organizations must implement robust security measures to safeguard their data. Cybersecurity involves a combination of technologies, processes, and practices designed to defend against unauthorized access and attacks. By staying vigilant and updating security protocols, businesses can mitigate risks and ensure the safety of their digital assets. Proactive cybersecurity strategies are crucial in", + "The Impact of Artificial Intelligence on Education: Artificial intelligence is reshaping education by providing personalized learning experiences and automating administrative tasks. AI-driven educational tools can adapt to individual student needs, offering tailored feedback and resources to enhance learning outcomes. Additionally, AI can streamline administrative processes, allowing educators to focus more on teaching and student engagement. As AI continues to evolve, its role in education will expand, offering new opportunities for innovation and efficiency. The integration of AI in classrooms promises to revolutionize how students learn and how educators manage their", + ] + sampling_config = SamplingParams(max_tokens=10, + temperature=0.0, + return_context_logits=True, + return_generation_logits=True) + + with create_nemotron_h_llm(use_cuda_graph=False, + disable_overlap_scheduler=True, + max_batch_size=16) as llm: + outputs = llm.generate(prompts, + sampling_params=sampling_config, + use_tqdm=True) + + with create_nemotron_h_llm(use_cuda_graph=False, + disable_overlap_scheduler=True, + max_batch_size=16, + enable_chunked_prefill=True, + max_num_tokens=64) as llm: + chunked_prefill_outputs = llm.generate(prompts, + sampling_params=sampling_config, + use_tqdm=True) + + for i, (output, chunked_prefill_output) in enumerate( + zip(outputs, chunked_prefill_outputs)): + assert output.outputs[0].text == chunked_prefill_output.outputs[0].text + + # assert same prefill logprobs. Same atol as diff between mcore and initial impl + prefill_logprobs = extract_prefill_logprobs(output) + chunked_prefill_logprobs = extract_prefill_logprobs( + chunked_prefill_output) + torch.testing.assert_close( + prefill_logprobs, + chunked_prefill_logprobs, + atol=0.3, + rtol=0.05, + msg=lambda x: f"Prompt {i} prefill logprobs {x}") + + # Decode logprobs shouldn't be affected by chunked prefill - tolerance like batching tolerance + decode_logprobs = extract_decode_logprobs(output) + chunked_decode_logprobs = extract_decode_logprobs( + chunked_prefill_output) + torch.testing.assert_close( + decode_logprobs, + chunked_decode_logprobs, + atol=0.2, + rtol=0.05, + msg=lambda x: f"Prompt {i} decode logprobs {x}") diff --git a/tests/unittest/_torch/sampler/test_return_logits.py b/tests/unittest/_torch/sampler/test_return_logits.py index 0d6a5e28ca..a3af16c8bc 100644 --- a/tests/unittest/_torch/sampler/test_return_logits.py +++ b/tests/unittest/_torch/sampler/test_return_logits.py @@ -27,9 +27,6 @@ def test_generate_with_return_logits(disable_overlap_scheduler: bool, or return_log_probs): # prune space pytest.skip("Nothing to test") - if sampler_type == "TorchSampler" and gather_context_logits: - pytest.skip("TorchSampler does not support gather_context_logits") - build_config = BuildConfig() build_config.gather_context_logits = gather_context_logits @@ -94,9 +91,6 @@ def test_generate_async_with_return_logits(disable_overlap_scheduler: bool, or return_log_probs): # prune space pytest.skip("Nothing to test") - if sampler_type == "TorchSampler" and gather_context_logits: - pytest.skip("TorchSampler does not support gather_context_logits") - build_config = BuildConfig() build_config.gather_context_logits = gather_context_logits diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index 641f37931c..a687e03645 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -1,6 +1,9 @@ +import json import os import sys +import tempfile import unittest +from pathlib import Path import pytest import torch @@ -121,5 +124,107 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str, assert text_spec == text_ref +def test_deepseek_eagle3(): + use_cuda_graph = True + attn_backend = "TRTLLM" + disable_overlap_scheduler = False + enable_block_reuse = False + use_one_model = False + enable_chunked_prefill = False + + # Eagle3 one model works with overlap scheduler and block reuse. + total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 + if total_mem_gb < 150: + pytest.skip("Not enough memory to load target + draft model") + + models_path = llm_models_root() + eagle_config = { + 'architectures': ['LlamaForCausalLMEagle3'], + 'attention_bias': False, + 'attention_dropout': 0.0, + 'bos_token_id': 128000, + 'eos_token_id': [128001, 128008, 128009], + 'eagle_config': { + 'use_aux_hidden_state': False, + 'use_input_layernorm_in_first_layer': True, + 'use_last_layernorm': True, + 'use_mtp_layernorm': False + }, + 'head_dim': 128, + 'hidden_act': 'silu', + 'hidden_size': 2560, + 'initializer_range': 0.02, + 'intermediate_size': 16384, + 'max_position_embeddings': 4096, + 'mlp_bias': False, + 'model_type': 'llama', + 'num_attention_heads': 32, + 'num_eagle_features': 1, + 'num_hidden_layers': 1, + 'num_key_value_heads': 8, + 'pretraining_tp': 1, + 'rms_norm_eps': 1e-05, + 'rope_scaling': { + 'factor': 8.0, + 'high_freq_factor': 4.0, + 'low_freq_factor': 1.0, + 'original_max_position_embeddings': 8192, + 'rope_type': 'llama3' + }, + 'rope_theta': 500000.0, + 'tie_word_embeddings': False, + 'torch_dtype': 'bfloat16', + 'transformers_version': '4.52.4', + 'use_cache': True, + 'vocab_size': 129280, + 'draft_vocab_size': 129280 + } + with tempfile.TemporaryDirectory() as temp_dir: + eagle_model_dir = Path(temp_dir) + config_path = eagle_model_dir / "config.json" + with config_path.open("w") as f: + json.dump(eagle_config, f, indent=2) + target_model_dir = f"{models_path}/DeepSeek-V3-Lite/nvfp4_moe_only" + + # bs > 1 gives non-deterministic when doing IFB. There are slight chances + # that ref and spec does not match 100% + max_batch_size = 16 + max_draft_len = 3 + kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse, + free_gpu_memory_fraction=0.5) + cuda_graph_config = CudaGraphConfig( + batch_sizes=[1]) if use_cuda_graph else None + + llm_common_config = dict( + model=target_model_dir, + attn_backend=attn_backend, + disable_overlap_scheduler=disable_overlap_scheduler, + cuda_graph_config=cuda_graph_config, + max_batch_size=max_batch_size, + max_num_tokens=4096, + max_seq_len=4096, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=enable_chunked_prefill, + ) + + spec_config = EagleDecodingConfig( + max_draft_len=max_draft_len, + speculative_model_dir=eagle_model_dir, + # Llama 3 does not support one model eagle. + eagle3_one_model=use_one_model, + eagle3_layers_to_capture={29}, + load_format="dummy") + + llm_spec = LLM(**llm_common_config, speculative_config=spec_config) + + tok_ids = llm_spec.tokenizer.encode("The future of AI is") + + sampling_params = SamplingParams(max_tokens=32, temperature=0) + for output in llm_spec.generate_async(tok_ids, + sampling_params, + streaming=True): + pass + + if __name__ == "__main__": unittest.main() diff --git a/tests/unittest/_torch/thop/test_causal_conv1d_op.py b/tests/unittest/_torch/thop/test_causal_conv1d_op.py index c5e42e2618..54793854c9 100644 --- a/tests/unittest/_torch/thop/test_causal_conv1d_op.py +++ b/tests/unittest/_torch/thop/test_causal_conv1d_op.py @@ -26,11 +26,15 @@ from tensorrt_llm.llmapi.utils import get_total_gpu_memory @pytest.mark.parametrize( - "dim, dconv, req_type, dtype, batch_size, max_seq_len, remove_padding, apply_silu, paged_cache", + "dim, dconv, req_type, dtype, batch_size, max_seq_len, remove_padding, apply_silu, paged_cache, use_initial_state", list( product([2048], [4], ['context', 'generation'], ['float16', 'float32', 'bfloat16'], [5], [16], [False, True], - [False, True], [False, True])) + + [False, True], [False, True], [False])) + + # test with initial state + list( + product([2048], [4], ['context'], ['bfloat16'], [5], [16], + [False, True], [False], [False, True], [True])) + # long sequence tests to cover the int overflow issue list( map( @@ -42,10 +46,11 @@ from tensorrt_llm.llmapi.utils import get_total_gpu_memory "The long sequence test needs at least 33GB memory, skipping" )), product([5376], [4], ['context'], ['float16', 'bfloat16'], [2], - [131072], [False, True], [False, True], [False])))) + [131072], [False, True], [False, True], [False], [False])))) @pytest.mark.high_cuda_memory def test_causal_conv1d(dim, dconv, req_type, dtype, batch_size, max_seq_len, - remove_padding, apply_silu, paged_cache): + remove_padding, apply_silu, paged_cache, + use_initial_state): device = "cuda" seq_len = max_seq_len if req_type == "context" else 1 mean = 0.0 @@ -68,7 +73,7 @@ def test_causal_conv1d(dim, dconv, req_type, dtype, batch_size, max_seq_len, host_context_lengths = torch.ones( (batch_size, ), dtype=torch.int32) * seq_len - if req_type == "context": + if req_type == "context" and not use_initial_state: conv_state = torch.zeros([batch_size, dim, dconv - 1], dtype=torch_dtype, device=device) @@ -111,7 +116,8 @@ def test_causal_conv1d(dim, dconv, req_type, dtype, batch_size, max_seq_len, conv_weight_input = conv_weight.squeeze(1).contiguous() if req_type == "context": - has_initial_state = None + has_initial_state = None if not use_initial_state else torch.ones( + batch_size, device=device, dtype=torch.bool) torch.ops.trtllm.causal_conv1d_fwd( x_in_out, diff --git a/tests/unittest/_torch/thop/test_mamba2_chunk_ss_update.py b/tests/unittest/_torch/thop/test_mamba2_chunk_ss_update.py index ea3c2c2c3c..e26fe00776 100644 --- a/tests/unittest/_torch/thop/test_mamba2_chunk_ss_update.py +++ b/tests/unittest/_torch/thop/test_mamba2_chunk_ss_update.py @@ -21,6 +21,8 @@ from einops import rearrange, repeat from utils.torch_ref import (selective_state_update_ref, ssd_chunk_scan_combined_ref) +from tensorrt_llm._torch.modules.mamba.mamba2_metadata import \ + cu_seqlens_to_chunk_indices_offsets from tensorrt_llm._torch.modules.mamba.selective_state_update import \ selective_state_update from tensorrt_llm._torch.modules.mamba.ssd_combined import \ @@ -30,51 +32,58 @@ from tensorrt_llm.llmapi.utils import get_total_gpu_memory @pytest.mark.parametrize( - "dim, headdim, ngroups, dstate, req_type, dtype, batch_size, max_seq_len, has_z, remove_padding, paged_cache", + "dim, headdim, ngroups, dstate, req_type, dtype, batch_size, max_seq_len, has_z, remove_padding, paged_cache, use_initial_states", # dim parametrization list( product([1024, 2048, 5120], [64], [1], [128], ['context', 'generation'], - ['bfloat16'], [3], [16], [False], [True], [False])) + + ['bfloat16'], [3], [16], [False], [True], [False], [False])) + # headdim parametrization list( product([2048], [32, 64, 128, 256], [1], [128], ['context', 'generation'], ['bfloat16'], [3], [16], [False], - [True], [False])) + + [True], [False], [False])) + # ngroups parametrization list( product([2048], [64], [1, 4], [128], ['context', 'generation'], - ['bfloat16'], [3], [16], [False], [True], [False])) + + ['bfloat16'], [3], [16], [False], [True], [False], [False])) + # dstate parametrization list( product([2048], [64], [1], [64, 96, 128, 256], ['context', 'generation'], ['bfloat16'], [3], [16], [False], - [True], [False])) + + [True], [False], [False])) + # dtype parametrization list( product([2048], [64], [1], [128], ['context', 'generation'], ['float16', 'bfloat16', 'float32'], [3], [16], [False], [True], - [False])) + + [False], [False])) + # batch_size parametrization list( product([2048], [64], [1], [128], ['context', 'generation'], - ['bfloat16'], [1, 2, 8, 16], [16], [False], [True], [False])) + + ['bfloat16'], [1, 2, 8, 16], [16], [False], [True], [False], + [False])) + # max_seq_len parametrization list( product([2048], [64], [1], [128], ['context', 'generation'], ['bfloat16'], [3], [32, 64, 256, 2048, 16384], [False], [True], - [False])) + + [False], [False])) + # has_z parametrization list( product([2048], [64], [1], [128], ['context', 'generation'], - ['bfloat16'], [3], [32], [True, False], [True], [False])) + + ['bfloat16'], [3], [32], [True, False], [True], [False], + [False])) + # remove_padding parametrization list( product([2048], [64], [1], [128], ['context', 'generation'], - ['bfloat16'], [3], [32], [False], [True, False], [False])) + + ['bfloat16'], [3], [32], [False], [True, False], [False], + [False])) + # paged_cache parametrization (relevant for generation only) list( product([2048], [64], [1], [128], ['generation'], ['bfloat16'], [3], - [32], [False], [False], [True, False])) + + [32], [False], [False], [True, False], [False])) + + # use_initial_states parametrization (relevant for context only and remove_padding=True) + list( + product([2048], [64], [1], [128], ['context'], ['bfloat16'], [3], [32], + [False], [True], [False], [True, False])) + # long sequence test to cover the int overflow issue [ pytest.param( @@ -89,6 +98,7 @@ from tensorrt_llm.llmapi.utils import get_total_gpu_memory False, False, False, + False, marks=pytest.mark.skipif( get_total_gpu_memory(0) < 68 * 1024**3, reason= @@ -97,7 +107,8 @@ from tensorrt_llm.llmapi.utils import get_total_gpu_memory def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate, req_type, dtype, batch_size, max_seq_len, has_z, - remove_padding, paged_cache): + remove_padding, paged_cache, + use_initial_states): # configs device = "cuda" seq_len = max_seq_len if req_type == 'context' else 1 @@ -168,6 +179,8 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate, D = torch.randn(nheads, device=device) if has_z: z = torch.randn_like(x) + if use_initial_states: + initial_states = state.clone() if req_type == 'generation': # remove the seqlen dimension @@ -193,8 +206,13 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate, C_ref = C.detach().clone() D_ref = D.detach().clone() z_ref = z.detach().clone() if has_z else None + initial_states_ref = state_ref.clone() if use_initial_states else None if req_type == "context": + if use_initial_states: + assert remove_padding + chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets( + cu_seqlens, chunk_size) out, ssm_state = mamba_chunk_scan_combined( x, dt, @@ -205,6 +223,9 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate, D=D, z=z if has_z else None, dt_bias=dt_bias, + initial_states=initial_states if use_initial_states else None, + chunk_indices=chunk_indices if use_initial_states else None, + chunk_offsets=chunk_offsets if use_initial_states else None, seq_idx=seq_idx if remove_padding else None, cu_seqlens=cu_seqlens if remove_padding else None, dt_softplus=delta_softplus, @@ -273,7 +294,10 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate, D=D_ref, z=z_ref[:, start:end, ...] if has_z else None, dt_bias=dt_bias_ref, - dt_softplus=delta_softplus) + dt_softplus=delta_softplus, + initial_states=initial_states_ref[i:i + 1, ...] + if use_initial_states else None, + ) out_ref[0, start:end, ...] = part_out_ref.squeeze(0) state_ref[i, ...] = part_state_ref.squeeze(0) elif long_context: @@ -295,7 +319,10 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate, D=D_ref, z=z_ref[i:i + 1, ...] if has_z else None, dt_bias=dt_bias_ref, - dt_softplus=delta_softplus) + dt_softplus=delta_softplus, + initial_states=initial_states_ref[i:i + 1, ...] + if use_initial_states else None, + ) out_ref[i, ...] = part_out_ref.squeeze(0) state_ref[i, ...] = part_state_ref.squeeze(0) else: @@ -309,7 +336,10 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate, D=D_ref, z=z_ref if has_z else None, dt_bias=dt_bias_ref, - dt_softplus=delta_softplus) + dt_softplus=delta_softplus, + initial_states=initial_states_ref + if use_initial_states else None, + ) elif req_type == 'generation': out_ref = selective_state_update_ref(state_ref, x_ref, @@ -330,3 +360,229 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate, state_ref, rtol=1e-2, atol=atol[dtype]) + + +@pytest.mark.parametrize("mamba_chunk_size", [8, 256]) +@pytest.mark.parametrize("seqlens", [ + (16, 2, 8, 13), + (270, 88, 212, 203), + (16, 20), +]) +def test_mamba2_chunk_scan_combined_prefill_chunking(mamba_chunk_size, seqlens): + dim = 1024 + headdim = 64 + ngroups = 1 + dstate = 128 + + # test in high precision to distinguish between numeric instabilities and actual errors + dtype = 'float32' + + num_sequences = len(seqlens) + has_z = True + + device = "cuda" + nheads = dim // headdim + delta_softplus = True + mean = 0.0 + std_dev = 0.1 + + torch_dtype = str_dtype_to_torch(dtype) + + seqlens = torch.tensor(seqlens, dtype=torch.int32, device=device) + cu_seqlens = torch.cat([ + torch.tensor([0], dtype=torch.int32, device=device), + torch.cumsum(seqlens, dim=0, dtype=torch.int32) + ], + dim=0) + seq_idx = torch.repeat_interleave(torch.arange(len(seqlens), + dtype=torch.int32, + device=device), + seqlens, + output_size=cu_seqlens[-1]).unsqueeze(0) + input_batch_size = 1 + input_seq_len = cu_seqlens[-1] + + # test data + torch.random.manual_seed(0) + x = torch.empty(input_batch_size, + input_seq_len, + nheads, + headdim, + device=device, + dtype=torch_dtype) + x.normal_(mean, std_dev) + dt = torch.randn(input_batch_size, + input_seq_len, + nheads, + device=device, + dtype=torch_dtype) + dt_bias = torch.rand(nheads, device=device) - 4.0 + A = -torch.rand(nheads, device=device) - 1.0 + B = torch.randn(input_batch_size, + input_seq_len, + ngroups, + dstate, + device=device, + dtype=torch_dtype) + C = torch.randn_like(B) + D = torch.randn(nheads, device=device) + + z = torch.randn_like(x) + + ## full seqlen computation + out_ref, state_ref = mamba_chunk_scan_combined( + x, + dt, + A, + B, + C, + chunk_size=mamba_chunk_size, + D=D, + z=z if has_z else None, + dt_bias=dt_bias, + seq_idx=seq_idx, + cu_seqlens=cu_seqlens, + dt_softplus=delta_softplus, + return_final_states=False, + return_varlen_states=True, + ) + + ## chunked seqlen computation + # first chunk + chunked_seqlens = seqlens // 2 + chunked_cu_seqlens = torch.cat([ + torch.tensor([0], dtype=torch.int32, device=device), + torch.cumsum(chunked_seqlens, dim=0, dtype=torch.int32) + ], + dim=0) + chunked_seq_idx = torch.repeat_interleave( + torch.arange(len(chunked_seqlens), dtype=torch.int32, device=device), + chunked_seqlens, + output_size=chunked_cu_seqlens[-1]).unsqueeze(0) + chunked_input_seq_len = chunked_cu_seqlens[-1] + x_chunked = torch.zeros_like(x)[:, :chunked_input_seq_len, ...] + dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...] + B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...] + C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...] + z_chunked = torch.zeros_like(z)[:, :chunked_input_seq_len, ...] + for i in range(num_sequences): + # yapf: disable + chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] + + x_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(x, i) + dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) + B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) + C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) + z_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(z, i) + # yapf: enable + + partial_out, partial_state = mamba_chunk_scan_combined( + x_chunked, + dt_chunked, + A, + B_chunked, + C_chunked, + chunk_size=mamba_chunk_size, + D=D, + z=z_chunked, + dt_bias=dt_bias, + seq_idx=chunked_seq_idx, + cu_seqlens=chunked_cu_seqlens, + dt_softplus=delta_softplus, + return_final_states=False, + return_varlen_states=True, + ) + + # remaining chunk + remaining_chunked_seqlens = seqlens - chunked_seqlens + remaining_chunked_cu_seqlens = torch.cat([ + torch.tensor([0], dtype=torch.int32, device=device), + torch.cumsum(remaining_chunked_seqlens, dim=0, dtype=torch.int32) + ], + dim=0) + remaining_chunked_seq_idx = torch.repeat_interleave( + torch.arange(len(remaining_chunked_seqlens), + dtype=torch.int32, + device=device), + remaining_chunked_seqlens, + output_size=remaining_chunked_cu_seqlens[-1]).unsqueeze(0) + remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1] + # yapf: disable + remaining_x_chunked = torch.zeros_like(x)[:, :remaining_chunked_input_seq_len, ...] + remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] + remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] + remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] + remaining_z_chunked = torch.zeros_like(z)[:, :remaining_chunked_input_seq_len, ...] + for i in range(num_sequences): + remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] + + remaining_x_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(x, i) + remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) + remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) + remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) + remaining_z_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(z, i) + + # assert input chunking is correct + concat_chunk_f = lambda pt1, pt2, i: torch.cat([ + pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...], + pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...], + ], + dim=1) + concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) + + assert concat_batch_f(x_chunked, remaining_x_chunked).equal(x) + assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt) + assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B) + assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C) + assert concat_batch_f(z_chunked, remaining_z_chunked).equal(z) + # yapf: enable + + chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets( + remaining_chunked_cu_seqlens, mamba_chunk_size) + + out_chunked, state_chunked = mamba_chunk_scan_combined( + remaining_x_chunked, + remaining_dt_chunked, + A, + remaining_B_chunked, + remaining_C_chunked, + chunk_size=mamba_chunk_size, + D=D, + z=remaining_z_chunked, + dt_bias=dt_bias, + initial_states=partial_state, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + seq_idx=remaining_chunked_seq_idx, + cu_seqlens=remaining_chunked_cu_seqlens, + dt_softplus=delta_softplus, + return_final_states=False, + return_varlen_states=True, + ) + out = concat_batch_f(partial_out, out_chunked) + + # kernel chunked is same as kernel overall + # tight tolerance to find subtle correctness issues + rtol = 1e-2 + atol = 2e-3 + for i in range(num_sequences): + out_seq = out[:, cu_seqlens[i]:cu_seqlens[i + 1], ...] + out_seq_ref = out_ref[:, cu_seqlens[i]:cu_seqlens[i + 1], ...] + torch.testing.assert_close(out_seq[:, :chunked_seqlens[i], ...], + out_seq_ref[:, :chunked_seqlens[i], ...], + rtol=rtol, + atol=atol, + msg=lambda x: f"seq{i} output part1 " + x) + torch.testing.assert_close(out_seq[:, chunked_seqlens[i]:, ...], + out_seq_ref[:, chunked_seqlens[i]:, ...], + rtol=rtol, + atol=atol, + msg=lambda x: f"seq{i} output part2 " + x) + + state_seq = state_chunked[i] + state_seq_ref = state_ref[i] + torch.testing.assert_close(state_seq, + state_seq_ref, + rtol=rtol, + atol=atol, + msg=lambda x: f"seq{i} state " + x) diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 6b78c46bd7..66d946d5c6 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -4,6 +4,7 @@ from contextlib import contextmanager, nullcontext import pytest from tensorrt_llm import LLM +from tensorrt_llm.executor import GenerationExecutorWorker from tensorrt_llm.llmapi import KvCacheConfig from tensorrt_llm.llmapi.llm_args import PeftCacheConfig from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer @@ -818,3 +819,40 @@ class TestLlmError: match="should not exceed max_num_tokens"): ids = [random.randint(10, 100) for _ in range(101)] llm.generate([ids]) + + +class FailingExecutorWorker(GenerationExecutorWorker): + """Mock worker that fails during initialization to test error handling.""" + + def __init__(self, *args, **kwargs): + # Simulate a constructor failure + raise RuntimeError( + "Mock GenerationExecutorWorker initialization failed") + + +FailingExecutor = type( + "FailingExecutor", (), { + "create": + classmethod( + lambda cls, *args, **kwargs: FailingExecutorWorker(*args, **kwargs)) + }) + + +def test_llm_with_proxy_error(): + """Test that LLM properly handles GenerationExecutorWorker constructor failures. + + This test mocks the GenerationExecutorWorker to fail during __init__ and + verifies that the LLM class properly catches and re-raises the error. + """ + from unittest.mock import patch + + # Test that the error is properly caught and re-raised by LLM + # We patch GenerationExecutor.create directly to return our failing worker + with patch('tensorrt_llm.executor.executor.GenerationExecutor.create', + side_effect=lambda *args, **kwargs: FailingExecutorWorker( + *args, **kwargs)): + with pytest.raises( + RuntimeError, + match="Mock GenerationExecutorWorker initialization failed"): + llm = LLM(model=llama_model_path, + kv_cache_config=global_kvcache_config) diff --git a/tests/unittest/utils/torch_ref.py b/tests/unittest/utils/torch_ref.py index 6e666bed26..d8a6b258c5 100644 --- a/tests/unittest/utils/torch_ref.py +++ b/tests/unittest/utils/torch_ref.py @@ -480,7 +480,8 @@ def ssd_chunk_scan_combined_ref(x, D=None, z=None, dt_bias=None, - dt_softplus=False): + dt_softplus=False, + initial_states=None): """ Argument: x: (batch, seqlen, nheads, headdim) @@ -492,6 +493,7 @@ def ssd_chunk_scan_combined_ref(x, D: (nheads, headdim) or (nheads,) z: (batch, seqlen, nheads, headdim) dt_bias: (nheads,) + initial_states: (batch, nheads, dstate, headdim) Return: out: (batch, seqlen, nheads, headdim) final_states: (batch, nheads, dstate, headdim) @@ -520,8 +522,16 @@ def ssd_chunk_scan_combined_ref(x, states = states.to(torch.float32) # 2. Pass the state to all the chunks by weighted cumsum. # state_passing_ref is much less numerically stable + # align initial_states shape with states shape + initial_states = rearrange( + initial_states, + "... n p -> ... p n") if initial_states is not None else None states, final_states = state_passing_ref( - rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]) + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum[:, :, :, -1], + rearrange(initial_states, "... p n-> ... (p n)") + if initial_states is not None else None, + ) states, final_states = [ rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states] diff --git a/triton_backend/all_models/inflight_batcher_llm/preprocessing/1/model.py b/triton_backend/all_models/inflight_batcher_llm/preprocessing/1/model.py index 5a7d36ac1e..549f3f210d 100755 --- a/triton_backend/all_models/inflight_batcher_llm/preprocessing/1/model.py +++ b/triton_backend/all_models/inflight_batcher_llm/preprocessing/1/model.py @@ -29,13 +29,13 @@ import io import json import os from collections import defaultdict -from typing import List +from typing import Dict, List, Tuple import numpy as np import requests import triton_python_backend_utils as pb_utils from PIL import Image -from transformers import AutoProcessor, AutoTokenizer, T5Tokenizer +from transformers import AutoConfig, AutoProcessor, AutoTokenizer, T5Tokenizer class TritonPythonModel: @@ -136,9 +136,9 @@ class TritonPythonModel: 'model_type'] assert self.model_type in [ - 'llava', 'blip2-opt', 'vila', 'mllama', 'llava_onevision', - 'qwen2_vl' - ], f"[TensorRT-LLM][ERROR] Currently supported multi-modal models are llava, blip2-opt, vila, mllama, llava_onevision and qwen2_vl. Got {self.model_type}." + 'llava', 'blip2-opt', 'pixtral', 'vila', 'mllama', + 'llava_onevision', 'qwen2_vl' + ], f"[TensorRT-LLM][ERROR] Currently supported multi-modal models are llava, blip2-opt, pixtral, vila, mllama, llava_onevision and qwen2_vl. Got {self.model_type}." assert self.model_type != 'llava_onevison' or self.max_num_images is None or self.max_num_images <= 1, f"LLaVA-OneVsion is not support multi image inference currently." @@ -151,10 +151,18 @@ class TritonPythonModel: llm_model_config["pretrained_config"]["vocab_size"]) self._setup_ptable_shape(llm_model_config) - if self.model_type in ['mllama', 'llava_onevision', 'qwen2_vl']: + if self.model_type in [ + 'mllama', 'llava_onevision', 'qwen2_vl', 'pixtral' + ]: + full_processor = AutoProcessor.from_pretrained( + tokenizer_dir, trust_remote_code=True) + self.hf_config = AutoConfig.from_pretrained(tokenizer_dir) self.vision_preprocessor = VisionPreProcessor( self.model_type, - AutoProcessor.from_pretrained(tokenizer_dir), model_config) + full_processor, + model_config, + self.hf_config, + ) # Parse model output configs and convert Triton types to numpy types output_names = [ @@ -285,7 +293,9 @@ class TritonPythonModel: request, 'VIDEO_BYTES') vision_processed_tensors = [] visual_tokens = [] - if self.is_multimodal and (img_urls or image_bytes or video_bytes): + # Pixtral supports text-only input + if self.is_multimodal and (img_urls or image_bytes or video_bytes + or self.model_type == 'pixtral'): assert self.vision_preprocessor != None, "Vision preprocessor for preparing images before encoding is None" processed_tensors = {} if self.model_type == 'mllama': @@ -317,6 +327,19 @@ class TritonPythonModel: qwen2vl_input_length_tensor = processed_tensors.get( "REQUEST_INPUT_LEN") processed_tensors.pop("REQUEST_INPUT_LEN") + elif self.model_type == 'pixtral': + image_sizes = pb_utils.get_input_tensor_by_name( + request, 'IMAGE_SIZES') + processed_tensors, visual_tokens = self.vision_preprocessor.pixtral_process( + queries=query.astype(str).tolist(), + img_urls=img_urls, + image_bytes=image_bytes, + image_sizes=image_sizes, + ) + pixtral_input_id_tensor = processed_tensors.pop("INPUT_IDS") + request_input_len = np.array( + [[len(input_ids_for_batch)] + for input_ids_for_batch in pixtral_input_id_tensor]) else: raise ValueError( "Unsupported model type for IMAGE_BYTES or IMAGE_URL inputs" @@ -330,8 +353,9 @@ class TritonPythonModel: # Preprocessing input data. # For the LLaVA_OneVision model, num_multimodal_features is not a fixed value - input_id, request_input_len = self._create_request( - query, visual_tokens) + if self.model_type != 'pixtral': + input_id, request_input_len = self._create_request( + query, visual_tokens) if decoder_query is not None: decoder_input_id, request_decoder_input_len = self._create_request( decoder_query) @@ -362,6 +386,13 @@ class TritonPythonModel: 'INPUT_ID', qwen2vl_input_id_tensor) request_input_len_tensor = pb_utils.Tensor.from_dlpack( 'REQUEST_INPUT_LEN', qwen2vl_input_length_tensor) + elif self.model_type == 'pixtral': + input_id_tensor = pb_utils.Tensor( + 'INPUT_ID', + pixtral_input_id_tensor.numpy().astype(self.input_id_dtype)) + request_input_len_tensor = pb_utils.Tensor( + 'REQUEST_INPUT_LEN', + request_input_len.astype(self.request_input_len_dtype)) else: input_id_tensor = pb_utils.Tensor( 'INPUT_ID', input_id.astype(self.input_id_dtype)) @@ -719,7 +750,10 @@ class VisionPreProcessor: def __init__(self, vision_model_type, vision_model_processor, - preprocessor_model_config={}): + preprocessor_model_config=None, + hf_config=None): + preprocessor_model_config = preprocessor_model_config or {} + # import libraries that are only relevant for multimodal models import torch from torch.utils.dlpack import from_dlpack @@ -767,6 +801,12 @@ class VisionPreProcessor: self.vision_model_processor = vision_model_processor self.vision_model_type = vision_model_type + if vision_model_type == 'pixtral': + assert hf_config is not None, "Pixtral model requires hf_config to be set" + self.vocab_size = hf_config.text_config.vocab_size + self.image_size = hf_config.vision_config.image_size + self.image_token_index = hf_config.image_token_index + def load_images_from_urls(self, img_urls): images = [] for img_url in img_urls: @@ -777,10 +817,11 @@ class VisionPreProcessor: image_data = base64.b64decode(image_base64) # Create a BytesIO object from the decoded data image_buffer = io.BytesIO(image_data) - images.append(Image.open(image_buffer)) + images.append(Image.open(image_buffer).convert("RGB")) else: - images.append(Image.open( - requests.get(img_url, stream=True).raw)) + images.append( + Image.open(requests.get(img_url, + stream=True).raw).convert("RGB")) return images def mllama_process(self, queries, img_urls=None, image_bytes=None): @@ -879,6 +920,9 @@ class VisionPreProcessor: mode='constant') for image in preprocessor_outputs['PIXEL_VALUES'] ] + # Add a dimension image_sizes to match the dimensions defined in config.pbtxt + for elem in preprocessor_outputs['IMAGE_SIZES']: + elem.unsqueeze_(1) for key, tensor_list in preprocessor_outputs.items(): val = self.convert_tensor_list_to_tensor(tensor_list) if key in self.output_str_dtypes: @@ -1001,3 +1045,130 @@ class VisionPreProcessor: val, self.output_str_dtypes[key]) vision_processed_tensors[key] = val return vision_processed_tensors + + def pixtral_process(self, + queries, + img_urls=None, + image_bytes=None, + image_sizes=None + ) -> Tuple[Dict[str, "torch.Tensor"], List[int]]: + import torch + vision_processed_tensors = {} + if img_urls is not None: + assert image_sizes is None, "IMAGE_SIZES should not be supplied together with IMAGE_URL" + # download and read images + images = [ + self.load_images_from_urls(urls) + for urls in img_urls.as_numpy() + ] + images = [[np.array(img) for img in batch] for batch in images] + + # pad to the max_h, max_w dimensions to create one tensor for all images + shapes = [img.shape for batch in images for img in batch] + assert all( + len(s) == 3 + for s in shapes), "All input images must have three dimensions" + assert all( + s[-1] == shapes[0][-1] for s in shapes + ), "All input images must have the same number of channels" + max_h, max_w = max(s[0] for s in shapes), max(s[1] for s in shapes) + for batch_idx in range(len(images)): + for image_idx in range(len(images[batch_idx])): + images[batch_idx][image_idx] = np.pad( + images[batch_idx][image_idx], + ((0, max_h - images[batch_idx][image_idx].shape[0]), + (0, max_w - images[batch_idx][image_idx].shape[1]), + (0, 0)), + mode='constant', + ) + images = np.array(images) + elif image_bytes is not None: + images = self.load_images_tensor(image_bytes) + else: + images = np.empty((len(queries), 0, 0, 0, 0), dtype=np.uint8) + + batch_size = len(images) + assert len( + queries + ) == batch_size, f"Image must have the same batch size as Query." + + if image_sizes is not None: + image_sizes = self.load_images_tensor(image_sizes) + else: + s = images.shape + image_sizes = np.array([[[s[2], s[3]]] * s[1]] * s[0]) + + preprocessor_outputs = {} + possible_output_names = ['PIXEL_VALUES', 'IMAGE_SIZES', 'INPUT_IDS'] + visual_tokens = [] + for batch_id in range(batch_size): + # Preprocess images and query + query = queries[batch_id] + if not isinstance(query, (str, bytes)): + query = query[0] + if isinstance(query, bytes): + query = query.decode("utf-8") + if "[IMG]" not in query: + query = "[IMG]" * len(images[batch_id]) + query + assert query.count("[IMG]") == len( + images[batch_id] + ), "Number of [IMG] tags must match number of images" + + if not query.startswith("[INST]"): + query = "[INST]" + query + if not query.endswith("[/INST]"): + query = query + "[/INST]" + + sizes = image_sizes[batch_id] + curr_images = [ + img[:sizes[idx][0], :sizes[idx][1], :] + for idx, img in enumerate(images[batch_id]) + ] + if not curr_images: + curr_images = None + + processed_vision_data = self.vision_model_processor( + images=curr_images, text=query, return_tensors="pt") + visual_tokens.append(processed_vision_data['input_ids'].shape[1]) + if "pixel_values" in processed_vision_data: + # Pad to self.image_size x self.image_size + processed_vision_data['pixel_values'] = torch.nn.functional.pad( + processed_vision_data['pixel_values'], ( + 0, + self.image_size - + processed_vision_data['pixel_values'].shape[-1], + 0, + self.image_size - + processed_vision_data['pixel_values'].shape[-2], + ), + mode='constant') + # Create vision output tensors + for key in possible_output_names: + val = processed_vision_data.get(key.lower()) + if val is not None: + if key not in preprocessor_outputs: + preprocessor_outputs[key] = [] + if key != 'INPUT_IDS': + val.unsqueeze_(0) # unsqueeze to add batch dimension + preprocessor_outputs[key].append(val) + + for key, tensor_list in preprocessor_outputs.items(): + val = self.convert_tensor_list_to_tensor(tensor_list) + if key in self.output_str_dtypes: + val = self.convert_tensor_to_str_dtype( + val, self.output_str_dtypes[key]) + vision_processed_tensors[key] = val + + # Replace all image tokens with a unique token_id > vocab_size. + # This shall be used to lookup the prompt table. + for batch_id in range(batch_size): + # Note: We reset replacer to vocab_size for each sample. This is as opposed to doing `replacer = vocab_size + img_idx * tokens_per_task`. + # That part of the look-up manipulation is done by the `task_ids` input to PromptEmbedding forward. + replacer = self.vocab_size + input_ids = vision_processed_tensors['INPUT_IDS'][batch_id] + for token_idx in range(len(input_ids)): + if input_ids[token_idx] == self.image_token_index: + input_ids[token_idx] = replacer + replacer += 1 + + return vision_processed_tensors, visual_tokens diff --git a/triton_backend/all_models/inflight_batcher_llm/preprocessing/config.pbtxt b/triton_backend/all_models/inflight_batcher_llm/preprocessing/config.pbtxt index b21585e4bd..ed819b7b60 100755 --- a/triton_backend/all_models/inflight_batcher_llm/preprocessing/config.pbtxt +++ b/triton_backend/all_models/inflight_batcher_llm/preprocessing/config.pbtxt @@ -55,7 +55,14 @@ input [ { name: "IMAGE_URL" data_type: TYPE_STRING - dims: [ 1 ] + dims: [ -1 ] + optional: true + }, + # Required for pixtral + { + name: "IMAGE_SIZES" + data_type: TYPE_INT64 + dims: [ -1, 2 ] optional: true }, { @@ -188,11 +195,11 @@ output [ data_type: TYPE_INT64 dims: [ -1, -1, -1 ] }, - # Required for image postprocessing in the llava_onevision model + # Required for image postprocessing in the llava_onevision and pixtral models { name: "IMAGE_SIZES" data_type: TYPE_INT64 - dims: [ 2 ] + dims: [ -1, 2 ] }, # Indicates if the input is video in the llava_onevision model { diff --git a/triton_backend/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py b/triton_backend/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py index 4f8863465b..ab165323c1 100755 --- a/triton_backend/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py +++ b/triton_backend/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py @@ -280,7 +280,9 @@ def get_prompt_tuning_config_from_request(request, kwargs = {} prompt_embedding_table = get_input_tensor_by_name(request, 'prompt_embedding_table', - batch_size, batch_index) + batch_size, + batch_index, + force_on_torch=True) prompt_table_extra_ids = get_input_tensor_by_name(request, 'prompt_table_extra_ids', batch_size, batch_index) diff --git a/triton_backend/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt b/triton_backend/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt index 4f06581c04..f5f6cb41a4 100644 --- a/triton_backend/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt +++ b/triton_backend/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt @@ -319,7 +319,7 @@ input [ }, { name: "prompt_embedding_table" - data_type: TYPE_FP16 + data_type: ${prompt_embedding_table_data_type} dims: [ -1, -1 ] optional: true allow_ragged_batch: true diff --git a/triton_backend/all_models/inflight_batcher_llm/tensorrt_llm_bls/1/lib/decode.py b/triton_backend/all_models/inflight_batcher_llm/tensorrt_llm_bls/1/lib/decode.py index 566e62cb7c..e6ed45d185 100644 --- a/triton_backend/all_models/inflight_batcher_llm/tensorrt_llm_bls/1/lib/decode.py +++ b/triton_backend/all_models/inflight_batcher_llm/tensorrt_llm_bls/1/lib/decode.py @@ -103,6 +103,7 @@ class Request: request_id: Optional[str] = None mrope_rotary_cos_sin: Optional[np.ndarray] = None mrope_position_deltas: Optional[np.ndarray] = None + image_sizes_input: Optional[np.ndarray] = None def validate(self): _validate_non_empty(self.text_input, "text_input is required") diff --git a/triton_backend/all_models/inflight_batcher_llm/tensorrt_llm_bls/1/lib/triton_decoder.py b/triton_backend/all_models/inflight_batcher_llm/tensorrt_llm_bls/1/lib/triton_decoder.py index 00bd315b13..2c9c5b8055 100644 --- a/triton_backend/all_models/inflight_batcher_llm/tensorrt_llm_bls/1/lib/triton_decoder.py +++ b/triton_backend/all_models/inflight_batcher_llm/tensorrt_llm_bls/1/lib/triton_decoder.py @@ -165,7 +165,12 @@ class TritonDecoder(Decoder): continue triton_name = tensor.name() if tensor.is_cpu(): - value = tensor.as_numpy() + try: + value = tensor.as_numpy() + except pb_utils.TritonModelException as e: + # Use to_dlpack()/from_dlpack() if as_numpy() fails, + # e.g. in case of BF16 tensors + value = from_dlpack(tensor.to_dlpack()) else: # If the tensor is in GPU memory make it torch.Tensor type value = from_dlpack(tensor.to_dlpack()) @@ -247,6 +252,7 @@ class TritonDecoder(Decoder): "text_input": "QUERY", "image_bytes_input": "IMAGE_BYTES", "image_url_input": "IMAGE_URL", + "image_sizes_input": "IMAGE_SIZES", "video_bytes_input": "VIDEO_BYTES", "decoder_text_input": "DECODER_QUERY", "max_tokens": "REQUEST_OUTPUT_LEN", diff --git a/triton_backend/all_models/inflight_batcher_llm/tensorrt_llm_bls/config.pbtxt b/triton_backend/all_models/inflight_batcher_llm/tensorrt_llm_bls/config.pbtxt index 7f38bf903a..7ad5ccf9f1 100644 --- a/triton_backend/all_models/inflight_batcher_llm/tensorrt_llm_bls/config.pbtxt +++ b/triton_backend/all_models/inflight_batcher_llm/tensorrt_llm_bls/config.pbtxt @@ -62,6 +62,13 @@ input [ dims: [ 1 ] optional: true }, + # An arbitrary number of images for pixtral + { + name: "image_sizes_input" + data_type: TYPE_INT64 + dims: [ -1, 2 ] + optional: true + }, { name: "video_bytes_input" data_type: TYPE_UINT8 @@ -199,7 +206,7 @@ input [ }, { name: "prompt_embedding_table" - data_type: TYPE_FP16 + data_type: ${prompt_embedding_table_data_type} dims: [ -1, -1 ] optional: true }, diff --git a/triton_backend/all_models/multimodal/ensemble/config.pbtxt b/triton_backend/all_models/multimodal/ensemble/config.pbtxt index d3affefabf..777118eefe 100755 --- a/triton_backend/all_models/multimodal/ensemble/config.pbtxt +++ b/triton_backend/all_models/multimodal/ensemble/config.pbtxt @@ -54,9 +54,16 @@ input [ { name: "image_url_input" data_type: TYPE_STRING - dims: [ 1 ] + dims: [ -1 ] optional: true }, + # An arbitrary number of images for pixtral + { + name: "image_sizes_input" + data_type: TYPE_INT64 + dims: [ -1, 2 ] + optional: true + }, { name: "video_bytes_input" data_type: TYPE_UINT8 @@ -253,6 +260,10 @@ ensemble_scheduling { key: "IMAGE_URL" value: "image_url_input" } + input_map { + key: "IMAGE_SIZES" + value: "image_sizes_input" + } input_map { key: "VIDEO_BYTES" value: "video_bytes_input" diff --git a/triton_backend/all_models/multimodal/multimodal_encoders/1/model.py b/triton_backend/all_models/multimodal/multimodal_encoders/1/model.py index acf601ac80..adb968d461 100755 --- a/triton_backend/all_models/multimodal/multimodal_encoders/1/model.py +++ b/triton_backend/all_models/multimodal/multimodal_encoders/1/model.py @@ -112,6 +112,8 @@ class TritonPythonModel: self.image_session = Session.from_serialized_engine(engine_buffer) self.vision_dtype_str = visual_config['builder_config']['precision'] + self.vision_max_batch_size = visual_config['builder_config'][ + 'max_batch_size'] features_output_name = "OUT_PROMPT_EMBEDDING_TABLE" if self.model_type == "mllama": features_output_name = "ENCODER_INPUT_FEATURES" @@ -162,7 +164,21 @@ class TritonPythonModel: self.vocab_size = hf_config.vocab_size self.qwen2vl_utils = Qwen2VLUtils(hf_config) - def get_requests(self, request: List) -> Dict[str, torch.Tensor]: + if self.model_type == 'pixtral': + from transformers import AutoConfig + hf_model_path = model_config['parameters'].get( + 'hf_model_path', None) + assert hf_model_path is not None and hf_model_path[ + 'string_value'] != "${hf_model_path}", "Need to provide hf_model_path for the Pixtral model" + hf_config = AutoConfig.from_pretrained( + hf_model_path['string_value']) + self.image_size = hf_config.vision_config.image_size + self.patch_size = hf_config.vision_config.patch_size + self.vocab_size = hf_config.text_config.vocab_size + self.spatial_merge_size = hf_config.spatial_merge_size + self.relevant_patch_size = self.patch_size * self.spatial_merge_size + + def get_requests(self, request) -> Dict[str, torch.Tensor]: """ Processes the incoming request to extract and organize input tensors for different model types. @@ -193,8 +209,10 @@ class TritonPythonModel: img_tensor = (pb_utils.get_input_tensor_by_name(request, 'pixel_values') or pb_utils.get_input_tensor_by_name(request, 'IMAGE')) - # mllama supports img_tensor is None case - assert img_tensor != None or self.model_type == 'mllama', "There is no preprocessed image tensor to encode" + # mllama and pixtral support img_tensor is None case + assert img_tensor != None or self.model_type in [ + 'mllama', 'pixtral' + ], "There is no preprocessed image tensor to encode" if img_tensor is not None: img_tensor = from_dlpack(img_tensor.to_dlpack()) @@ -242,6 +260,9 @@ class TritonPythonModel: image_sizes = from_dlpack( pb_utils.get_input_tensor_by_name( request, 'image_sizes').to_dlpack()) + # Remove dimension 1, which was added to match the dimensions defined in config.pbtxt + assert image_sizes.shape[1] == 1 + image_sizes.squeeze_(1) from transformers.models.llava_onevision.modeling_llava_onevision import \ image_size_to_num_patches image_num_patches = [ @@ -276,6 +297,33 @@ class TritonPythonModel: input_tensors['attention_mask_llm'].append(attention_mask) input_tensors['image_grid_thw'].append(image_grid_thw) + elif self.model_type == 'pixtral': + if img_tensor is None: + input_tensors['pixel_values'].append(None) + else: + assert batch_size == 1, "Only support batch size 1 for Pixtral, because each batch can contain a different number of images" + d_min = torch.finfo(self.vision_output_dtype).min + total_images = img_tensor.shape[0] * img_tensor.shape[1] + num_patches = self.image_size // self.patch_size + input_tensors['input'].append( + img_tensor.view(-1, img_tensor.shape[2], + img_tensor.shape[3], img_tensor.shape[4])) + attention_mask_shape = (total_images, num_patches, num_patches) + attention_mask = torch.full(attention_mask_shape, + fill_value=d_min, + dtype=self.vision_output_dtype, + device="cuda") + image_sizes = from_dlpack( + pb_utils.get_input_tensor_by_name( + request, + 'image_sizes').to_dlpack()).reshape(total_images, 2) + for image_idx in range(total_images): + image_h, image_w = image_sizes[image_idx][0], image_sizes[ + image_idx][1] + attention_mask[image_idx, :image_h // + self.patch_size, :image_w // + self.patch_size] = 0 + input_tensors['attention_mask'].append(attention_mask) else: input_tensors['input'].append( img_tensor.view(-1, img_tensor.shape[2], img_tensor.shape[3], @@ -408,7 +456,7 @@ class TritonPythonModel: f"encoder_output_lengths: {encoder_output_lengths}") # True when the request does not have image input - output_tensors = [ + response_tensors = [ pb_utils.Tensor.from_dlpack( 'ENCODER_INPUT_FEATURES', to_dlpack(encoder_input_features)), @@ -417,16 +465,16 @@ class TritonPythonModel: to_dlpack(encoder_output_lengths)) ] if cross_attention_mask is not None: - output_tensors.append( + response_tensors.append( pb_utils.Tensor.from_dlpack( 'CROSS_ATTENTION_MASK', to_dlpack(cross_attention_mask))) - output_tensors.append( + response_tensors.append( pb_utils.Tensor.from_dlpack( 'SKIP_CROSS_ATTN_BLOCKS', to_dlpack(skip_cross_attn_blocks))) inference_response = pb_utils.InferenceResponse( - output_tensors=output_tensors) + output_tensors=response_tensors) responses.append(inference_response) elif self.model_type == 'llava_onevision': for req_idx, embeddings in enumerate( @@ -443,6 +491,9 @@ class TritonPythonModel: image_sizes = from_dlpack( pb_utils.get_input_tensor_by_name( request, 'image_sizes').to_dlpack()) + # Remove dimension 1, which was added to match the dimensions defined in config.pbtxt + assert image_sizes.shape[1] == 1 + image_sizes.squeeze_(1) from transformers.models.llava_onevision.modeling_llava_onevision import \ image_size_to_num_patches image_num_patches = [ @@ -458,10 +509,10 @@ class TritonPythonModel: embeddings, image_sizes, image_num_patches) prompt_embedding_table_tensor = pb_utils.Tensor.from_dlpack( 'OUT_PROMPT_EMBEDDING_TABLE', to_dlpack(prompt_table)) - output_tensors = [prompt_embedding_table_tensor] + response_tensors = [prompt_embedding_table_tensor] inference_response = pb_utils.InferenceResponse( - output_tensors=output_tensors) + output_tensors=response_tensors) responses.append(inference_response) elif self.model_type == 'qwen2_vl': image_grid_thw = other_vision_input_tensors.get('image_grid_thw') @@ -493,12 +544,92 @@ class TritonPythonModel: 'MROPE_ROTARY_COS_SIN', to_dlpack(mrope_rotary_cos_sin)) mrope_position_deltas_tensor = pb_utils.Tensor.from_dlpack( 'MROPE_POSITION_DELTAS', to_dlpack(mrope_position_deltas)) - output_tensors = [ + response_tensors = [ prompt_embedding_table_tensor, mrope_rotary_cos_sin_tensor, mrope_position_deltas_tensor ] inference_response = pb_utils.InferenceResponse( - output_tensors=output_tensors) + output_tensors=response_tensors) + responses.append(inference_response) + elif self.model_type == 'pixtral': + assert len(num_images) == len(batch_sizes) == len( + is_skip_encoders) == len(requests) + images_per_batch = [i * b for i, b in zip(num_images, batch_sizes)] + split_along = np.cumsum(images_per_batch).tolist() + if output_tensor is not None: + splitted_output_tensor = torch.tensor_split(output_tensor, + split_along, + dim=0) + visual_embed_dim = output_tensor.shape[-1] + output_img_size = self.image_size // self.relevant_patch_size + + for req_idx, request in enumerate(requests): + if is_skip_encoders[req_idx]: + responses.append( + pb_utils.InferenceResponse(output_tensors=[])) + continue + + response_tensors = [] + assert splitted_output_tensor[req_idx].ndim == 3 + current_output_tensor = splitted_output_tensor[req_idx].reshape( + batch_sizes[req_idx], num_images[req_idx], + splitted_output_tensor[req_idx].shape[-2], + splitted_output_tensor[req_idx].shape[-1]) + image_sizes = from_dlpack( + pb_utils.get_input_tensor_by_name( + request, 'image_sizes').to_dlpack()) + complete_visual_features = [] + vocab_size = [] + for batch_idx in range(batch_sizes[req_idx]): + batch_visual_features = [] + for image_idx in range(num_images[req_idx]): + image_h = image_sizes[batch_idx][image_idx][0] + image_w = image_sizes[batch_idx][image_idx][1] + h_patches = image_h // self.relevant_patch_size + w_patches = image_w // self.relevant_patch_size + relevant_visual_features = torch.zeros( + 1, h_patches * w_patches, visual_embed_dim) + visual_features = current_output_tensor[batch_idx][ + image_idx].reshape(output_img_size, output_img_size, + visual_embed_dim) + flattened_features = visual_features[:h_patches, : + w_patches, :].flatten( + 0, 1) + relevant_visual_features[ + 0, :h_patches * w_patches, :] = flattened_features + batch_visual_features.append(relevant_visual_features) + batch_visual_features = torch.cat(batch_visual_features, + dim=1) + vocab_size.append(batch_visual_features.shape[1]) + complete_visual_features.append(batch_visual_features) + + # Pad elements of complete_visual_features to have the same shape[1], + # to allow concatenation over batch dimension + max_vocab_size = max(vocab_size) + for batch_idx in range(batch_sizes[req_idx]): + complete_visual_features[ + batch_idx] = torch.nn.functional.pad( + complete_visual_features[batch_idx], + (0, 0, 0, max_vocab_size - + complete_visual_features[batch_idx].shape[1]), + mode='constant') + complete_visual_features = torch.cat(complete_visual_features, + dim=0) + + prompt_embedding_table_tensor = pb_utils.Tensor.from_dlpack( + 'OUT_PROMPT_EMBEDDING_TABLE', + to_dlpack( + complete_visual_features.type( + self.vision_output_dtype))) + prompt_vocab_size_tensor = pb_utils.Tensor( + 'OUT_PROMPT_VOCAB_SIZE', + np.array(vocab_size, + dtype=np.int32).reshape(batch_sizes[req_idx], 1)) + + response_tensors.extend( + [prompt_embedding_table_tensor, prompt_vocab_size_tensor]) + inference_response = pb_utils.InferenceResponse( + output_tensors=response_tensors) responses.append(inference_response) else: for req_idx, embeddings in enumerate( @@ -530,17 +661,67 @@ class TritonPythonModel: prompt_vocab_size_tensor = pb_utils.Tensor( 'OUT_PROMPT_VOCAB_SIZE', prompt_vocab_size.astype(np.int32)) - output_tensors = [ + response_tensors = [ prompt_embedding_table_tensor, prompt_vocab_size_tensor ] inference_response = pb_utils.InferenceResponse( - output_tensors=output_tensors) + output_tensors=response_tensors) responses.append(inference_response) # You should return a list of pb_utils.InferenceResponse. Length # of this list must match the length of `requests` list. return responses + def run_vision_encoder(self, vit_input: Dict[str, + torch.Tensor]) -> torch.Tensor: + batch_size = [v.shape[0] for v in vit_input.values()] + assert all( + b == batch_size[0] + for b in batch_size), "Batch sizes of encoder inputs must match" + batch_size = batch_size[0] + + embeddings = [] + for start_idx in range(0, batch_size, self.vision_max_batch_size): + end_idx = min(start_idx + self.vision_max_batch_size, batch_size) + logger.debug( + f"Running encoder (max_batch_size={self.vision_max_batch_size}) " + + f"with batch indices {start_idx}:{end_idx} of {batch_size}.") + + # Slice the input tensors along the batch dimension + vit_input_batch = { + k: v[start_idx:end_idx] + for k, v in vit_input.items() + } + + # Set up output tensors + vit_input_info = [ + TensorInfo(key, torch_dtype_to_trt(val.dtype), val.shape) + for key, val in vit_input_batch.items() + ] + vit_output_info = self.image_session.infer_shapes(vit_input_info) + + vit_output_batch = { + t.name: + torch.empty(tuple(t.shape), + dtype=trt_dtype_to_torch(t.dtype), + device='cuda') + for t in vit_output_info + } + + # Run the vision encoder + with torch.cuda.stream(self.vision_stream): + ok = self.image_session.run(vit_input_batch, vit_output_batch, + self.vision_stream.cuda_stream) + assert ok, "Runtime execution failed for vision encoder session" + embeddings.append(vit_output_batch['encoder_output'].to( + self.vision_output_dtype)) + + with torch.cuda.stream(self.vision_stream): + embeddings = torch.cat(embeddings, dim=0) + + self.vision_stream.synchronize() + return embeddings + def execute(self, requests: List): """`execute` must be implemented in every Python model. `execute` function receives a list of pb_utils.InferenceRequest as the only @@ -664,28 +845,8 @@ class TritonPythonModel: vit_input['attention_mask'] = attention_mask_vit.to( str_dtype_to_torch(self.vision_dtype_str)).to('cuda') - # Set up output tensors - vit_input_info = [ - TensorInfo(key, torch_dtype_to_trt(val.dtype), val.shape) - for key, val in vit_input.items() - ] - vit_output_info = self.image_session.infer_shapes( - vit_input_info) - vit_output = { - t.name: - torch.empty(tuple(t.shape), - dtype=trt_dtype_to_torch(t.dtype), - device='cuda') - for t in vit_output_info - } - # Run the vision encoder - with torch.cuda.stream(self.vision_stream): - ok = self.image_session.run(vit_input, vit_output, - self.vision_stream.cuda_stream) - assert ok, "Runtime execution failed for vision encoder session" - embeddings = vit_output['encoder_output'].to( - self.vision_output_dtype) - self.vision_stream.synchronize() + embeddings = self.run_vision_encoder(vit_input) + # Post process output and save in responses responses.extend( self.postprocess_output_tensors(embeddings, diff --git a/triton_backend/all_models/multimodal/multimodal_encoders/config.pbtxt b/triton_backend/all_models/multimodal/multimodal_encoders/config.pbtxt index 715c491501..c2a79e01e7 100755 --- a/triton_backend/all_models/multimodal/multimodal_encoders/config.pbtxt +++ b/triton_backend/all_models/multimodal/multimodal_encoders/config.pbtxt @@ -72,13 +72,14 @@ input [ dims: [ 1 ] optional: true }, - # input tensors for llava_onevision + # Required for llava_onevision and pixtral { name: "image_sizes" data_type: TYPE_INT64 - dims: [ 2 ] + dims: [ -1, 2 ] optional: true }, + # Required for llava_onevision { name: "is_video_input" data_type: TYPE_BOOL @@ -114,7 +115,7 @@ input [ output [ { name: "OUT_PROMPT_EMBEDDING_TABLE" - data_type: TYPE_FP16 + data_type: ${prompt_embedding_table_data_type} dims: [ -1, -1 ] }, { diff --git a/triton_backend/all_models/multimodal/requirements-mistral3.1.txt b/triton_backend/all_models/multimodal/requirements-mistral3.1.txt new file mode 100644 index 0000000000..954e44483a --- /dev/null +++ b/triton_backend/all_models/multimodal/requirements-mistral3.1.txt @@ -0,0 +1 @@ +transformers>=4.50.0 diff --git a/triton_backend/all_models/tests/test_llmapi_python_backend.py b/triton_backend/all_models/tests/test_llmapi_python_backend.py index 6ef7cd9946..6ab4120aa4 100644 --- a/triton_backend/all_models/tests/test_llmapi_python_backend.py +++ b/triton_backend/all_models/tests/test_llmapi_python_backend.py @@ -64,6 +64,12 @@ class MockTritonTensor: else: return False + def to_dlpack(self): + if self.is_cpu(): + return self._tensor.__dlpack__() + else: + return self._tensor.to_dlpack() + @dataclass class MockTritonError: diff --git a/triton_backend/all_models/tests/test_python_backend.py b/triton_backend/all_models/tests/test_python_backend.py index b993af957f..8e17f2b09f 100644 --- a/triton_backend/all_models/tests/test_python_backend.py +++ b/triton_backend/all_models/tests/test_python_backend.py @@ -63,6 +63,12 @@ class MockTritonTensor: else: return False + def to_dlpack(self): + if self.is_cpu(): + return self._tensor.__dlpack__() + else: + return self._tensor.to_dlpack() + @dataclass class MockTritonError: diff --git a/triton_backend/all_models/tests/test_triton_decoder.py b/triton_backend/all_models/tests/test_triton_decoder.py index 7ebcb28e99..90f71107b7 100644 --- a/triton_backend/all_models/tests/test_triton_decoder.py +++ b/triton_backend/all_models/tests/test_triton_decoder.py @@ -64,6 +64,12 @@ class MockTritonTensor: else: return False + def to_dlpack(self): + if self.is_cpu(): + return self._tensor.__dlpack__() + else: + return self._tensor.to_dlpack() + @dataclass class MockTritonResponse: diff --git a/triton_backend/ci/L0_backend_trtllm/test.sh b/triton_backend/ci/L0_backend_trtllm/test.sh index 83967d1c58..272a208b53 100644 --- a/triton_backend/ci/L0_backend_trtllm/test.sh +++ b/triton_backend/ci/L0_backend_trtllm/test.sh @@ -197,6 +197,7 @@ for NUM_GPU in "${NUM_GPUS_TO_TEST[@]}"; do replace_config_tags '${max_queue_delay_microseconds}' "50000" "${MODEL_DIR}/tensorrt_llm/config.pbtxt" replace_config_tags '${triton_backend}' "tensorrtllm" "${MODEL_DIR}/tensorrt_llm/config.pbtxt" replace_config_tags '${encoder_input_features_data_type}' "TYPE_FP16" "${MODEL_DIR}/tensorrt_llm/config.pbtxt" + replace_config_tags '${prompt_embedding_table_data_type}' 'TYPE_FP16' "${MODEL_DIR}/tensorrt_llm/config.pbtxt" replace_config_tags '${triton_max_batch_size}' "128" "${MODEL_DIR}/postprocessing/config.pbtxt" replace_config_tags '${tokenizer_dir}' "${TOKENIZER_DIR}/" "${MODEL_DIR}/postprocessing/config.pbtxt" replace_config_tags '${postprocessing_instance_count}' '1' "${MODEL_DIR}/postprocessing/config.pbtxt" diff --git a/triton_backend/inflight_batcher_llm/scripts/build.sh b/triton_backend/inflight_batcher_llm/scripts/build.sh index 44a5550021..031d623d69 100644 --- a/triton_backend/inflight_batcher_llm/scripts/build.sh +++ b/triton_backend/inflight_batcher_llm/scripts/build.sh @@ -53,7 +53,8 @@ fi # TODO: Remove specifying Triton version after cmake version is upgraded to 3.31.8 # Get TRITON_SHORT_TAG from docker/Dockerfile.multi -LLM_ROOT="$(dirname $0)/../../../.." +LLM_ROOT=$BUILD_DIR/../../.. +LLM_ROOT=$(cd -- "$LLM_ROOT" && pwd) TRITON_SHORT_TAG=$("$LLM_ROOT/jenkins/scripts/get_triton_tag.sh" "$LLM_ROOT") cmake -DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install ${BUILD_TESTS_ARG} -DTRITON_COMMON_REPO_TAG=${TRITON_SHORT_TAG} -DTRITON_CORE_REPO_TAG=${TRITON_SHORT_TAG} -DTRITON_THIRD_PARTY_REPO_TAG=${TRITON_SHORT_TAG} -DTRITON_BACKEND_REPO_TAG=${TRITON_SHORT_TAG} .. make install diff --git a/triton_backend/tools/multimodal/client.py b/triton_backend/tools/multimodal/client.py index bac2b4ef5b..b77de50e8d 100755 --- a/triton_backend/tools/multimodal/client.py +++ b/triton_backend/tools/multimodal/client.py @@ -6,6 +6,8 @@ import io import os import sys from datetime import datetime +from pathlib import Path +from typing import List, Tuple sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) @@ -19,8 +21,32 @@ from transformers import AutoProcessor, Blip2Processor from utils import utils +def pixtral_pad_images( + image_list: List[Image.Image]) -> Tuple[np.ndarray, np.ndarray]: + if not image_list: + return np.empty((0, 0, 0, 0), dtype=np.uint8), np.empty((0, 2), + dtype=np.int64) + image_list_np = [np.array(img) for img in image_list] + shapes = [img.shape for img in image_list_np] + assert all(len(s) == 3 + for s in shapes), "All input images must have three dimensions" + assert all(s[-1] == shapes[0][-1] for s in + shapes), "All input images must have the same number of channels" + max_h, max_w = max(s[0] for s in shapes), max(s[1] for s in shapes) + for i in range(len(image_list_np)): + image_list_np[i] = np.pad(image_list_np[i], + ((0, max_h - image_list_np[i].shape[0]), + (0, max_w - image_list_np[i].shape[1]), + (0, 0)), + mode='constant') + raw_image = np.stack(image_list_np, axis=0) + image_sizes = np.array([s[:2] for s in shapes], dtype=np.int64) + return raw_image, image_sizes + + def prepare_inputs(text_data, image_data, + image_sizes, request_output_len_data, beam_width_data, temperature_data, @@ -35,7 +61,6 @@ def prepare_inputs(text_data, image_input_name="image_input"): inputs = [ utils.prepare_tensor("text_input", text_data, grpcclient), - utils.prepare_tensor(image_input_name, image_data, grpcclient), utils.prepare_tensor("max_tokens", request_output_len_data, grpcclient), utils.prepare_tensor("beam_width", beam_width_data, grpcclient), utils.prepare_tensor("temperature", temperature_data, grpcclient), @@ -45,6 +70,14 @@ def prepare_inputs(text_data, utils.prepare_tensor("top_p", top_p_data, grpcclient), utils.prepare_tensor("stream", streaming_data, grpcclient), ] + if image_data is not None: + inputs += [ + utils.prepare_tensor(image_input_name, image_data, grpcclient), + ] + if image_sizes is not None: + inputs += [ + utils.prepare_tensor("image_sizes_input", image_sizes, grpcclient), + ] if repetition_penalty_data is not None: inputs += [ utils.prepare_tensor("repetition_penalty", repetition_penalty_data, @@ -63,20 +96,16 @@ def prepare_inputs(text_data, return inputs -def load_image(image_path): +def load_image(image_path) -> Image.Image: if image_path.startswith("http") or image_path.startswith("https"): - image = Image.open(requests.get(image_path, - stream=True).raw).convert("RGB") + image_bytes = requests.get(image_path, stream=True).content elif image_path.startswith("data:image/jpeg;base64,"): image_base64 = image_path.split(",")[1] - # Decode the base64 string - image_data = base64.b64decode(image_base64) - # Create a BytesIO object from the decoded data - image_buffer = io.BytesIO(image_data) - image = Image.open(image_buffer).convert("RGB") + image_bytes = base64.b64decode(image_base64) else: - image = Image.open(image_path).convert("RGB") - return image + image_bytes = Path(image_path).read_bytes() + + return Image.open(io.BytesIO(image_bytes)).convert("RGB") def load_video(video_path, num_of_frames): @@ -239,7 +268,7 @@ if __name__ == "__main__": required=True, choices=[ 'blip2', 'llava', 'vila', 'mllama', - 'llava_onevision', 'qwen2_vl' + 'llava_onevision', 'qwen2_vl', 'pixtral' ], help="Model type") parser.add_argument("--hf_model_dir", @@ -249,11 +278,18 @@ if __name__ == "__main__": help="path to the model directory") FLAGS = parser.parse_args() # load and process images or video + image_sizes = np.empty((0, 2), dtype=np.int64) if 'vila' in FLAGS.model_type: image_paths = FLAGS.image.split(",") raw_image = [] for image_path in image_paths: raw_image.append(load_image(image_path)) + elif 'pixtral' in FLAGS.model_type: + image_paths = FLAGS.image.split(",") if FLAGS.image else [] + raw_image = [] + for image_path in image_paths: + raw_image.append(load_image(image_path)) + raw_image, image_sizes = pixtral_pad_images(raw_image) elif FLAGS.video is not None: assert FLAGS.video_num_frames is not None, "Number of frames should be provided for video input." raw_video = load_video(FLAGS.video, FLAGS.video_num_frames) @@ -303,6 +339,9 @@ if __name__ == "__main__": FLAGS.text = image_tag + FLAGS.text image_data = np.array([[raw_image]]) image_input_name = "image_bytes_input" + elif 'pixtral' in FLAGS.model_type: + image_data = np.array([raw_image]) + image_input_name = "image_bytes_input" elif 'llava_onevision' in FLAGS.model_type: if FLAGS.video is not None: image_data = np.array([raw_video]) @@ -334,6 +373,9 @@ if __name__ == "__main__": temperature_data = np.array(temperature, dtype=np.float32) streaming = [[FLAGS.streaming]] streaming_data = np.array(streaming, dtype=bool) + image_data = None if image_data.size == 0 else image_data + image_sizes_data = None if image_sizes.size == 0 else np.array( + [image_sizes], dtype=np.int64) model_name = "ensemble" if FLAGS.use_bls: @@ -356,6 +398,7 @@ if __name__ == "__main__": inputs = prepare_inputs(text_data, image_data, + image_sizes_data, request_output_len_data, beam_width_data, temperature_data,