TensorRT-LLMs/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/IntFastDiv.h
Matthias Jouanneaux 514baf1287
[fix] Fix comment to pass guardwords check (#5191)
Signed-off-by: Matthias Jouanneaux <mjoux@nvidia.com>
2025-06-13 15:49:59 +08:00

175 lines
4.7 KiB
C++

/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cute/config.hpp>
#include <stdexcept>
namespace trtllm::dev
{
////////////////////////////////////////////////////////////////////////////////////////////////////
// *************************************************************************************************
// IntFastDiv class.
// *************************************************************************************************
////////////////////////////////////////////////////////////////////////////////////////////////////
class IntFastDiv
{
public:
// we allow a default constructor (and over-writing the divisor with assignment)
CUTE_HOST
IntFastDiv()
: mDivisor(1)
, mMagicM(0)
, mMagicS(-1)
, mAddSign(1)
{
}
CUTE_HOST
IntFastDiv(int divisor)
: mDivisor(divisor)
{
if (mDivisor == 0)
throw std::runtime_error("IntFastDiv: cannot divide by 0");
updateMagicNumbers();
}
CUTE_HOST
IntFastDiv& operator=(int divisor)
{
this->mDivisor = divisor;
if (this->mDivisor == 0)
throw std::runtime_error("IntFastDiv: cannot divide by 0");
updateMagicNumbers();
return *this;
}
CUTE_HOST_DEVICE
operator int() const
{
return mDivisor;
}
private:
int mDivisor;
int mMagicM;
int mMagicS;
int mAddSign;
// Hacker'mMagicS Delight, Second Edition, Chapter 10, Integer Division By Constants
CUTE_HOST
void updateMagicNumbers()
{
if (mDivisor == 1)
{
mMagicM = 0;
mMagicS = -1;
mAddSign = 1;
return;
}
else if (mDivisor == -1)
{
mMagicM = 0;
mMagicS = -1;
mAddSign = -1;
return;
}
int p;
unsigned int tmpAd, tmpAnc, delta, q1, r1, q2, r2, t;
unsigned const two31 = 0x80000000;
tmpAd = abs(mDivisor);
t = two31 + ((unsigned int) mDivisor >> 31);
tmpAnc = t - 1 - t % tmpAd;
p = 31;
q1 = two31 / tmpAnc;
r1 = two31 - q1 * tmpAnc;
q2 = two31 / tmpAd;
r2 = two31 - q2 * tmpAd;
do
{
++p;
q1 = 2 * q1;
r1 = 2 * r1;
if (r1 >= tmpAnc)
{
++q1;
r1 -= tmpAnc;
}
q2 = 2 * q2;
r2 = 2 * r2;
if (r2 >= tmpAd)
{
++q2;
r2 -= tmpAd;
}
delta = tmpAd - r2;
} while (q1 < delta || (q1 == delta && r1 == 0));
this->mMagicM = q2 + 1;
if (mDivisor < 0)
this->mMagicM = -this->mMagicM;
this->mMagicS = p - 32;
if ((mDivisor > 0) && (mMagicM < 0))
mAddSign = 1;
else if ((mDivisor < 0) && (mMagicM > 0))
mAddSign = -1;
else
mAddSign = 0;
}
CUTE_HOST_DEVICE
friend int operator/(int const dividend, IntFastDiv const& divisor);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
CUTE_HOST_DEVICE
int operator/(int const dividend, IntFastDiv const& divisor)
{
int q;
#ifdef __CUDA_ARCH__
asm("mul.hi.s32 %0, %1, %2;" : "=r"(q) : "r"(divisor.mMagicM), "r"(dividend));
#else
q = (((unsigned long long) ((long long) divisor.mMagicM * (long long) dividend)) >> 32);
#endif
q += dividend * divisor.mAddSign;
if (divisor.mMagicS >= 0)
{
q >>= divisor.mMagicS; // we rely on this to be implemented as arithmetic shift
q += (((unsigned int) q) >> 31);
}
return q;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
CUTE_HOST_DEVICE
int operator%(int const dividend, IntFastDiv const& divisor)
{
int quotient = dividend / divisor;
int remainder = dividend - quotient * divisor;
return remainder;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace trtllm::dev