/* * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "cutlass_extensions/gemm_configs.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/kernels/internal_cutlass_kernels/include/fp4_gemm.h" #include "tensorrt_llm/kernels/quantization.h" #include "tensorrt_llm/thop/thUtils.h" #include #include #include #include #include #include #include #include namespace tkc = tensorrt_llm::cutlass_extensions; using tensorrt_llm::kernels::internal_cutlass_kernels::CutlassFp4GemmRunner; using tensorrt_llm::kernels::internal_cutlass_kernels::CutlassFp4GemmRunnerInterface; namespace torch_ext { namespace { tkc::CutlassGemmConfig getDefaultGemmConfig(int64_t m, int64_t n, int64_t k) { return tkc::CutlassGemmConfig(tkc::CutlassTileConfigSM100::CtaShape128x256x128B, tkc::MainloopScheduleType::AUTO, tkc::EpilogueScheduleType::AUTO, tkc::ClusterShape::ClusterShape_1x1x1); } template void runGemm(at::Tensor& out, at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale, at::Tensor const& mat2Scale, at::Tensor const& globalScale, int64_t m, int64_t n, int64_t k, int64_t batch_count, tkc::CutlassGemmConfig const& gemmConfig) { CutlassFp4GemmRunner gemmRunner; int64_t const wsBytes = gemmRunner.getWorkspaceSize(m, n, k, batch_count); at::Tensor workspace = at::detail::empty_cuda({wsBytes}, at::ScalarType::Char, mat1.device(), std::nullopt); gemmRunner.gemm(out.data_ptr(), mat1.const_data_ptr(), mat2.const_data_ptr(), mat1Scale.const_data_ptr(), mat2Scale.const_data_ptr(), globalScale.data_ptr(), m, n, k, batch_count, gemmConfig, reinterpret_cast(workspace.data_ptr()), wsBytes, at::cuda::getCurrentCUDAStream(mat1.get_device())); } // mat1: [B, M, K / 2], FLOAT4_E2M1X2 // mat2: [B, N, K / 2], FLOAT4_E2M1X2 // out: [B, M, N], fp16/bf16/fp32 // mat1Scale: B * ceil(M / 128) * 128 * ceil(K / sfVecSize / 4) * 4, SF_DTYPE (UE4M3 or UE8M0) // mat2Scale: B * ceil(N / 128) * 128 * ceil(K / sfVecSize / 4) * 4, SF_DTYPE (UE4M3 or UE8M0) // globalScale: [1], 1 / (((448 * 6) / mat1.abs().max()) * ((448 * 6) / mat2.abs().max())) // Only NVFP4 is currently supported at::Tensor fp4_bmm_impl(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale, at::Tensor const& mat2Scale, at::Tensor const& globalScale, bool sfUseUE8M0, std::optional out_dtype, tkc::CutlassGemmConfig const* maybe_config = nullptr) { CHECK_INPUT(mat1, FLOAT4_E2M1X2); CHECK_INPUT(mat2, FLOAT4_E2M1X2); CHECK_INPUT(mat1Scale, SF_DTYPE); CHECK_INPUT(mat2Scale, SF_DTYPE); CHECK_INPUT(globalScale, at::ScalarType::Float); TORCH_CHECK(!sfUseUE8M0, "use UE8M0 for FP4 Block Scale Factors is not supported yet"); TORCH_CHECK(mat1.dim() == 3, "mat1 must be a batch of matrices"); TORCH_CHECK(mat2.dim() == 3, "mat2 must be a batch of matrices"); TORCH_CHECK(mat1.sizes()[0] == mat2.sizes()[0], "mat1 and mat2 must have the same number of batches"); TORCH_CHECK(mat1.sizes()[2] == mat2.sizes()[2], "mat1 and mat2 shapes cannot be multiplied (", mat1.sizes()[1], "x", mat1.sizes()[2], " and ", mat2.sizes()[1], "x", mat2.sizes()[2], ")"); auto const m = mat1.sizes()[1]; auto const n = mat2.sizes()[1]; auto const k = mat1.sizes()[2] * 2; auto const batch_count = mat1.sizes()[0]; auto config = maybe_config ? *maybe_config : getDefaultGemmConfig(m, n, k); constexpr int alignment = 32; TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment, ", but got mat1 shape: (", mat1.sizes()[0], "x", mat1.sizes()[1], "), k: ", k, "."); TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment, ", but got mat2 shape: (", mat2.sizes()[0], "x", mat2.sizes()[1], ")."); if (!out_dtype) { out_dtype = torch::kHalf; } TORCH_CHECK(out_dtype == torch::kFloat || out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be one of fp16/bf16/fp32. It defaults to fp16."); at::Tensor out = at::detail::empty_cuda({batch_count, m, n}, out_dtype.value(), mat1.device(), std::nullopt); switch (out_dtype.value()) { case at::ScalarType::Half: runGemm(out, mat1, mat2, mat1Scale, mat2Scale, globalScale, m, n, k, batch_count, config); break; case at::ScalarType::BFloat16: runGemm<__nv_bfloat16>(out, mat1, mat2, mat1Scale, mat2Scale, globalScale, m, n, k, batch_count, config); break; case at::ScalarType::Float: runGemm(out, mat1, mat2, mat1Scale, mat2Scale, globalScale, m, n, k, batch_count, config); break; default: C10_THROW_ERROR(NotImplementedError, "out_dtype must be one of fp16/bf16/fp32."); } return out; } // mat1: [M, K / 2], FLOAT4_E2M1X2 // mat2: [N, K / 2], FLOAT4_E2M1X2 // out: [M, N], fp16/bf16/fp32 // mat1Scale: ceil(M / 128) * 128 * ceil(K / sfVecSize / 4) * 4, SF_DTYPE (UE4M3 or UE8M0) // mat2Scale: ceil(N / 128) * 128 * ceil(K / sfVecSize / 4) * 4, SF_DTYPE (UE4M3 or UE8M0) // globalScale: [1], 1 / (((448 * 6) / mat1.abs().max()) * ((448 * 6) / mat2.abs().max())) // Only NVFP4 is currently supported at::Tensor fp4_gemm_impl(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale, at::Tensor const& mat2Scale, at::Tensor const& globalScale, bool sfUseUE8M0, std::optional out_dtype, tkc::CutlassGemmConfig const* maybe_config = nullptr) { CHECK_INPUT(mat1, FLOAT4_E2M1X2); CHECK_INPUT(mat2, FLOAT4_E2M1X2); CHECK_INPUT(mat1Scale, SF_DTYPE); CHECK_INPUT(mat2Scale, SF_DTYPE); CHECK_INPUT(globalScale, at::ScalarType::Float); TORCH_CHECK(!sfUseUE8M0, "use UE8M0 for FP4 Block Scale Factors is not supported yet"); TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix"); TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix"); TORCH_CHECK(mat1.sizes()[1] == mat2.sizes()[1], "mat1 and mat2 shapes cannot be multiplied (", mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); auto const m = mat1.sizes()[0]; auto const n = mat2.sizes()[0]; auto const k = mat1.sizes()[1] * 2; constexpr auto batch_count = 1; // Batch count is 1 for plain GEMM auto config = maybe_config ? *maybe_config : getDefaultGemmConfig(m, n, k); constexpr int alignment = 32; TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment, ", but got mat1 shape: (", mat1.sizes()[0], "x", mat1.sizes()[1], "), k: ", k, "."); TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment, ", but got mat2 shape: (", mat2.sizes()[0], "x", mat2.sizes()[1], ")."); if (!out_dtype) { out_dtype = torch::kHalf; } TORCH_CHECK(out_dtype == torch::kFloat || out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be one of fp16/bf16/fp32. It defaults to fp16."); at::Tensor out = at::detail::empty_cuda({m, n}, out_dtype.value(), mat1.device(), std::nullopt); switch (out_dtype.value()) { case at::ScalarType::Half: runGemm(out, mat1, mat2, mat1Scale, mat2Scale, globalScale, m, n, k, batch_count, config); break; case at::ScalarType::BFloat16: runGemm<__nv_bfloat16>(out, mat1, mat2, mat1Scale, mat2Scale, globalScale, m, n, k, batch_count, config); break; case at::ScalarType::Float: runGemm(out, mat1, mat2, mat1Scale, mat2Scale, globalScale, m, n, k, batch_count, config); break; default: C10_THROW_ERROR(NotImplementedError, "out_dtype must be one of fp16/bf16/fp32."); } return out; } struct Fp4GemmId { int64_t n; int64_t k; int64_t b; c10::ScalarType out_dtype; Fp4GemmId(int64_t n_, int64_t k_, int64_t b_, c10::ScalarType out_dtype_) : n(n_) , k(k_) , b(b_) , out_dtype(out_dtype_) { } Fp4GemmId() : n(-1) , k(-1) , b(-1) , out_dtype(c10::ScalarType::Half) { } bool operator==(Fp4GemmId const& other) const { return n == other.n && k == other.k && b == other.b && out_dtype == other.out_dtype; } }; struct Fp4GemmIdHash { std::size_t operator()(Fp4GemmId const& id) const { auto h1 = std::hash{}(id.n); auto h2 = std::hash{}(id.k); auto h3 = std::hash{}(id.b); auto h4 = std::hash{}(static_cast(id.out_dtype)); return h1 ^ h2 ^ h3 ^ h4; } }; void eventDelete(cudaEvent_t event) { // Do not use check_cuda_error here; we need to swallow all // exceptions in destructors. if (event != nullptr) { cudaEventDestroy(event); } } float profileConfigForProblem(CutlassFp4GemmRunnerInterface& gemmRunner, int64_t m, int64_t n, int64_t k, int64_t batch_count, at::Tensor& out, at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale, at::Tensor const& mat2Scale, at::Tensor const& globalScale, at::Tensor& workspace, int64_t wsBytes, tkc::CutlassGemmConfig const& config) { constexpr int warmup = 5; constexpr int runs = 10; auto stream = at::cuda::getCurrentCUDAStream(mat1.get_device()); auto runGemm = [&] { gemmRunner.gemm(out.data_ptr(), mat1.const_data_ptr(), mat2.const_data_ptr(), mat1Scale.const_data_ptr(), mat2Scale.const_data_ptr(), globalScale.data_ptr(), m, n, k, batch_count, config, reinterpret_cast(workspace.data_ptr()), wsBytes, stream); }; for (int i = 0; i < warmup; ++i) { runGemm(); } // Wrap start/stop in unique ptrs so we don't leak memory if there are errors. cudaEvent_t start; tensorrt_llm::common::check_cuda_error(cudaEventCreate(&start)); std::unique_ptr::type, decltype(&eventDelete)> start_ptr{start, &eventDelete}; cudaEvent_t stop; tensorrt_llm::common::check_cuda_error(cudaEventCreate(&stop)); std::unique_ptr::type, decltype(&eventDelete)> stop_ptr{stop, &eventDelete}; tensorrt_llm::common::check_cuda_error(cudaStreamSynchronize(stream)); tensorrt_llm::common::check_cuda_error(cudaEventRecord(start, stream)); for (int i = 0; i < runs; ++i) { runGemm(); } tensorrt_llm::common::check_cuda_error(cudaEventRecord(stop, stream)); tensorrt_llm::common::check_cuda_error(cudaEventSynchronize(stop)); float elapsed; tensorrt_llm::common::check_cuda_error(cudaEventElapsedTime(&elapsed, start, stop)); return elapsed / runs; } std::pair runProfilingFor( int64_t m, Fp4GemmId const& gemmId, CutlassFp4GemmRunnerInterface& gemmRunner) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); float bestTime = std::numeric_limits::max(); tkc::CutlassGemmConfig bestConfig; int64_t bestIdx{}; auto configs = gemmRunner.getConfigs(); // Check alignment here to avoid confusing errors from failing inside the kernel. constexpr int alignment = 32; TORCH_CHECK(gemmId.k % alignment == 0, "Expected k to be divisible by ", alignment, ", but got ", gemmId.k); TORCH_CHECK(gemmId.n % alignment == 0, "Expected n to be divisible by ", alignment, ", but got ", gemmId.n); constexpr auto randMin = std::numeric_limits::min(); constexpr auto randMax = std::numeric_limits::max(); auto const batch_count = gemmId.b; // Note that we have to use randint() as the fp4/sf dtypes are really just aliases for integer types. at::Tensor mat1 = at::randint(randMin, randMax, {batch_count, m, gemmId.k / 2}, at::ScalarType::Byte, std::nullopt, torch::kCUDA, std::nullopt) .view(FLOAT4_E2M1X2); at::Tensor mat2 = at::randint(randMin, randMax, {batch_count, gemmId.n, gemmId.k / 2}, at::ScalarType::Byte, std::nullopt, torch::kCUDA, std::nullopt) .view(FLOAT4_E2M1X2); at::Tensor out = at::randn({batch_count, m, gemmId.n}, gemmId.out_dtype, std::nullopt, torch::kCUDA, std::nullopt); constexpr int sfVecSize = 16; auto const scaleColSize = gemmId.k / sfVecSize; auto mat1ScaleNumElems = tensorrt_llm::computeSFSize(m, scaleColSize) * batch_count; at::Tensor mat1Scale = at::randint( randMin, randMax, {mat1ScaleNumElems}, at::ScalarType::Byte, std::nullopt, torch::kCUDA, std::nullopt) .view(SF_DTYPE); auto mat2ScaleNumElems = tensorrt_llm::computeSFSize(gemmId.n, scaleColSize) * batch_count; at::Tensor mat2Scale = at::randint( randMin, randMax, {mat2ScaleNumElems}, at::ScalarType::Byte, std::nullopt, torch::kCUDA, std::nullopt) .view(SF_DTYPE); at::Tensor globalScale = at::randn({1}, at::ScalarType::Float, std::nullopt, torch::kCUDA, std::nullopt); int64_t const wsBytes = gemmRunner.getWorkspaceSize(m, gemmId.n, gemmId.k, batch_count); at::Tensor workspace = at::detail::empty_cuda({wsBytes}, at::ScalarType::Char, torch::kCUDA, std::nullopt); for (int64_t i = 0; i < static_cast(configs.size()); ++i) { auto& config = configs[i]; try { float const time = profileConfigForProblem(gemmRunner, m, gemmId.n, gemmId.k, batch_count, out, mat1, mat2, mat1Scale, mat2Scale, globalScale, workspace, wsBytes, config); if (time < bestTime) { bestTime = time; bestConfig = config; bestIdx = i; } } catch (std::exception const& e) { std::ostringstream msg; msg << "Cannot profile configuration " << config.toString() << "\n (for" << " m=" << m << ", n=" << gemmId.n << ", k=" << gemmId.k << ")" << ", reason: \"" << e.what() << "\". Skipped"; TLLM_LOG_TRACE(msg.str()); cudaGetLastError(); // Reset the last cudaError to cudaSuccess. } } if (std::isinf(bestTime)) { std::ostringstream msg; msg << "Have not found any valid GEMM config for shape (" << "m=" << m << ", n=" << gemmId.n << ", k=" << gemmId.k << "). Will try to use default or fail at" << "runtime"; TLLM_LOG_WARNING(msg.str()); return {getDefaultGemmConfig(m, gemmId.n, gemmId.k), -1}; } return {bestConfig, bestIdx}; } class Fp4GemmProfiler { public: // Maps values of M to the best config and the config's index. using MProfileMap = std::unordered_map>; void profileTactics(int64_t m, Fp4GemmId const& gemmId, CutlassFp4GemmRunnerInterface& gemmRunner) { if (getBestConfigImpl(m, gemmId) != std::nullopt) { return; } auto bestConfigAndIdx = runProfilingFor(m, gemmId, gemmRunner); setBestConfig(m, gemmId, bestConfigAndIdx); } std::optional getBestConfig(int64_t m, Fp4GemmId const& gemmId) { auto result = getBestConfigImpl(m, gemmId); if (result != std::nullopt) { return result->first; } return std::nullopt; } std::optional getBestConfigIdx(int64_t m, Fp4GemmId const& gemmId) { auto result = getBestConfigImpl(m, gemmId); if (result != std::nullopt) { return result->second; } return std::nullopt; } private: std::optional> getBestConfigImpl(int64_t m, Fp4GemmId const& gemmId) { auto profileMapIt = mProfileMap.find(gemmId); if (profileMapIt == mProfileMap.end()) { return std::nullopt; } auto& mValToConfig = profileMapIt->second; auto configIt = mValToConfig.find(m); if (configIt != mValToConfig.end()) { return configIt->second; } return std::nullopt; } void setBestConfig( int64_t m, Fp4GemmId const& gemmId, std::pair const& configAndIdx) { // Note that profileMapIt will point to the existing map associated // with gemmId if such a map exists. auto [profileMapIt, _] = mProfileMap.try_emplace(gemmId, MProfileMap{}); auto& profileMap = profileMapIt->second; profileMap.emplace(m, configAndIdx); } std::unordered_map mProfileMap; }; } // namespace at::Tensor fp4_bmm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale, at::Tensor const& mat2Scale, at::Tensor const& globalScale, bool sfUseUE8M0, std::optional out_dtype) { // The functional version of this op does not do any profiling; use the profiler class below instead for // better performance. // Note that we can still add a heuristic here. return fp4_bmm_impl(mat1, mat2, mat1Scale, mat2Scale, globalScale, sfUseUE8M0, out_dtype); } at::Tensor fp4_gemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale, at::Tensor const& mat2Scale, at::Tensor const& globalScale, bool sfUseUE8M0, std::optional out_dtype) { // The functional version of this op does not do any profiling; use the profiler class below instead for // better performance. // Note that we can still add a heuristic here. return fp4_gemm_impl(mat1, mat2, mat1Scale, mat2Scale, globalScale, sfUseUE8M0, out_dtype); } class FP4GemmRunnerBase : public torch::CustomClassHolder { public: explicit FP4GemmRunnerBase(at::ScalarType outputDtype) : mOutputDtype(outputDtype) { if (outputDtype == at::ScalarType::Half) { mGemmRunner = std::make_unique>(); } else if (outputDtype == at::ScalarType::Float) { mGemmRunner = std::make_unique>(); } #ifdef ENABLE_BF16 else if (outputDtype == at::ScalarType::BFloat16) { mGemmRunner = std::make_unique>(); } #endif else { C10_THROW_ERROR(NotImplementedError, "out_dtype must be one of fp16/bf16/fp32."); } mConfigs = mGemmRunner->getConfigs(); } protected: int64_t getConfigIdForBucket(int64_t bucket, int64_t n, int64_t k, int64_t b) { auto gemmId = Fp4GemmId(n, k, b, mOutputDtype); auto configIdx = mGemmProfiler.getBestConfigIdx(bucket, gemmId); TORCH_CHECK(configIdx != std::nullopt, "Need to run profiling before getting best config"); return *configIdx; } int64_t getBestConfigId(int64_t m, int64_t n, int64_t k, int64_t b) { std::lock_guard lk(mProfilerMutex); auto bucketIt = std::upper_bound(mBuckets.begin(), mBuckets.end(), m); if (bucketIt == mBuckets.begin()) { return getConfigIdForBucket(mBuckets[0], n, k, b); } return getConfigIdForBucket(*(bucketIt - 1), n, k, b); } void runProfile(int64_t n, int64_t k, int64_t b, std::vector& buckets) { TORCH_CHECK(!buckets.empty(), "At least one bucket must be specified"); std::lock_guard lk(mProfilerMutex); mBuckets = std::move(buckets); std::sort(mBuckets.begin(), mBuckets.end()); for (auto m : mBuckets) { TORCH_CHECK(m > 0, "Bucket sizes must be positive."); auto gemmId = Fp4GemmId(n, k, b, mOutputDtype); TORCH_CHECK(mGemmRunner != nullptr); mGemmProfiler.profileTactics(m, gemmId, *mGemmRunner); } } at::ScalarType getOutputDtype() const { return mOutputDtype; } tkc::CutlassGemmConfig const& getConfig(int64_t idx) const { return mConfigs.at(idx); } std::vector::size_type getNumConfigs() const { return mConfigs.size(); } private: std::mutex mProfilerMutex; Fp4GemmProfiler mGemmProfiler{}; std::shared_ptr mGemmRunner{nullptr}; std::vector mConfigs; at::ScalarType mOutputDtype; std::vector mBuckets; }; class FP4BmmRunner : public FP4GemmRunnerBase { using FP4GemmRunnerBase::FP4GemmRunnerBase; public: static c10::intrusive_ptr getInstance(c10::ScalarType outputDtype) { static std::mutex instance_mutex; static std::unordered_map> dtype_to_instance; std::lock_guard lock(instance_mutex); auto instance_it = dtype_to_instance.find(outputDtype); if (instance_it != dtype_to_instance.end()) { return instance_it->second; } auto result = c10::make_intrusive(outputDtype); dtype_to_instance.emplace(outputDtype, result); return result; } void profile(int64_t n, int64_t k, int64_t b, std::vector buckets) { runProfile(n, k, b, buckets); } int64_t getBestBmmConfigId(int64_t m, int64_t n, int64_t k, int64_t b) { return getBestConfigId(m, n, k, b); } at::Tensor runBmm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale, at::Tensor const& mat2Scale, at::Tensor const& globalScale, bool sfUseUE8M0, int64_t configIdx) const { TORCH_CHECK(configIdx >= 0 && configIdx < static_cast(getNumConfigs())); auto const& config = getConfig(configIdx); return fp4_bmm_impl(mat1, mat2, mat1Scale, mat2Scale, globalScale, sfUseUE8M0, getOutputDtype(), &config); } }; class FP4GemmRunner : public FP4GemmRunnerBase { using FP4GemmRunnerBase::FP4GemmRunnerBase; public: static c10::intrusive_ptr getInstance(c10::ScalarType outputDtype) { static std::mutex instance_mutex; static std::unordered_map> dtype_to_instance; std::lock_guard lock(instance_mutex); auto instance_it = dtype_to_instance.find(outputDtype); if (instance_it != dtype_to_instance.end()) { return instance_it->second; } auto result = c10::make_intrusive(outputDtype); dtype_to_instance.emplace(outputDtype, result); return result; } void profile(int64_t n, int64_t k, std::vector buckets) { runProfile(n, k, /* batch count */ 1, buckets); } int64_t getBestGemmConfigId(int64_t m, int64_t n, int64_t k) { return getBestConfigId(m, n, k, /* batch count */ 1); } at::Tensor runGemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale, at::Tensor const& mat2Scale, at::Tensor const& globalScale, bool sfUseUE8M0, int64_t configIdx) const { TORCH_CHECK(configIdx >= 0 && configIdx < static_cast(getNumConfigs())); auto const& config = getConfig(configIdx); return fp4_gemm_impl(mat1, mat2, mat1Scale, mat2Scale, globalScale, sfUseUE8M0, getOutputDtype(), &config); } }; } // namespace torch_ext TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.class_("FP4BmmRunner") .def_static("get_instance", &torch_ext::FP4BmmRunner::getInstance) .def("run_bmm", &torch_ext::FP4BmmRunner::runBmm) .def("run_profile", &torch_ext::FP4BmmRunner::profile) .def("get_best_config_id", &torch_ext::FP4BmmRunner::getBestBmmConfigId); m.def( "fp4_bmm(Tensor mat1, Tensor mat2, Tensor mat1Scale, Tensor mat2Scale, Tensor globalScale, " "bool sfUseUE8M0, ScalarType? out_dtype=None) -> Tensor"); } TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.class_("FP4GemmRunner") .def_static("get_instance", &torch_ext::FP4GemmRunner::getInstance) .def("run_gemm", &torch_ext::FP4GemmRunner::runGemm) .def("run_profile", &torch_ext::FP4GemmRunner::profile) .def("get_best_config_id", &torch_ext::FP4GemmRunner::getBestGemmConfigId); m.def( "fp4_gemm(Tensor mat1, Tensor mat2, Tensor mat1Scale, Tensor mat2Scale, Tensor globalScale, bool sfUseUE8M0, " "ScalarType? out_dtype=None) -> Tensor"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("fp4_bmm", &torch_ext::fp4_bmm); m.impl("fp4_gemm", &torch_ext::fp4_gemm); }