mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Yao Yao <lowsfer@users.noreply.github.com> Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com> Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> Co-authored-by: Yao Yao <lowsfer@users.noreply.github.com> Co-authored-by: peaceh-nv <103117813+peaceh-nv@users.noreply.github.com>
473 lines
14 KiB
Plaintext
473 lines
14 KiB
Plaintext
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
|
|
*
|
|
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
* property and proprietary rights in and to this material, related
|
|
* documentation and any modifications thereto. Any use, reproduction,
|
|
* disclosure or distribution of this material and related documentation
|
|
* without an express license agreement from NVIDIA CORPORATION or
|
|
* its affiliates is strictly prohibited.
|
|
*/
|
|
|
|
#pragma once
|
|
#include "cuda_hint.cuh"
|
|
#include "defines.h"
|
|
#if !USE_CUSTOM_BARRIER
|
|
#include <cuda/std/barrier>
|
|
using CtaBarrier = cuda::barrier<cuda::thread_scope_block>;
|
|
#else
|
|
|
|
#ifndef __CUDACC__
|
|
#include <cuda_runtime_api.h>
|
|
#endif
|
|
|
|
#if __CUDACC_VER_MAJOR__ < 12
|
|
#define STR_REL_CTA ""
|
|
#define STR_ACQ_CTA ""
|
|
#else
|
|
#define STR_REL_CTA ".release.cta"
|
|
#define STR_ACQ_CTA ".acquire.cta"
|
|
#endif
|
|
|
|
enum class Scope : uint32_t
|
|
{
|
|
CTA = 0,
|
|
CGA = 1,
|
|
};
|
|
|
|
enum class ArriveOrder : uint32_t
|
|
{
|
|
RELEASE = 0,
|
|
RELAXED = 1,
|
|
};
|
|
|
|
enum class ArrivalToken : uint64_t
|
|
{
|
|
};
|
|
|
|
template <Scope defaultScope_ = Scope::CTA>
|
|
class MBarrier // rename this to MBarrier
|
|
{
|
|
public:
|
|
using ArrivalToken = ::ArrivalToken;
|
|
static constexpr Scope defaultScope = defaultScope_;
|
|
using arrival_token = ArrivalToken;
|
|
|
|
__device__ inline MBarrier(uint32_t count)
|
|
{
|
|
assert(count > 0);
|
|
asm volatile("mbarrier.init.b64 [%0], %1;\n" ::"l"(addr()), "r"(count) : "memory");
|
|
}
|
|
|
|
__device__ ~MBarrier()
|
|
{
|
|
asm volatile("mbarrier.inval.b64 [%0];\n" ::"l"(addr()) : "memory");
|
|
}
|
|
|
|
template <Scope scope = defaultScope, ArriveOrder order = ArriveOrder::RELEASE>
|
|
__device__ inline mha::conditional_t<scope == Scope::CTA, ArrivalToken, void> arrive(uint32_t update = 1)
|
|
{
|
|
ArrivalToken token;
|
|
#if __CUDA_ARCH__ >= 900
|
|
if constexpr (scope == Scope::CTA)
|
|
{
|
|
switch (order)
|
|
{
|
|
case ArriveOrder::RELEASE:
|
|
asm volatile("mbarrier.arrive.release.cta.b64 %0, [%1], %2;\n"
|
|
: "=l"(token)
|
|
: "l"(addr()), "r"(update)
|
|
: "memory");
|
|
break;
|
|
case ArriveOrder::RELAXED:
|
|
asm volatile("mbarrier.arrive.relaxed.cta.b64 %0, [%1], %2;\n"
|
|
: "=l"(token)
|
|
: "l"(addr()), "r"(update)
|
|
: "memory");
|
|
break;
|
|
}
|
|
return token;
|
|
}
|
|
else
|
|
{
|
|
static_assert(scope == Scope::CGA);
|
|
switch (order)
|
|
{
|
|
case ArriveOrder::RELEASE:
|
|
asm volatile("mbarrier.arrive.release.cluster.b64 _, [%0], %1;\n" ::"l"(addr()), "r"(update)
|
|
: "memory");
|
|
break;
|
|
case ArriveOrder::RELAXED:
|
|
asm volatile("mbarrier.arrive.relaxed.cluster.b64 _, [%0], %1;\n" ::"l"(addr()), "r"(update)
|
|
: "memory");
|
|
break;
|
|
}
|
|
return;
|
|
}
|
|
#else
|
|
static_assert(scope == Scope::CTA && order == ArriveOrder::RELEASE);
|
|
if (update > 1)
|
|
{
|
|
asm volatile("mbarrier.arrive.noComplete" STR_REL_CTA ".b64 %0, [%1], %2;\n"
|
|
: "=l"(token)
|
|
: "l"(addr()), "r"(update - 1U)
|
|
: "memory");
|
|
ArrivalToken refToken;
|
|
asm volatile("mbarrier.arrive" STR_REL_CTA ".b64 %0, [%1];\n" : "=l"(refToken) : "l"(addr()) : "memory");
|
|
assert(token == refToken);
|
|
return token;
|
|
}
|
|
else
|
|
{
|
|
asm volatile("mbarrier.arrive" STR_REL_CTA ".b64 %0, [%1];\n" : "=l"(token) : "l"(addr()) : "memory");
|
|
return token;
|
|
}
|
|
#endif
|
|
}
|
|
|
|
__device__ inline bool isLocal() const
|
|
{
|
|
uint32_t addrCtaRank;
|
|
asm("getctarank.u64 %0, %1;\n" : "=r"(addrCtaRank) : "l"(addr()));
|
|
uint32_t ctaRank;
|
|
asm("mov.u32 %0, %%cluster_ctarank;\n" : "=r"(ctaRank));
|
|
return addrCtaRank == ctaRank;
|
|
}
|
|
|
|
__device__ inline void remoteArrive(uint32_t update = 1)
|
|
{
|
|
#if __CUDA_ARCH__ >= 900
|
|
assert(!isLocal());
|
|
asm volatile("mbarrier.arrive.release.cluster.shared::cluster.b64 _, [%0], %1;\n"
|
|
:
|
|
: "l"(__cvta_generic_to_shared(&mBar)), "r"(update)
|
|
: "memory");
|
|
#else
|
|
asm volatile("trap;\n");
|
|
#endif
|
|
}
|
|
|
|
template <Scope scope = defaultScope, ArriveOrder order = ArriveOrder::RELEASE>
|
|
__device__ inline mha::conditional_t<scope == Scope::CTA, ArrivalToken, void> arrive_tx_relaxed(uint32_t txCount)
|
|
{
|
|
#if __CUDA_ARCH__ >= 900
|
|
if constexpr (scope == Scope::CTA)
|
|
{
|
|
ArrivalToken token;
|
|
asm volatile("mbarrier.arrive.expect_tx.relaxed.cta.b64 %0, [%1], %2;\n"
|
|
: "=l"(token)
|
|
: "l"(addr()), "r"(txCount)
|
|
: "memory");
|
|
return token;
|
|
}
|
|
else
|
|
{
|
|
asm volatile("mbarrier.arrive.expect_tx.relaxed.cluster.b64 _, [%0], %1;\n" ::"l"(addr()), "r"(txCount)
|
|
: "memory");
|
|
return;
|
|
}
|
|
#else
|
|
asm volatile("trap;\n");
|
|
#endif
|
|
}
|
|
|
|
template <Scope scope = defaultScope, ArriveOrder order = ArriveOrder::RELEASE>
|
|
__device__ inline mha::conditional_t<scope == Scope::CTA, ArrivalToken, void> arrive_tx(
|
|
uint32_t txCount, uint32_t arriveCount = 1)
|
|
{
|
|
#if __CUDA_ARCH__ >= 900
|
|
if (arriveCount == 1)
|
|
{
|
|
if constexpr (scope == Scope::CTA)
|
|
{
|
|
ArrivalToken token;
|
|
switch (order)
|
|
{
|
|
case ArriveOrder::RELEASE:
|
|
asm volatile("mbarrier.arrive.expect_tx.release.cta.b64 %0, [%1], %2;\n"
|
|
: "=l"(token)
|
|
: "l"(addr()), "r"(txCount)
|
|
: "memory");
|
|
break;
|
|
case ArriveOrder::RELAXED:
|
|
asm volatile("mbarrier.arrive.expect_tx.relaxed.cta.b64 %0, [%1], %2;\n"
|
|
: "=l"(token)
|
|
: "l"(addr()), "r"(txCount)
|
|
: "memory");
|
|
break;
|
|
}
|
|
return token;
|
|
}
|
|
else
|
|
{
|
|
static_assert(scope == Scope::CGA);
|
|
switch (order)
|
|
{
|
|
case ArriveOrder::RELEASE:
|
|
asm volatile(
|
|
"mbarrier.arrive.expect_tx.release.cluster.b64 _, [%0], %1;\n" ::"l"(addr()), "r"(txCount)
|
|
: "memory");
|
|
break;
|
|
case ArriveOrder::RELAXED:
|
|
asm volatile(
|
|
"mbarrier.arrive.expect_tx.relaxed.cluster.b64 _, [%0], %1;\n" ::"l"(addr()), "r"(txCount)
|
|
: "memory");
|
|
break;
|
|
}
|
|
return;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if constexpr (scope == Scope::CTA)
|
|
{
|
|
asm volatile("mbarrier.expect_tx.relaxed.cta.b64 [%0], %1;\n" ::"l"(addr()), "r"(txCount) : "memory");
|
|
}
|
|
else
|
|
{
|
|
asm volatile("mbarrier.expect_tx.relaxed.cluster.b64 [%0], %1;\n" ::"l"(addr()), "r"(txCount)
|
|
: "memory");
|
|
}
|
|
return arrive<scope, order>(arriveCount);
|
|
}
|
|
#else
|
|
asm volatile("trap;\n");
|
|
#endif
|
|
}
|
|
|
|
template <Scope scope = defaultScope>
|
|
__device__ inline bool test_wait(ArrivalToken&& token)
|
|
{
|
|
uint32_t ready;
|
|
if constexpr (scope == Scope::CGA)
|
|
{
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred ready;\n"
|
|
"mbarrier.test_wait.acquire.cluster.b64 ready, [%1], %2;\n"
|
|
"selp.b32 %0, 1, 0, ready;\n"
|
|
"}\n"
|
|
: "=r"(ready)
|
|
: "l"(addr()), "l"(token)
|
|
: "memory");
|
|
}
|
|
else
|
|
{
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred ready;\n"
|
|
"mbarrier.test_wait" STR_ACQ_CTA
|
|
".b64 ready, [%1], %2;\n"
|
|
"selp.b32 %0, 1, 0, ready;\n"
|
|
"}\n"
|
|
: "=r"(ready)
|
|
: "l"(addr()), "l"(token)
|
|
: "memory");
|
|
}
|
|
return ready != 0;
|
|
}
|
|
|
|
template <Scope scope = defaultScope>
|
|
__device__ inline bool test_wait_parity(bool parity)
|
|
{
|
|
uint32_t ready;
|
|
if constexpr (scope == Scope::CGA)
|
|
{
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred ready;\n"
|
|
"mbarrier.test_wait.parity.acquire.cluster.b64 ready, [%1], %2;\n"
|
|
"selp.b32 %0, 1, 0, ready;\n"
|
|
"}\n"
|
|
: "=r"(ready)
|
|
: "l"(addr()), "r"(uint32_t{parity})
|
|
: "memory");
|
|
}
|
|
else
|
|
{
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred ready;\n"
|
|
"mbarrier.test_wait.parity" STR_ACQ_CTA
|
|
".b64 ready, [%1], %2;\n"
|
|
"selp.b32 %0, 1, 0, ready;\n"
|
|
"}\n"
|
|
: "=r"(ready)
|
|
: "l"(addr()), "r"(uint32_t{parity})
|
|
: "memory");
|
|
}
|
|
return ready != 0;
|
|
}
|
|
#if __CUDA_ARCH__ >= 900
|
|
template <Scope scope = defaultScope>
|
|
__device__ inline bool try_wait(ArrivalToken&& token)
|
|
{
|
|
uint32_t ready;
|
|
if constexpr (scope == Scope::CGA)
|
|
{
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred ready;\n"
|
|
"mbarrier.try_wait.acquire.cluster.b64 ready, [%1], %2, %3;\n"
|
|
"selp.b32 %0, 1, 0, ready;\n"
|
|
"}\n"
|
|
: "=r"(ready)
|
|
: "l"(addr()), "l"(token), "n"(kSUSPEND_TIME_HINT)
|
|
: "memory");
|
|
}
|
|
else
|
|
{
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred ready;\n"
|
|
"mbarrier.try_wait.acquire.cta.b64 ready, [%1], %2, %3;\n"
|
|
"selp.b32 %0, 1, 0, ready;\n"
|
|
"}\n"
|
|
: "=r"(ready)
|
|
: "l"(addr()), "l"(token), "n"(kSUSPEND_TIME_HINT)
|
|
: "memory");
|
|
}
|
|
return ready != 0;
|
|
}
|
|
|
|
template <Scope scope = defaultScope>
|
|
__device__ inline bool try_wait_parity(bool parity)
|
|
{
|
|
uint32_t ready;
|
|
if constexpr (scope == Scope::CGA)
|
|
{
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred ready;\n"
|
|
"mbarrier.try_wait.parity.acquire.cluster.b64 ready, [%1], %2, %3;\n"
|
|
"selp.b32 %0, 1, 0, ready;\n"
|
|
"}\n"
|
|
: "=r"(ready)
|
|
: "l"(addr()), "r"(uint32_t{parity}), "n"(kSUSPEND_TIME_HINT)
|
|
: "memory");
|
|
}
|
|
else
|
|
{
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred ready;\n"
|
|
"mbarrier.try_wait.parity.acquire.cta.b64 ready, [%1], %2, %3;\n"
|
|
"selp.b32 %0, 1, 0, ready;\n"
|
|
"}\n"
|
|
: "=r"(ready)
|
|
: "l"(addr()), "r"(uint32_t{parity}), "n"(kSUSPEND_TIME_HINT)
|
|
: "memory");
|
|
}
|
|
return ready != 0;
|
|
}
|
|
#endif
|
|
template <Scope scope = defaultScope>
|
|
__device__ inline void wait(ArrivalToken&& token)
|
|
{
|
|
#if __CUDA_ARCH__ >= 900
|
|
poll<true>([&]() { return try_wait<scope>(ArrivalToken{token}); });
|
|
#else
|
|
poll<false>([&]() { return test_wait<scope>(ArrivalToken{token}); });
|
|
#endif
|
|
}
|
|
|
|
// starting from `parity = false`.
|
|
template <Scope scope = defaultScope>
|
|
__device__ inline void wait_parity(bool parity)
|
|
{
|
|
#if __CUDA_ARCH__ >= 900
|
|
poll<true>([&]() { return try_wait_parity<scope>(parity); });
|
|
#else
|
|
poll<false>([&]() { return test_wait_parity<scope>(parity); });
|
|
#endif
|
|
}
|
|
|
|
template <Scope scope = defaultScope, ArriveOrder order = ArriveOrder::RELEASE>
|
|
__device__ inline mha::enable_if_t<scope == Scope::CTA, void> arrive_and_wait(uint32_t update = 1)
|
|
{
|
|
wait<scope>(arrive<scope, order>(update));
|
|
}
|
|
|
|
private:
|
|
__device__ inline uint64_t addr() const
|
|
{
|
|
return reinterpret_cast<uint64_t>(&mBar);
|
|
}
|
|
|
|
template <bool funcSupportsBlocking, typename F>
|
|
__device__ inline static void poll(F&& func)
|
|
{
|
|
if constexpr (funcSupportsBlocking)
|
|
{
|
|
while (!func())
|
|
{
|
|
}
|
|
}
|
|
else
|
|
{
|
|
float sleepDuration = 0.125F;
|
|
while (!func())
|
|
{
|
|
// if (sleepDuration > 1) {
|
|
__nanosleep(uint32_t(sleepDuration));
|
|
// }
|
|
sleepDuration = sleepDuration * 1.25F + 0.F;
|
|
}
|
|
}
|
|
}
|
|
|
|
public:
|
|
static constexpr uint32_t kSUSPEND_TIME_HINT = 0xFFFFFFFFU;
|
|
|
|
private:
|
|
uint64_t mBar;
|
|
};
|
|
|
|
template <Scope defaultScope>
|
|
__device__ inline void init(MBarrier<defaultScope>* bar, uint32_t count)
|
|
{
|
|
new (bar) MBarrier<defaultScope>{count};
|
|
}
|
|
|
|
using CtaBarrier = MBarrier<Scope::CTA>;
|
|
using CgaBarrier = MBarrier<Scope::CGA>;
|
|
|
|
template <uint32_t nbBars>
|
|
__device__ inline bool toParity(uint32_t i)
|
|
{
|
|
return i % (nbBars * 2) / nbBars;
|
|
}
|
|
|
|
class NamedBarrier
|
|
{
|
|
public:
|
|
__device__ inline NamedBarrier(uint32_t idxBar, uint32_t arriveCount)
|
|
: mName{idxBar}
|
|
, mArriveCount{arriveCount}
|
|
{
|
|
assert(idxBar < 16 && arriveCount % 32 == 0);
|
|
}
|
|
|
|
__device__ inline void arrive() const
|
|
{
|
|
asm volatile("barrier.cta.arrive %0, %1;\n" ::"r"(mName), "r"(mArriveCount) : "memory");
|
|
}
|
|
|
|
__device__ inline void arrive_and_wait() const
|
|
{
|
|
asm volatile("barrier.cta.sync %0, %1;\n" ::"r"(mName), "r"(mArriveCount) : "memory");
|
|
}
|
|
|
|
private:
|
|
uint32_t const mName;
|
|
uint32_t const mArriveCount;
|
|
};
|
|
|
|
__device__ inline void namedBarSync(uint32_t idxBar, uint32_t arriveCount)
|
|
{
|
|
NamedBarrier bar{idxBar, arriveCount};
|
|
bar.arrive_and_wait();
|
|
}
|
|
#endif
|