Merge branch 'main' into update_mnnvl_test

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
This commit is contained in:
Bo Li 2026-01-08 22:07:55 +08:00 committed by GitHub
commit 11ed735a4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4735 changed files with 79614 additions and 33156 deletions

19
.github/CODEOWNERS vendored
View File

@ -1,5 +1,18 @@
# This file defines code ownership rules for the repository.
## TensorRT-LLM QA
### Integration Tests
/tests/integration/test_lists/qa @NVIDIA/trt-llm-qa
/tests/integration/defs/examples/test_ray.py @NVIDIA/trt-llm-qa-function
/tests/integration/defs/examples/test_redrafter.py @NVIDIA/trt-llm-qa-function
/tests/integration/defs/accuracy @NVIDIA/trt-llm-qa-function
/tests/integration/defs/stress_test @NVIDIA/trt-llm-qa-function
/tests/integration/defs/triton_server @NVIDIA/trt-llm-qa-function
/tests/integration/defs/test_e2e.py @NVIDIA/trt-llm-qa-function
/tests/integration/defs/disaggregated @NVIDIA/trt-llm-qa-serving
/tests/integration/defs/sysinfo @NVIDIA/trt-llm-qa-perf
/tests/integration/defs/perf @NVIDIA/trt-llm-qa-perf
/tests/integration/defs/perf/disagg @NVIDIA/trt-llm-qa-serving
## TensorRT-LLM Infra
### CI
@ -13,6 +26,11 @@
## TensorRT-LLM - Docs
/docs @NVIDIA/trt-llm-doc-owners
/CODING_GUIDELINES.md @NVIDIA/trt-llm-doc-owners
/CODE_OF_CONDUCT.md @NVIDIA/trt-llm-doc-owners
/CONTAINER_SOURCE.md @NVIDIA/trt-llm-doc-owners
/CONTRIBUTING.md @NVIDIA/trt-llm-doc-owners
/README.md @NVIDIA/trt-llm-doc-owners
## Examples
/examples @NVIDIA/trt-llm-doc-owners
@ -183,6 +201,7 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers
## and license compliance when adding, removing, or changing versions of dependencies.
### License Files
/LICENSE @NVIDIA/trt-llm-oss-compliance
/ATTRIBUTIONS-*.md @NVIDIA/trt-llm-oss-compliance
/jenkins/license_cpp.json @NVIDIA/trt-llm-ci-infra-devs @NVIDIA/trt-llm-infra-devs @NVIDIA/trt-llm-oss-compliance
### Python Dependency Management

View File

@ -11,10 +11,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v2
uses: actions/checkout@v6
- name: Get assignee
uses: actions/github-script@v6
uses: actions/github-script@v8
id: get-assignee
with:
github-token: ${{secrets.GITHUB_TOKEN}}

View File

@ -14,7 +14,7 @@ jobs:
pull-requests: write
steps:
- uses: actions/stale@v9
- uses: actions/stale@v10
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: 'Issue has not received an update in over 14 days. Adding stale label.'

View File

@ -53,6 +53,7 @@ jobs:
"amukkara",
"anish-shanbhag",
"arekay",
"arysef",
"atrifex",
"Autumn1998",
"baize97",
@ -121,6 +122,7 @@ jobs:
"heyuhhh",
"hijkzzz",
"hlu1",
"hnover-nv",
"HuiGao-NV",
"hvagadia",
"hypdeb",
@ -215,6 +217,7 @@ jobs:
"omera-nv",
"pamelap-nvidia",
"pcastonguay",
"pcicotti",
"pdrake-nv",
"peaceh-nv",
"pengbowang-nv",
@ -243,6 +246,7 @@ jobs:
"schetlur-nv",
"shaharmor98",
"shangz-ai",
"sherry-1001",
"shifangx",
"Shixiaowei02",
"Shunkangz",
@ -262,6 +266,7 @@ jobs:
"syuoni",
"Tabrizian",
"talorabr",
"taylor-yb-lee",
"tburt-nv",
"tcherckez-nvidia",
"thorjohnsen",

View File

@ -36,7 +36,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Add bot help comment
uses: actions/github-script@v6
uses: actions/github-script@v8
with:
script: |
const helpMessage = "" +

View File

@ -34,7 +34,7 @@ jobs:
if: github.event_name == 'workflow_dispatch'
steps:
- name: Update commit status
uses: actions/github-script@v6
uses: actions/github-script@v8
with:
script: |
state = 'pending'
@ -60,7 +60,7 @@ jobs:
with:
paths: results/**/results*.xml
- name: Update commit status
uses: actions/github-script@v6
uses: actions/github-script@v8
with:
script: |
github.rest.repos.createCommitStatus({

View File

@ -17,10 +17,10 @@ jobs:
if: github.repository == 'NVIDIA/TensorRT-LLM'
steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v3
uses: actions/setup-python@v6
with:
python-version: '3.x'

View File

@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout private action repository
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
repository: NVIDIA/goggles_action
path: ./.github/actions/goggles_action # local path to store the action

View File

@ -59,10 +59,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.10'

View File

@ -29,11 +29,11 @@ jobs:
name: Pre-commit Check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
ref: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.ref || github.ref }}
- uses: actions/setup-python@v5
- uses: actions/setup-python@v6
with:
python-version: '3.12'
cache: 'pip'

6
.gitignore vendored
View File

@ -40,6 +40,8 @@ tensorrt_llm/libs
tensorrt_llm/bindings.*.so
tensorrt_llm/bindings.pyi
tensorrt_llm/bindings/**/*.pyi
tensorrt_llm/tensorrt_llm_transfer_agent_binding.*.so
tensorrt_llm/tensorrt_llm_transfer_agent_binding.pyi
tensorrt_llm/deep_ep/
tensorrt_llm/deep_ep_cpp_tllm.*.so
tensorrt_llm/deep_ep_cpp_tllm.pyi
@ -56,13 +58,14 @@ tensorrt_llm/scripts
docs/source/**/*.rst
!docs/source/examples/index.rst
!docs/source/deployment-guide/config_table.rst
!docs/source/deployment-guide/note_sections.rst
!docs/source/_includes/note_sections.rst
*.swp
# Testing
.coverage.*
results_trt/
llm-test-workspace/
ad-test-workspace/
# build/debug
*.safetensors
@ -76,6 +79,7 @@ cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmha_v2_cu/
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.cpp
.devcontainer/.env
/examples/layer_wise_benchmarks/autotuner_cache/
/examples/layer_wise_benchmarks/profiles/
# User config files

View File

@ -38,8 +38,8 @@ FetchContent_Declare(
FetchContent_Declare(
deepgemm
GIT_REPOSITORY https://github.com/ruoqianguo/DeepGEMM
GIT_TAG 6cb8161516302550785d9af924d2778afef1f3f6 # swapab_sm100 branch
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM
GIT_TAG 4ff3f54d9b7ed3129e4f36f9871232ea7ecab86b # nv_dev branch
GIT_SUBMODULES_RECURSE
ON
SOURCE_SUBDIR

View File

@ -487,9 +487,17 @@ else:
f.read()
```
## Documentation Guidelines
#### CLI Options in Documentation
1. When documenting CLI commands for `trtllm-serve`, `trtllm-bench`, `trtllm-eval`, or similar tools, prefer using `--config` over `--extra_llm_api_options` for specifying configuration files.
- `--config` is the preferred, shorter alias for configuration file options.
- Example: `trtllm-serve --model <model_path> --config config.yaml` (preferred)
- Avoid: `trtllm-serve --model <model_path> --extra_llm_api_options config.yaml`
## NVIDIA Copyright
1. All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. The following block of text should be prepended to the top of all files. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
1. All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification. The following block of text should be prepended to the top of all files. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
```cpp
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

View File

@ -10,7 +10,7 @@ state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.<
[![python](https://img.shields.io/badge/python-3.10-green)](https://www.python.org/downloads/release/python-31012/)
[![cuda](https://img.shields.io/badge/cuda-13.0.0-green)](https://developer.nvidia.com/cuda-downloads)
[![torch](https://img.shields.io/badge/torch-2.9.0-green)](https://pytorch.org)
[![version](https://img.shields.io/badge/release-1.2.0rc6-green)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/version.py)
[![version](https://img.shields.io/badge/release-1.2.0rc8-green)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/version.py)
[![license](https://img.shields.io/badge/license-Apache%202-blue)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/LICENSE)
[Architecture](https://nvidia.github.io/TensorRT-LLM/developer-guide/overview.html)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Performance](https://nvidia.github.io/TensorRT-LLM/developer-guide/perf-overview.html)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Documentation](https://nvidia.github.io/TensorRT-LLM/)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap)

View File

@ -68,6 +68,7 @@ option(USING_OSS_CUTLASS_MOE_GEMM "Using open sourced Cutlass moe gemm kernel"
ON)
option(USING_OSS_CUTLASS_ALLREDUCE_GEMM
"Using open sourced Cutlass AR gemm kernel" ON)
option(SKIP_SOFTMAX_STAT "Enable Statistics of Skip-Softmax" OFF)
message(STATUS "ENABLE_NVSHMEM is ${ENABLE_NVSHMEM}")
@ -360,6 +361,11 @@ else()
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_NVSHMEM=0>)
endif()
if(SKIP_SOFTMAX_STAT)
add_compile_definitions("SKIP_SOFTMAX_STAT")
message(STATUS "SKIP_SOFTMAX_STAT is enabled")
endif()
# Fix linking issue with TRT 10, the detailed description about `--mcmodel` can
# be found in
# https://gcc.gnu.org/onlinedocs/gcc/x86-Options.html#index-mcmodel_003dmedium-1

View File

@ -380,6 +380,7 @@ public:
, mBeamWidth(beamWidth)
, mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig))
, mNumFrontBlocksRemoved(0)
, mCurrentPrepopulatedPromptLen(std::numeric_limits<SizeType32>::max())
{
auto const numWindowSizes = windowSizeToMetadata.size();
mCacheBlockIds.reserve(numWindowSizes);
@ -500,6 +501,20 @@ public:
return mKvCacheRetentionConfig.getDirectory();
}
[[nodiscard]] SizeType32 getCurrentPrepopulatedPromptLen() const
{
return mCurrentPrepopulatedPromptLen;
}
void setCurrentPrepopulatedPromptLen(SizeType32 currentPrepopulatedPromptLen)
{
TLLM_CHECK_WITH_INFO(currentPrepopulatedPromptLen <= mCurrentPrepopulatedPromptLen,
"currentPrepopulatedPromptLen must be updated non-increasingly due to the "
"assumption that smaller window sizes have shorter or equal"
"currentPrepopulatedPromptLen in WindowSizeManager::loadOrAllocateBlocks.");
mCurrentPrepopulatedPromptLen = currentPrepopulatedPromptLen;
}
private:
// Request id of the sequence
LlmRequest::RequestIdType mRequestId;
@ -517,6 +532,8 @@ private:
SizeType32 mNumFrontBlocksRemoved;
// Set of used blocks by the sequence
std::set<KVCacheBlock::IdType> mUsedBlocks;
// Current prepopulated prompt length
SizeType32 mCurrentPrepopulatedPromptLen;
};
// attach metadata to a pool pointer

View File

@ -17,6 +17,7 @@
#pragma once
#include "tensorrt_llm/executor/serialization.h"
#include <atomic>
#include <vector>
namespace tensorrt_llm::executor::kv_cache
@ -27,8 +28,9 @@ class CommState;
struct DataContext
{
public:
explicit DataContext(int tag)
explicit DataContext(int tag, std::atomic<bool> const& transferTerminate = sDefaultTransferTerminate)
: mTag{tag}
, mTransferTerminate(transferTerminate)
{
}
@ -37,8 +39,15 @@ public:
return mTag;
}
[[nodiscard]] std::atomic<bool> const& getTransferTerminate() const noexcept
{
return mTransferTerminate;
}
private:
inline static std::atomic<bool> sDefaultTransferTerminate{false};
int const mTag;
std::atomic<bool> const& mTransferTerminate;
};
class Connection

View File

@ -1468,7 +1468,8 @@ public:
DEFAULT = 0,
MPI = 1,
UCX = 2,
NIXL = 3
NIXL = 3,
MOONCAKE = 4
};
explicit CacheTransceiverConfig(std::optional<BackendType> backendType = std::nullopt,
std::optional<size_t> maxNumTokens = std::nullopt, std::optional<int> kvTransferTimeoutMs = std::nullopt,

View File

@ -274,13 +274,20 @@ private:
std::optional<SyncMessage> mSyncMessage;
};
enum class TransferState : uint8_t
{
kIN_PROGRESS,
kSUCCESS,
kFAILURE,
};
// Data structure for checking the status of active transfer operations.
class TransferStatus
{
public:
virtual ~TransferStatus() = default;
[[nodiscard]] virtual bool isCompleted() const = 0;
virtual void wait() const = 0;
virtual TransferState wait(int64_t timeout_ms = -1) const = 0;
};
struct BaseAgentConfig
@ -288,6 +295,8 @@ struct BaseAgentConfig
std::string mName;
bool useProgThread;
bool multiThread;
bool useListenThread;
unsigned int numWorkers;
};
class BaseTransferAgent
@ -391,6 +400,14 @@ template <typename... Args>
"libtensorrt_llm_nixl_wrapper.so", "createNixlTransferAgent");
return func(std::forward<Args>(args)...);
}
if (backend == "mooncake")
{
auto& loader = DynLibLoader::getInstance();
using CreateMooncakeFuncType = std::unique_ptr<BaseTransferAgent> (*)(BaseAgentConfig const*);
auto* func = loader.getFunctionPointer<CreateMooncakeFuncType>(
"libtensorrt_llm_mooncake_wrapper.so", "createMooncakeTransferAgent");
return func(std::forward<Args>(args)...);
}
TLLM_THROW("Unknown backend name.");
}

View File

@ -104,12 +104,14 @@ public:
[[nodiscard]] SizeType32 constexpr getTensorParallelRank() const noexcept
{
return mRank % mTensorParallelism;
// Layout: pp is outermost, then tp, then cp is innermost (consecutive).
return (mRank % (mTensorParallelism * mContextParallelism)) / mContextParallelism;
}
[[nodiscard]] SizeType32 constexpr getContextParallelRank() const noexcept
{
return (mRank % (mTensorParallelism * mContextParallelism)) / mTensorParallelism;
// Layout: pp is outermost, then tp, then cp is innermost (consecutive).
return mRank % mContextParallelism;
}
[[nodiscard]] SizeType32 constexpr getLocalRank() const noexcept

View File

@ -69,6 +69,11 @@ PREPROCESSOR_FLAGS += -DUSE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE
# Do we want to use half accumulation for flash attention
PREPROCESSOR_FLAGS += -DHALF_ACCUMULATION_FOR_FLASH_ATTENTION
# Print the resulted sparsity given threshold in Skip-Softmax attention
# Note: You only need to "python scripts/build_wheel.py -D SKIP_SOFTMAX_STAT=ON ..." to use it inside TRTLLM.
# Turn this on manually only if you want to build&run the unittest (bin/fmha.exe) with SKIP_SOFTMAX_STAT.
# PREPROCESSOR_FLAGS += -DSKIP_SOFTMAX_STAT
# Add FLAGS when generating cubins.
ifdef GENERATE_CUBIN
PREPROCESSOR_FLAGS += -DGENERATE_CUBIN

View File

@ -154,7 +154,9 @@ spec_fields = (
'head_size_v',
'sage_block_sizes',
'output_dtype',
'is_mtp')
'is_mtp',
'enable_skip_softmax',
)
kernel_spec = namedtuple('kernel_spec', spec_fields)
kernel_spec.__new__.__defaults__ = (
1, # ctas_per_head
@ -179,7 +181,9 @@ kernel_spec.__new__.__defaults__ = (
0, # head size of V
None, # sage_block_sizes
None, # output_dtype, same as dtype by default.
False) # use MTP or not
False, # use MTP or not
False, # enable skip softmax
)
generate_cu_trtllm = os.environ.get('GENERATE_CU_TRTLLM',
'False').lower() == 'true'
@ -1435,6 +1439,7 @@ using Ktraits = {kernel_traits_header}
USE_TMA_STORE,
{enable_attn_logit_softcapping_flag},
{return_softmax_stats_flag},
{enable_skip_softmax_flag},
{output_dtype_},
{sage_block_size_q},
{sage_block_size_k},
@ -1458,6 +1463,7 @@ using Ktraits_causal = {kernel_traits_header}
USE_TMA_STORE,
{enable_attn_logit_softcapping_flag},
{return_softmax_stats_flag},
{enable_skip_softmax_flag},
{output_dtype_}>;
using Ktraits_sliding_or_chunked_causal = {kernel_traits_header}
@ -1478,6 +1484,7 @@ using Ktraits_sliding_or_chunked_causal = {kernel_traits_header}
USE_TMA_STORE && false,
{enable_attn_logit_softcapping_flag},
{return_softmax_stats_flag},
{enable_skip_softmax_flag},
{output_dtype_}>;
using Ktraits_custom_mask = {kernel_traits_header}
@ -1498,6 +1505,7 @@ using Ktraits_custom_mask = {kernel_traits_header}
USE_TMA_STORE && false,
{enable_attn_logit_softcapping_flag},
{return_softmax_stats_flag},
{enable_skip_softmax_flag},
{output_dtype_}>;
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -1835,6 +1843,8 @@ def encode_name(kernel_spec):
if kernel_spec.enable_attn_logit_softcapping:
feature_tags += '_softcapping'
if kernel_spec.enable_skip_softmax:
feature_tags += '_skipSoftmax'
if kernel_spec.sage_block_sizes:
feature_tags += f"_sage_{'_'.join(map(str, kernel_spec.sage_block_sizes))}"
if kernel_spec.output_dtype:
@ -2131,6 +2141,8 @@ def get_kernel_code(kspec, kname, lname):
return_softmax_stats_flag = pythonBoolean2cpp[kspec.return_softmax_stats]
enable_skip_softmax_flag = pythonBoolean2cpp[kspec.enable_skip_softmax]
# needed by warpspec kernels.
fp8_kernel = kspec.dtype in ["e4m3", "e4m3_fp32"]
kernel_traits_header = "fmha::ws::Kernel_traits_Hopper_qgmma_e4m3_fp32<" if fp8_kernel \
@ -2331,6 +2343,8 @@ def get_api_code(specs_names):
f'&& sage_block_size_k == {sage_block_size_k} ' \
f'&& sage_block_size_v == {sage_block_size_v} '
il_check += '&& enable_skip_softmax ' if kspec.enable_skip_softmax else '&& !enable_skip_softmax '
il_check += '&& params.use_int8_scale_max ' if kspec.has_scale_max else '&& !params.use_int8_scale_max '
slen = kspec.seq_len * kspec.ctas_per_head if not kspec.flash_attention else 0
@ -2607,6 +2621,7 @@ const bool warp_specialization = launch_params.warp_specialization
const bool use_tma = launch_params.use_tma;
const bool use_flash_attention = launch_params.flash_attention;
const bool enable_attn_logit_softcapping = launch_params.enable_attn_logit_softcapping;
const bool enable_skip_softmax = launch_params.enable_skip_softmax;
const int attention_input_layout = static_cast<int>(launch_params.attention_input_layout);
// tiled variant uses ldgsts
const bool use_tiled = launch_params.use_granular_tiling;
@ -2785,6 +2800,8 @@ def get_kernel_traits_code(specs_names):
enable_attn_logit_softcapping_flag = pythonBoolean2cpp[
kspec.enable_attn_logit_softcapping]
enable_skip_softmax_flag = pythonBoolean2cpp[kspec.enable_skip_softmax]
tmp = dict(locals(), **kspec._asdict())
if effective_sm < 90:
@ -2903,7 +2920,8 @@ def get_kernel_traits_code(specs_names):
{input_layout_flag},
__use_tma_store__ /* USE_TMA_STORE */,
{enable_attn_logit_softcapping_flag},
{return_softmax_stats_flag}>;
{return_softmax_stats_flag},
{enable_skip_softmax_flag}>;
printf("%s %d %d %s %d %d\\n",
\"{kname}\",
@ -3062,9 +3080,16 @@ def get_kernel_traits_code(specs_names):
# For now:
# 1. Hopper head_size 128 kernel uses cubins for performance regressions.
# 2. Hopper sm89 with e4m3/e4m3_fp32 dtype uses cubins for accuracy regressions (will be fixed).
# 3. For skip-softmax attention feature, we force not to use cubins.
# You should set the condition `use_cubin_header` to false if you have modified the source codes of those kernels that use cubins.
# This ensures that the kernels will be recompiled using the updated source code rather than relying on precompiled cubins.
def use_cubin_header(sm, head_size, dtype, output_dtype=None):
def use_cubin_header(sm,
head_size,
dtype,
output_dtype=None,
enable_skip_softmax=False):
if enable_skip_softmax:
return False
if 'e4m3' in dtype and output_dtype in ['bf16', 'fp16']:
return False
return (sm == 90 and head_size == 128) or (sm == 89 and 'e4m3' in dtype)
@ -3079,7 +3104,8 @@ def get_cubin_header(kernel_traits, specs_names):
launchers_dict = {}
for kspec, fname, lname, kname in specs_names:
if generate_cu_trtllm and not use_cubin_header(
kspec.sm, kspec.head_size, kspec.dtype, kspec.output_dtype):
kspec.sm, kspec.head_size, kspec.dtype, kspec.output_dtype,
kspec.enable_skip_softmax):
continue
name = fname.replace('.', '_')
data = 'extern unsigned char cubin_{name}_cubin[];'.format(name=name)
@ -3111,8 +3137,9 @@ def get_cubin_header(kernel_traits, specs_names):
'q_kv_', '').replace('q_paged_kv_', '').replace(
'q_k_v_', '').replace('ws_', '').replace(
'softcapping_',
'').replace('sage_',
'').replace('output_', ''))
'').replace('sage_', '').replace(
'skipSoftmax_',
'').replace('output_', ''))
flash_attention = 'flash_attention' in kname
warp_specialization = 'tma_ws' in kname
toks = tname.split('_')
@ -3209,6 +3236,8 @@ def get_cubin_header(kernel_traits, specs_names):
return_softmax_stats_flag = pythonBoolean2cpp[sm != '90' or (
sm == '90' and '_softmax' in kname)]
enable_skip_softmax_flag = pythonBoolean2cpp['_skipSoftmax' in kname]
# meta_unroll_step
meta_unroll_step = unroll_step if ('_nl' in kname
or '_ws' in kname) else '0'
@ -3235,7 +3264,8 @@ def get_cubin_header(kernel_traits, specs_names):
def get_lname_from_kname(kname: str) -> str:
if use_cubin_header(int(sm), int(head_size), prec.lower(),
output_prec.lower()):
output_prec.lower(),
enable_skip_softmax_flag):
return 'nullptr'
lname = kname.replace('_kernel', '')
mask_types = [
@ -3253,15 +3283,15 @@ def get_cubin_header(kernel_traits, specs_names):
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, {cubin_name}, \
{cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
'''.format(**locals()) if use_cubin_header(int(sm),
int(head_size), prec.lower(),
output_prec.lower()) else '''\
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {enable_skip_softmax_flag}, {lname}}}\
'''.format(**locals()) if use_cubin_header(int(sm), int(head_size),
prec.lower(), output_prec.lower(),
enable_skip_softmax_flag) else '''\
{{ DATA_TYPE_{prec}, DATA_TYPE_{output_prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, {head_size_v}, \
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, nullptr, \
0, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {enable_skip_softmax_flag}, {lname}}}\
'''.format(**locals())
else:
code = '''\
@ -3269,7 +3299,7 @@ def get_cubin_header(kernel_traits, specs_names):
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, {cubin_name}, \
{cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}}}\
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {enable_skip_softmax_flag}}}\
'''.format(**locals())
if sm in metadata_v2_dict:
metadata_v2_dict[sm].append(code)
@ -3377,7 +3407,8 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
bool mAlibiSupported;
bool mTiled;
bool mEnableAttnLogitSoftcapping;
bool mReturnSoftmaxStats;{launcher_line}
bool mReturnSoftmaxStats;
bool mEnableSkipSoftmax;{launcher_line}
}} sMhaKernelMetaInfosV2[] = {{
{metadata_v2}
}};
@ -3438,6 +3469,7 @@ static const struct TestMetaV2
bool mTiled;
bool mEnableAttnLogitSoftcapping;
bool mReturnSoftmaxStats;
bool mEnableSkipSoftmax;
}} metaV2[] = {{
{metadata_v2}
}};
@ -3484,7 +3516,8 @@ struct FusedMultiHeadAttentionKernelMetaInfoV2
bool mAlibiSupported;
bool mTiled;
bool mEnableAttnLogitSoftcapping;
bool mReturnSoftmaxStats;{launcher_line}
bool mReturnSoftmaxStats;
bool mEnableSkipSoftmax;{launcher_line}
}};
extern const FusedMultiHeadAttentionKernelMetaInfoV2 sMhaKernelMetaInfosV2[];
@ -3580,7 +3613,8 @@ struct FusedMultiHeadAttentionKernelMetaInfoV2
bool mAlibiSupported;
bool mTiled;
bool mEnableAttnLogitSoftcapping;
bool mReturnSoftmaxStats;{launcher_line}
bool mReturnSoftmaxStats;
bool mEnableSkipSoftmax;{launcher_line}
}};
extern const FusedMultiHeadAttentionKernelMetaInfoV2 sMhaKernelMetaInfosV2[] = {{
@ -3637,7 +3671,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_
return '\n'.join(lines)
target = "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled"
new_line = '{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, false, true, true, false, true, nullptr},'
new_line = '{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, false, true, true, false, true, false, nullptr},'
result = modify_kernel_line(result, target, new_line)
# make sure only one empty line at the end
@ -3801,7 +3835,10 @@ def enumerate_hgmma_ldgsts_kernels(specs, sm=90, dtype='fp16'):
# Note this will be used in TRT-LLM.
def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
def enumerate_hgmma_flash_warpspec_kernels(specs,
sm=90,
dtype='fp16',
enable_skip_softmax=False):
scheduling_mode = int(os.getenv('SCHEDULING_MODE', '1'))
@ -3851,7 +3888,8 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
return_softmax_stats=return_softmax,
scheduling_mode=scheduling_mode,
input_layout=input_layout))
input_layout=input_layout,
enable_skip_softmax=enable_skip_softmax))
specs.append(
kernel_spec(
@ -3883,7 +3921,8 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
return_softmax_stats=return_softmax,
scheduling_mode=scheduling_mode,
input_layout=input_layout))
input_layout=input_layout,
enable_skip_softmax=enable_skip_softmax))
specs.append(
kernel_spec(
@ -3915,7 +3954,8 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
return_softmax_stats=return_softmax,
scheduling_mode=scheduling_mode,
input_layout=input_layout))
input_layout=input_layout,
enable_skip_softmax=enable_skip_softmax))
'''
smem size = (q_step * d * q_buffers * NUM_COMPUTE_GROUPS
+ (kv_step * d + kv_step * dv) * kv_buffers) * ele_size
@ -3967,7 +4007,8 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
sm=90,
dtype='e4m3',
sage_block_sizes=None,
output_dtype=None):
output_dtype=None,
enable_skip_softmax=False):
scheduling_mode = int(os.getenv('SCHEDULING_MODE', '1'))
@ -4021,7 +4062,8 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
scheduling_mode=scheduling_mode,
input_layout=input_layout,
sage_block_sizes=sage_block_sizes,
output_dtype=output_dtype))
output_dtype=output_dtype,
enable_skip_softmax=enable_skip_softmax))
# 64 < D <=128: KV_STEP = 128
specs.append(
@ -4056,7 +4098,8 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
scheduling_mode=scheduling_mode,
input_layout=input_layout,
sage_block_sizes=sage_block_sizes,
output_dtype=output_dtype))
output_dtype=output_dtype,
enable_skip_softmax=enable_skip_softmax))
# 128 < D <=256: KV_STEP = 128
specs.append(
@ -4092,7 +4135,8 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
scheduling_mode=scheduling_mode,
input_layout=input_layout,
sage_block_sizes=sage_block_sizes,
output_dtype=output_dtype))
output_dtype=output_dtype,
enable_skip_softmax=enable_skip_softmax))
if not skip_mla_combination:
# context MLA (192x128)
@ -6374,13 +6418,21 @@ def enumerate_kernels():
enumerate_igmma_kernels(specs, sm=90)
enumerate_qgmma_kernels(specs, sm=90)
# need to add bf16 kernels if needed
enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16')
enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='bf16')
enumerate_qgmma_flash_warpspec_kernels(specs, sm=90, dtype='e4m3')
enumerate_qgmma_flash_warpspec_kernels(specs,
sm=90,
dtype='e4m3',
output_dtype="bf16")
for enable_skip_softmax in [False, True]:
if enable_skip_softmax and 'DISABLE_SKIP_SOFTMAX' in os.environ:
continue
enumerate_hgmma_flash_warpspec_kernels(
specs, sm=90, dtype='fp16', enable_skip_softmax=enable_skip_softmax)
enumerate_hgmma_flash_warpspec_kernels(
specs, sm=90, dtype='bf16', enable_skip_softmax=enable_skip_softmax)
enumerate_qgmma_flash_warpspec_kernels(
specs, sm=90, dtype='e4m3', enable_skip_softmax=enable_skip_softmax)
enumerate_qgmma_flash_warpspec_kernels(
specs,
sm=90,
dtype='e4m3',
output_dtype="bf16",
enable_skip_softmax=enable_skip_softmax)
# For now SageAttention only needs BF16
# block_size_q should be divisible by 64

View File

@ -256,7 +256,8 @@ struct Compute
actual_kv_seqlen, alibi_head_scale, \
USE_CUSTOM_MASK ? (head_info.mask_sum_s + q_step_idx * STEP_Q + local_q_tile_offset) \
: (q_step_idx * STEP_Q + head_info.q_tile_offset), \
kv_step_idx * STEP_KV, sage_scale_row, cbr, cbr_v, mutex_accessor, kv_step_idx == kv_idx_end - 1);
kv_step_idx * STEP_KV, sage_scale_row, cbr, cbr_v, mutex_accessor, \
&shared->skip_softmax_votes[kv_step_idx & 1][warpgroup_id], kv_step_idx == kv_idx_end - 1);
////////////////////////////////////////////////////////////////////////////////////////////////
@ -360,6 +361,12 @@ struct Compute
// Contiguous QKV FMHA assumes q, and kv have the same sequence length.
int const actual_kv_seqlen = SEPARATE_Q_KV_BUFFER ? head_info.actual_kv_seqlen : actual_q_seqlen;
// Update threshold of Skip-Softmax
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX)
{
softmax.skip_softmax_threshold = params.skip_softmax_threshold_scale_factor / actual_kv_seqlen;
}
// Calculate the alibi head_scaling_factor.
float alibi_head_scale
= APPLY_ALIBI ? get_alibi_head_scaling_factor<AlibiParams>(head_info.bidh, params.alibi_params) : 0.f;
@ -513,6 +520,13 @@ struct Compute
}
}
}
#ifdef SKIP_SOFTMAX_STAT
if (tidx == 0)
{
atomicAdd(params.skip_softmax_total_blocks, softmax.total_blocks);
atomicAdd(params.skip_softmax_skipped_blocks, softmax.skipped_blocks);
}
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////
@ -522,8 +536,15 @@ struct Compute
Compute_tile_o& ctile_o, float (&p_max)[Mma_tile_p::CORES_M], float (&p_sum)[Mma_tile_p::CORES_M],
int const tidx, int const actual_kv_seqlen, float const alibi_head_scale, int const row_offset,
int const col_offset, int const sage_scale_row, Circular_buffer_q_reader& cbr, Circular_buffer_kv_reader& cbr_v,
OrderedMutexAccessor& mutex, bool complete = false)
OrderedMutexAccessor& mutex, uint32_t* skip_softmax_vote, bool complete = false)
{
// Skip-softmax vote initialization
if (tidx == 0)
{
// Note that we need a named_barrier_wait in compute_single_tile to make sure init is before voting.
*skip_softmax_vote = 1;
}
// load the scales of K/V from global memory
#define LOAD_SCALES_KV(dst, which, blocks_per_step, block_size) \
if constexpr (block_size > 0) \
@ -557,6 +578,10 @@ struct Compute
// Ctile_p is only used once by each n step.
ctile_p.clear();
// If skip_softmax is enabled, make sure there is no racing between the initialization and writing of
// skip_softmax_vote.
named_barrier_wait(Kernel_traits::SKIP_SOFTMAX_BARRIER_ID + threadIdx.x / 128, 128);
// BMM1 (Q x K').
warpgroup_arrive();
@ -626,8 +651,22 @@ struct Compute
softmax.apply_alibi_and_mask<APPLY_MASK>(
ctile_p, params.alibi_params, alibi_head_scale, actual_kv_seqlen, row_offset, col_offset);
// Softmax Exp, max/sum, and update scales.
softmax.compute_and_update_scale<IS_FIRST_COL>(p_max, p_sum);
// Softmax Exp, max/sum, and update scales. If returns false we skip the rest.
if (!softmax.compute_and_update_scale<IS_FIRST_COL>(p_max, p_sum, skip_softmax_vote))
{
if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 1)
{
// Notify another warpgroup to execute QGMMA.
mutex.named_bar_arrive();
}
// Need to wait V, otherwise compute-sanitizer synccheck will fail.
int ready2 = cbr_v.peek();
if (!ready2)
{
cbr_v.wait();
}
return;
}
// experiments show that here is the best place to load scales of V
float scales_v[SAGE_BLOCKS_PER_STEP_V];

View File

@ -17,6 +17,8 @@
#pragma once
#include "fmha/hopper/arrive_wait.h"
#include <fmha/softmax.h>
#include <fmha/traits.h>
#include <fmha/utils.h>
@ -104,6 +106,12 @@ struct Softmax_base
CHECK_IF_NEG_INF_EXISTS = SLIDING_OR_CHUNKED_ATTENTION || USE_CUSTOM_MASK
};
// There are 2 warpgroups so 0x3 and 0x4 are used
enum
{
SKIP_SOFTMAX_BARRIER = Kernel_traits::SKIP_SOFTMAX_BARRIER_ID
};
// Ctor.
template <typename Params>
inline __device__ Softmax_base(Params params, int tidx)
@ -114,6 +122,11 @@ struct Softmax_base
, log2_chunked_attention_size_(params.log2_chunked_attention_size)
, packed_mask_ptr_{reinterpret_cast<uint32_t*>(params.packed_mask_ptr)}
, params_packed_mask_stride_in_bytes_{params.packed_mask_stride_in_bytes}
#ifdef SKIP_SOFTMAX_STAT
, total_blocks(0)
, skipped_blocks(0)
#endif
, skip_softmax_threshold(0)
{
int warp = tidx / 32;
@ -330,24 +343,22 @@ struct Softmax_base
}
// Calculate max/sum, and update flash-attention scales.
// Returns false if skipped due to skip-softmax attention feature.
template <bool IS_FIRST_COL>
inline __device__ void compute_and_update_scale(
float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M])
inline __device__ bool compute_and_update_scale(
float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M], uint32_t* skip_softmax_vote)
{
float const scale = reinterpret_cast<float const&>(scale_bmm1_);
// whether this warpgroup skips the softmax
constexpr bool may_skip = Kernel_traits::ENABLE_SKIP_SOFTMAX && !IS_FIRST_COL;
bool skip = may_skip;
// Row-wise max of current tile.
#pragma unroll
for (int mi = 0; mi < Mma_tile_p::CORES_M; mi++)
{
if (IS_FIRST_COL)
{
local_max_[mi] = elt_[mi][0];
}
else
{
local_max_[mi] = fmaxf(global_max[mi], elt_[mi][0]);
}
local_max_[mi] = elt_[mi][0];
#pragma unroll
for (int ni = 1; ni < Mma_tile_p::CORES_N * 2; ni++)
{
@ -355,6 +366,56 @@ struct Softmax_base
}
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 1), local_max_[mi]);
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 2), local_max_[mi]);
if constexpr (may_skip)
{
// AND(&) the CORES_M results, then `skip` means whether to skip
// the CORES_M(=2) rows
if constexpr (!EXP2F_OPTIMIZATION)
{
skip &= expf(local_max_[mi] - global_max[mi]) < skip_softmax_threshold;
}
else
{
skip &= exp2f((local_max_[mi] - global_max[mi]) * scale) < skip_softmax_threshold;
}
}
if (!IS_FIRST_COL)
{
local_max_[mi] = fmaxf(local_max_[mi], global_max[mi]);
}
}
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX)
{
#ifdef SKIP_SOFTMAX_STAT
total_blocks++;
#endif
if constexpr (may_skip)
{
// AND(&) the results together in a warp, then `skip` means whether to skip
// all the 16 rows managed by this warp.
// each 4 threads (e.g. T0~T3) have the same `skip`, only 0x11111111 is needed
// instead of 0xffffffff. But the perf is the same.
skip = __all_sync(0xffffffff, skip);
if (threadIdx.x % 32 == 0)
{
// The leader of each warp votes.
atomicAnd(skip_softmax_vote, uint32_t(skip));
}
// WG0 uses 0x3 barrier, WG1 uses 0x4 barrier
named_barrier_wait(SKIP_SOFTMAX_BARRIER + threadIdx.x / 128, 128);
skip = *((uint32_t volatile*) skip_softmax_vote);
if (skip)
{
#ifdef SKIP_SOFTMAX_STAT
skipped_blocks++;
#endif
return false;
}
}
}
// Softmax Exp.
@ -436,6 +497,7 @@ struct Softmax_base
global_max[mi] = max_new;
}
}
return true;
}
// Update flash attention scales and pack elements for BMM2.
@ -513,6 +575,13 @@ struct Softmax_base
float correction_[Mma_tile_p::CORES_M];
// The packed mask.
uint4 packed_mask_;
// Skip softmax when exp(local_max - global_max) < skip_softmax_threshold.
float skip_softmax_threshold;
#ifdef SKIP_SOFTMAX_STAT
// Statistics of skip-softmax
uint32_t total_blocks;
uint32_t skipped_blocks;
#endif
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -868,9 +937,10 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
}
// Calculate max/sum, and update flash-attention scales.
// Returns false if skipped due to skip-softmax attention feature.
template <bool IS_FIRST_COL>
inline __device__ void compute_and_update_scale(
float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M])
inline __device__ bool compute_and_update_scale(
float (&global_max)[Mma_tile_p::CORES_M], float (&global_sum)[Mma_tile_p::CORES_M], uint32_t* skip_softmax_vote)
{
float const scale = reinterpret_cast<float const&>(this->scale_bmm1_);
float(&local_max_)[Mma_tile_p::CORES_M] = this->local_max_;
@ -878,18 +948,15 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
float(&correction_)[Mma_tile_p::CORES_M] = this->correction_;
float(&elt_)[Mma_tile_p::CORES_M][Mma_tile_p::CORES_N * 2] = this->elt_;
// whether this warpgroup skips the softmax
constexpr bool may_skip = Kernel_traits::ENABLE_SKIP_SOFTMAX && !IS_FIRST_COL;
bool skip = may_skip;
// Row-wise max of current tile.
#pragma unroll
for (int mi = 0; mi < Mma_tile_p::CORES_M; mi++)
{
if (IS_FIRST_COL)
{
local_max_[mi] = elt_[mi][0];
}
else
{
local_max_[mi] = fmaxf(global_max[mi], elt_[mi][0]);
}
local_max_[mi] = elt_[mi][0];
#pragma unroll
for (int ni = 1; ni < Mma_tile_p::CORES_N * 2; ni++)
{
@ -897,6 +964,56 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
}
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 1), local_max_[mi]);
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 2), local_max_[mi]);
// AND(&) the CORES_M results, then `skip` means whether to skip
// the CORES_M(=2) rows
if constexpr (may_skip)
{
// AND(&) the CORES_M results, then `skip` means whether to skip
// the CORES_M(=2) rows
if constexpr (!EXP2F_OPTIMIZATION)
{
skip &= expf(local_max_[mi] - global_max[mi]) < this->skip_softmax_threshold;
}
else
{
skip &= exp2f((local_max_[mi] - global_max[mi]) * scale) < this->skip_softmax_threshold;
}
}
if (!IS_FIRST_COL)
{
local_max_[mi] = fmaxf(local_max_[mi], global_max[mi]);
}
}
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX)
{
#ifdef SKIP_SOFTMAX_STAT
this->total_blocks++;
#endif
if constexpr (may_skip)
{
// AND(&) the results together in a warp, then `skip` means whether to skip
// all the 16 rows managed by this warp.
// each 4 threads (e.g. T0~T3) have the same `skip`, only 0x11111111 is needed
// instead of 0xffffffff. But the perf is the same.
skip = __all_sync(0xffffffff, skip);
if (threadIdx.x % 32 == 0)
{
// The leader of each warp votes.
atomicAnd(skip_softmax_vote, uint32_t(skip));
}
// WG0 uses 0x3 barrier, WG1 uses 0x4 barrier
named_barrier_wait(Base::SKIP_SOFTMAX_BARRIER + threadIdx.x / 128, 128);
skip = *((uint32_t volatile*) skip_softmax_vote);
if (skip)
{
#ifdef SKIP_SOFTMAX_STAT
this->skipped_blocks++;
#endif
return false;
}
}
}
// Softmax Exp.
@ -987,6 +1104,7 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
global_max[mi] = max_new;
}
}
return true;
}
// Update flash attention scales and pack elements for BMM2.

View File

@ -71,6 +71,8 @@ template <
bool ENABLE_BMM1_SOFTCAPPING_SCALE_ = false,
// Save softmax stats ?
bool RETURN_SOFTMAX_STATS_ = false,
// Enable skip softmax attention feature
bool ENABLE_SKIP_SOFTMAX_ = false,
// The output type (only used by fp8 kernels).
typename OutputType = typename Instruction_traits<STEP_Q_, STEP_KV_, 0, false, false>::A_type,
// The sage attention block size for Q, K and V
@ -290,6 +292,12 @@ struct Kernel_traits
USE_CUSTOM_MASK = ATTENTION_MASK_TYPE_ == 3
};
// Are we enabling skip softmax attention feature?
enum
{
ENABLE_SKIP_SOFTMAX = ENABLE_SKIP_SOFTMAX_
};
static_assert(!USE_CUSTOM_MASK || STEP_KV == 64 || STEP_KV == 128 || STEP_KV == 256, "Not implemented!");
// Apply the exp2f optimization (fuse bmm1_scale and -max into FMAs).
@ -384,6 +392,8 @@ struct Kernel_traits
// Named barrier ids
static constexpr int DMA_SYNC_BARRIER_ID = 0x1;
static constexpr int MMA_SYNC_BARRIER_ID = 0x2;
// There are 2 warpgroups so 0x3 and 0x4 are used for skip-softmax
static constexpr int SKIP_SOFTMAX_BARRIER_ID = 0x3;
// How many threads get involved in the dma group.
enum
@ -518,6 +528,10 @@ struct Kernel_traits
// Mutex
OrderedMutex compute_mutex;
// 4 warps in a warpgroup vote to an atomic variable in shared memory
// to decide whether to skip this STEP_KV. Double-buffered to avoid races between consecutive KV_STEPS.
uint32_t skip_softmax_votes[2][NUM_COMPUTE_GROUPS];
inline __device__ void init(int tid0)
{
@ -580,6 +594,8 @@ template < // The step size in query sequence dimension (M of BMM1 and BMM2).
bool ENABLE_BMM1_SOFTCAPPING_SCALE_ = false,
// Save softmax stats ?
bool RETURN_SOFTMAX_STATS_ = false,
// Enable skip softmax attention feature
bool ENABLE_SKIP_SOFTMAX_ = false,
// The output type (only used by fp8 kernels).
typename OutputType = e4m3_t,
// The sage attention block size for Q, K and V
@ -588,14 +604,15 @@ struct Kernel_traits_Hopper_qgmma_e4m3_fp32
: public Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, DV_, Q_BUFFERS_, KV_BUFFERS_,
NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_,
ENABLE_MUTEX_, SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_,
RETURN_SOFTMAX_STATS_, OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_>
RETURN_SOFTMAX_STATS_, ENABLE_SKIP_SOFTMAX_, OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_,
SAGE_BLOCK_SIZE_V_>
{
// Base class.
using Base = Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, DV_, Q_BUFFERS_, KV_BUFFERS_,
NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_, ENABLE_MUTEX_,
SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_, RETURN_SOFTMAX_STATS_,
OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_>;
ENABLE_SKIP_SOFTMAX_, OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_>;
enum
{
@ -693,6 +710,10 @@ struct Kernel_traits_Hopper_qgmma_e4m3_fp32
// Mutex
OrderedMutex compute_mutex;
// 4 warps in a warpgroup vote to an atomic variable in shared memory
// to decide whether to skip this STEP_KV. Double-buffered to avoid races between consecutive STEP_KVs.
uint32_t skip_softmax_votes[2][Base::NUM_COMPUTE_GROUPS];
inline __device__ void init(int tid0)
{

View File

@ -276,7 +276,8 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params,
// scale factors
float const scale_bmm1, float const scale_softmax, float const scale_bmm2, float const softcapping_scale_bmm1,
// flags
bool const use_int8_scale_max, bool const interleaved, bool const is_s_padded, bool const has_alibi)
bool const use_int8_scale_max, bool const interleaved, bool const is_s_padded, bool const has_alibi,
float const skip_softmax_threshold_scale_factor)
{
memset(&params, 0, sizeof(params));
@ -421,6 +422,9 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params,
params.enable_i2f_trick
= -double(1 << 22) * double(scale_bmm2) <= -128.f && double(1 << 22) * double(scale_bmm2) >= 127.f;
}
// Skip-softmax attention
params.skip_softmax_threshold_scale_factor = skip_softmax_threshold_scale_factor;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -429,7 +433,7 @@ static inline void determine_launch_params(Launch_params& launch_params, Data_ty
const size_t d, const Attention_mask_type attention_mask_type, const Attention_input_layout input_layout,
bool const interleaved, bool const ignore_b1opt, bool const force_unroll, bool const use_tma,
bool const force_non_flash_attention, bool const force_non_warp_specialization,
bool const force_non_granular_tiling, bool const force_fp32_acc,
bool const force_non_granular_tiling, bool const force_fp32_acc, float const skip_softmax_threshold_scale_factor,
// device props
const cudaDeviceProp props)
{
@ -470,6 +474,9 @@ static inline void determine_launch_params(Launch_params& launch_params, Data_ty
"are not supported on Ada currently.\n");
launch_params.use_granular_tiling = false;
}
// Enable skip softmax attention or not.
launch_params.enable_skip_softmax = skip_softmax_threshold_scale_factor > 0.f;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -589,6 +596,9 @@ int main(int argc, char** argv)
// Use attention sinks (added to the denominator of softmax)
bool use_attention_sinks = false;
// Skip-softmax attention
float skip_softmax_threshold_scale_factor = 0;
// Read the parameters from the command-line.
for (int ii = 1; ii < argc; ++ii)
{
@ -885,6 +895,10 @@ int main(int argc, char** argv)
{
use_attention_sinks = true;
}
else if (!strcmp(argv[ii], "-skip-softmax-threshold-scale-factor") && ++ii < argc)
{
skip_softmax_threshold_scale_factor = strtof(argv[ii], nullptr);
}
else
{
fprintf(stderr, "Unrecognized option: %s. Aborting!\n", argv[ii]);
@ -1057,7 +1071,7 @@ int main(int argc, char** argv)
Launch_params launch_params;
determine_launch_params(launch_params, data_type, sm, s, d, attention_mask_type, input_layout, interleaved,
ignore_b1opt, force_unroll, use_tma, force_non_flash_attention, force_non_warp_specialization,
force_non_granular_tiling, force_fp32_acc, props);
force_non_granular_tiling, force_fp32_acc, skip_softmax_threshold_scale_factor, props);
// The Q, K and V matrices are packed into one big matrix of size S x B x H x 3 x D.
const size_t qkv_size = s * b * h * (2 * d + dv);
@ -1713,7 +1727,13 @@ int main(int argc, char** argv)
tokens_per_block, qkv_d_view, q_d, k_d, v_d, contiguous_kv_d, kv_cache_pool_ptr, kv_cache_block_offsets_d,
packed_mask_d, cu_mask_rows_d, attention_sinks_d, cu_seqlens_d, cu_q_seqlens_d, o_d_view, p_d, s_d,
softmax_stats_ptr, scale_bmm2_d, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1,
use_int8_scale_max, interleaved, is_s_padded, has_alibi);
use_int8_scale_max, interleaved, is_s_padded, has_alibi, skip_softmax_threshold_scale_factor);
#ifdef SKIP_SOFTMAX_STAT
FMHA_CHECK_CUDA(cudaMalloc(&params_v2.skip_softmax_total_blocks, sizeof(uint32_t)));
FMHA_CHECK_CUDA(cudaMalloc(&params_v2.skip_softmax_skipped_blocks, sizeof(uint32_t)));
FMHA_CHECK_CUDA(cudaMemset(params_v2.skip_softmax_total_blocks, 0, sizeof(uint32_t)));
FMHA_CHECK_CUDA(cudaMemset(params_v2.skip_softmax_skipped_blocks, 0, sizeof(uint32_t)));
#endif
// total number of tokens is needed to set TMA desc on the host.
launch_params.total_q_seqlen = q_seqlens[b];
@ -2101,6 +2121,18 @@ int main(int argc, char** argv)
non_fused_elapsed / fused_elapsed, total_flops / (fused_elapsed / float(runs) / 1e-9),
total_bytes / (fused_elapsed / float(runs) / 1e-6));
}
#ifdef SKIP_SOFTMAX_STAT
if (skip_softmax_threshold_scale_factor > 0)
{
uint32_t total_blocks, skipped_blocks;
FMHA_CHECK_CUDA(
cudaMemcpy(&total_blocks, params_v2.skip_softmax_total_blocks, sizeof(uint32_t), cudaMemcpyDeviceToHost));
FMHA_CHECK_CUDA(cudaMemcpy(
&skipped_blocks, params_v2.skip_softmax_skipped_blocks, sizeof(uint32_t), cudaMemcpyDeviceToHost));
printf("Skip-Softmax .: %u / %u = %.2f%%\n", skipped_blocks, total_blocks,
total_blocks ? 100.f * skipped_blocks / total_blocks : 0.f);
}
#endif
#if defined(DEBUG_HAS_PRINT_BUFFER)
FMHA_CHECK_CUDA(cuda_memcpy_d2h(print_buffer.data(), params.print_ptr, print_buffer.size(), DATA_TYPE_FP32));
@ -2141,6 +2173,11 @@ int main(int argc, char** argv)
FMHA_CHECK_CUDA(cudaFree(kv_cache_block_offsets_d));
FMHA_CHECK_CUDA(cudaFree(contiguous_kv_d));
FMHA_CHECK_CUDA(cudaFree(softmax_stats_d));
FMHA_CHECK_CUDA(cudaFree(attention_sinks_d));
#ifdef SKIP_SOFTMAX_STAT
FMHA_CHECK_CUDA(cudaFree(params_v2.skip_softmax_total_blocks));
FMHA_CHECK_CUDA(cudaFree(params_v2.skip_softmax_skipped_blocks));
#endif
free(qkv_h);
free(mask_h);

View File

@ -283,6 +283,16 @@ struct Fused_multihead_attention_params_v2 : Fused_multihead_attention_params_ba
float* scales;
} q, k, v;
} sage;
// Skip softmax when exp(local_max - global_max) < skip_softmax_threshold_scale_factor / seqlen.
// A positive value means skip-softmax is enabled.
float skip_softmax_threshold_scale_factor = 0;
#ifdef SKIP_SOFTMAX_STAT
// Statistics of skip-softmax, pointers of device memory for output
uint32_t* skip_softmax_total_blocks;
uint32_t* skip_softmax_skipped_blocks;
#endif
};
#endif
@ -322,6 +332,8 @@ struct Fused_multihead_attention_launch_params
// harward properties to determine how to launch blocks
int multi_processor_count = 0;
int device_l2_cache_size = 0;
// skip softmax attention
bool enable_skip_softmax = false;
};
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -177,4 +177,13 @@ struct Fused_multihead_attention_params_v2
float* scales;
} q, k, v;
} sage;
// Skip softmax when exp(local_max - global_max) < skip_softmax_threshold_scale_factor / seqlen.
// A positive value means skip-softmax is enabled.
float skip_softmax_threshold_scale_factor = 0;
#ifdef SKIP_SOFTMAX_STAT
// Statistics of skip-softmax, pointers of device memory for output
uint32_t* skip_softmax_total_blocks;
uint32_t* skip_softmax_skipped_blocks;
#endif
};

View File

@ -157,6 +157,11 @@ set(UCX_WRAPPER_TARGET tensorrt_llm_ucx_wrapper)
if(NIXL_ROOT)
set(NIXL_WRAPPER_TARGET tensorrt_llm_nixl_wrapper)
set(TRANSFER_AGENT_BINDING_TARGET tensorrt_llm_transfer_agent_binding)
endif()
if(MOONCAKE_ROOT)
set(MOONCAKE_WRAPPER_TARGET tensorrt_llm_mooncake_wrapper)
endif()
add_subdirectory(executor)
@ -272,6 +277,11 @@ if(TARGET ${NIXL_WRAPPER_TARGET})
add_dependencies(${SHARED_TARGET} ${NIXL_WRAPPER_TARGET})
endif()
if(TARGET ${MOONCAKE_WRAPPER_TARGET})
target_link_libraries(${MOONCAKE_WRAPPER_TARGET} INTERFACE ${SHARED_TARGET})
add_dependencies(${SHARED_TARGET} ${MOONCAKE_WRAPPER_TARGET})
endif()
if(NOT WIN32)
# Load libraries at $PREFIX/lib from
# $PREFIX/lib/python3.12/site-packages/tensorrt_llm/libs

View File

@ -154,7 +154,8 @@ bool CacheFormatter::needSendCache(
return true;
}
int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
int selfCpSize = selfConfig.getParallelConfig().mContextParallelism;
int selfTpRank = (selfIdx % (selfConfig.getParallelConfig().mTensorParallelism * selfCpSize)) / selfCpSize;
int selfTpRankInDpGroup = selfTpRank;
if (selfConfig.getParallelConfig().mEnableAttentionDP)
{

View File

@ -81,6 +81,11 @@ std::unique_ptr<BaseCacheTransceiver> CacheTransceiverFactory::createCacheTransc
backendType = executor::CacheTransceiverConfig::BackendType::NIXL;
TLLM_LOG_INFO("Enable NIXL KV cache transport.");
}
else if (common::getEnvUseMooncakeKvCache())
{
backendType = executor::CacheTransceiverConfig::BackendType::MOONCAKE;
TLLM_LOG_INFO("Enable MOONCAKE KV cache transport.");
}
else if (common::getEnvUseMPIKvCache())
{
backendType = executor::CacheTransceiverConfig::BackendType::MPI;
@ -203,9 +208,15 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::NIXL)
{
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
mCacheTransBufferManagerPtrs, *mCacheState);
mCacheTransBufferManagerPtrs, *mCacheState, "nixl");
TLLM_LOG_INFO("NIXL Connection Manager created");
}
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MOONCAKE)
{
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
mCacheTransBufferManagerPtrs, *mCacheState, "mooncake");
TLLM_LOG_INFO("MOONCAKE Connection Manager created");
}
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MPI)
{
mMpiWorldComm = std::addressof(tensorrt_llm::mpi::MpiComm::world());

View File

@ -358,8 +358,9 @@ public:
TransceiverTag::Id id;
RequestInfo info;
auto const* connection = isAgent ? agentConnectionManager->recvConnectionAndRequestInfo(info)
: mManager->recvConnect(DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id));
auto const* connection = isAgent
? agentConnectionManager->recvConnectionAndRequestInfo(info, mTerminate)
: mManager->recvConnect(DataContext{TransceiverTag::kID_TAG, mTerminate}, &id, sizeof(id));
if (connection == nullptr && !mManager->isRunning())
{
TLLM_LOG_WARNING(" recvRequestInfo connection is nullptr, maybe the server is terminating");
@ -395,8 +396,8 @@ public:
if (it == mRequestToSession.end())
{
auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr),
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager,
info.getIndexFromEnd(), info.getLastBlockKey(), nullptr,
DataContext{tagFromRequestId(requestId), mTerminate}, mSelfState, info.getTransState(),
mBufferManager, info.getIndexFromEnd(), info.getLastBlockKey(), nullptr,
!common::getEnvKVCacheTimeOutputPath().empty());
session.setTime(TransferSession::kTimeRequestInfo);
it = mRequestToSession.emplace(requestId, std::move(session)).first;
@ -685,6 +686,10 @@ private:
{
future.get();
}
if (mResponseFuture.valid())
{
mResponseFuture.get();
}
}
void removeResponse(std::map<RequestIdType, Response>::iterator it)
@ -886,9 +891,9 @@ public:
}
}
auto const& resource = getReceiveCacheResource(llmRequest);
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState,
contextState, resource->mBufferManager, requestInfo.getIndexFromEnd(), requestInfo.getLastBlockKey(),
&llmRequest, !common::getEnvKVCacheTimeOutputPath().empty());
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId), mTerminate},
mSelfState, contextState, resource->mBufferManager, requestInfo.getIndexFromEnd(),
requestInfo.getLastBlockKey(), &llmRequest, !common::getEnvKVCacheTimeOutputPath().empty());
}
std::unique_ptr<ReceiveCacheResource> const& getReceiveCacheResource(LlmRequest const& llmRequest)
@ -964,7 +969,7 @@ public:
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections.at(i));
TLLM_CHECK(agentConnection);
isReady = agentConnection->recvReadySignal(
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG});
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG, mTerminate});
}
else
{
@ -979,6 +984,7 @@ public:
~Impl()
{
mTerminate.store(true);
for (auto&& [processInfo, asyncResource] : mInstanceToAsyncResource)
{
asyncResource->mTerminate = true;
@ -1134,6 +1140,7 @@ private:
runtime::BufferManager mBufferManager;
std::ofstream mMeasuresFile;
std::mutex mMeasuresFileMutex;
std::atomic<bool> mTerminate{false};
};
void CacheSender::ImplDeleter::operator()(Impl* ptr)

View File

@ -1224,7 +1224,7 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
auto [partialMatch, numMatched, matchingBlock] = searchRoot != nullptr && blockItr != blockKeys.end()
? searchRoot->findMatchingBlock(*blockItr, mEnablePartialReuse, mCopyOnPartialReuse)
: std::make_tuple(false, 0, nullptr);
if (matchingBlock != nullptr)
if (matchingBlock != nullptr && numMatchedTokens + numMatched <= sequence.getCurrentPrepopulatedPromptLen())
{
KVCacheBlock::IdType matchingBlockId = matchingBlock->getBlockId();
@ -1338,6 +1338,7 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
}
}
sequence.setCurrentPrepopulatedPromptLen(numMatchedTokens);
return numMatchedTokens;
}
@ -1731,9 +1732,22 @@ std::optional<KVCacheBlock::IdType> BlockManager::releaseBlocks(
// Released block will be stored when reuse is enabled.
// Reuse is implied to be enabled if llmRequest is provided.
std::optional<KVCacheBlock::IdType> lastStoredId = std::nullopt;
// For now, the attention kernel only accepts a single
// "prepopulatedPromptLen", that is, all window sizes will use the same
// prepopulated prompt length, so it is meaningless right now to save
// blocks only for a certain window size while blocks in the other
// window size are not valid for saving for reuse.
bool isAllWindowSizesValidForStoreForReuse = true;
for (auto& [windowSize, manager] : mWindowBlockManagers)
{
isAllWindowSizesValidForStoreForReuse &= manager.isSequenceValidForStoreForReuse(sequence.getRequestId());
}
for (auto& [_, manager] : mWindowBlockManagers)
{
if (!llmRequest.has_value() || llmRequest->isDummyRequest() || sequence.getBeamWidth() > 1)
if (!llmRequest.has_value() || llmRequest->isDummyRequest() || sequence.getBeamWidth() > 1
|| !isAllWindowSizesValidForStoreForReuse)
{
lastStoredId = manager.releaseBlocks(sequence, std::nullopt);
}

View File

@ -60,7 +60,8 @@ std::vector<size_t> MLACacheFormatter::pickRecvConnections(
bool MLACacheFormatter::needSendCache(
CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx)
{
int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
int selfCpSize = selfConfig.getParallelConfig().mContextParallelism;
int selfTpRank = (selfIdx % (selfConfig.getParallelConfig().mTensorParallelism * selfCpSize)) / selfCpSize;
int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP
? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize

View File

@ -296,7 +296,8 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
// Parameters for sparse attention
xqaParams.sparse_params = mRuntimeSparseAttentionParams;
xqaParams.use_sparse_attention = useTllmGenSparseAttention();
// Skip softmax threshold.
xqaParams.skip_softmax_threshold_scale_factor = mSkipSoftmaxThresholdScaleFactorDecode;
// Cross attention parameters.
xqaParams.encoder_input_lengths = generationsParams.encoder_input_lengths;
@ -1313,6 +1314,8 @@ int AttentionOp::mlaGeneration(
fmhaParams.sparse_params = mRuntimeSparseAttentionParams;
}
// MLA does not support skip-softmax attention right now
// Run the fmha kernel
mDecoderFMHARunner->run(fmhaParams);
}
@ -1885,6 +1888,18 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
fmhaParams.sparse_params = mRuntimeSparseAttentionParams;
}
// Skip-softmax attention parameters
fmhaParams.skipSoftmaxThresholdScaleFactor = mSkipSoftmaxThresholdScaleFactorPrefill;
#ifdef SKIP_SOFTMAX_STAT
fmhaParams.skipSoftmaxTotalBlocks = mSkipSoftmaxTotalBlocks;
fmhaParams.skipSoftmaxSkippedBlocks = mSkipSoftmaxSkippedBlocks;
#else
if (tensorrt_llm::common::getEnvPrintSkipSoftmaxStat())
{
TLLM_THROW("To print skip softmax stat, please run build_wheel.py with -DSKIP_SOFTMAX_STAT");
}
#endif
if (mAttentionChunkSize)
{
fmhaParams.chunkedAttentionSize = *mAttentionChunkSize;

View File

@ -494,6 +494,14 @@ public:
// See [Chunked Attention] in _torch/modules/attention.py
std::optional<int64_t> mAttentionChunkSize = std::nullopt;
// Skip softmax threshold scale factor.
float mSkipSoftmaxThresholdScaleFactorPrefill = 0;
float mSkipSoftmaxThresholdScaleFactorDecode = 0;
#ifdef SKIP_SOFTMAX_STAT
uint32_t* mSkipSoftmaxTotalBlocks;
uint32_t* mSkipSoftmaxSkippedBlocks;
#endif
[[nodiscard]] auto data() const
{
return std::make_tuple(mLayerIdx, mNumHeads, mVisionStart, mVisionLength, mNumKVHeads, mHeadSize,
@ -510,7 +518,8 @@ public:
mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin,
mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA,
mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1), mSkipSoftmaxThresholdScaleFactorPrefill,
mSkipSoftmaxThresholdScaleFactorDecode);
};
private:

View File

@ -43,7 +43,7 @@ template <QuantizeMode QUANTIZE_MODE, bool QUANTIZE, typename T_OUT, typename T_
__global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x)
@ -63,7 +63,7 @@ __global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* i
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

View File

@ -40,50 +40,12 @@ inline size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept
{
return common::getEnvAllReduceWorkspaceSize();
}
if (worldSize <= 2)
char const* envWorkspaceSize = std::getenv("TRTLLM_ALLREDUCE_FUSION_WORKSPACE_SIZE");
if (envWorkspaceSize != nullptr)
{
return 16 * 1000 * 1000;
}
return 8 * 1000 * 1000;
}
// (SM major_version, TP_size) -> (NCCL_num_token_threshold, TWO_SHOT_numel_threshold)
inline std::unordered_map<int, std::unordered_map<int, std::pair<size_t, size_t>>> HeuristicThresholdLP{
{90,
{
{2, {4096, 4096 * 4096}},
{4, {4096, 1024 * 1024}},
{8, {2048, 512 * 512}},
}},
{100,
{
{2, {4096, 4096 * 4096}},
{4, {4096, 1024 * 2048}},
{8, {4096, 1024 * 1024}},
}},
};
inline AllReduceStrategyType SelectStrategyLP(size_t seq_len, size_t hidden_size, int world_size, AllReduceFusionOp op)
{
// The heuristic is based on the following assumptions:
// __________________________________
// | \ TWO-SHOT zone |
// | ONE-SHOT zone \ | NCCL zone
// |_______________________\______|___
// sm_major is 90 or 100
auto const sm_major = std::min(100, std::max(90, tensorrt_llm::common::getSMVersion()));
auto const [nccl_num_token_threshold, two_shot_numel_threshold] = HeuristicThresholdLP[sm_major][world_size];
auto const message_size = seq_len * hidden_size;
if (message_size >= two_shot_numel_threshold)
{
return AllReduceStrategyType::TWOSHOT;
}
else
{
return AllReduceStrategyType::ONESHOT;
return static_cast<size_t>(std::atoi(envWorkspaceSize));
}
return 67108864; // 64 MiB
}
// use 1D vector to store the best strategy instead of a map for each sm version

View File

@ -249,7 +249,7 @@ bool getEnvUseTileSizeKv64ForTrtllmGen()
bool getEnvEnablePDL()
{
static std::once_flag flag;
static bool enablePDL = false;
static bool enablePDL = true;
std::call_once(flag,
[&]()
@ -257,7 +257,18 @@ bool getEnvEnablePDL()
if (getSMVersion() >= 90)
{
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL");
char const* env = std::getenv("TRTLLM_ENABLE_PDL");
if (env)
{
if (env[0] == '1' && env[1] == '\0')
{
enablePDL = true;
}
else if (env[0] == '0' && env[1] == '\0')
{
enablePDL = false;
}
};
}
});
return enablePDL;
@ -281,6 +292,12 @@ bool getEnvUseNixlKvCache()
return useNixlKvCache;
}
bool getEnvUseMooncakeKvCache()
{
static bool const useMooncakeKvCache = getBoolEnv("TRTLLM_USE_MOONCAKE_KVCACHE");
return useMooncakeKvCache;
}
bool getEnvUseRoundRobinBlockDistForCP()
{
static bool const useRoundRobinBlockDistForCP = getBoolEnv("TRTLLM_USE_ROUND_ROBIN_BLOCK_DIST_FOR_CP");
@ -343,6 +360,23 @@ std::string getEnvNixlBackend()
return nixlBackend;
}
std::string getEnvMooncakeInterface()
{
static std::once_flag flag;
static std::string mooncakeInterface;
std::call_once(flag,
[&]()
{
char const* mooncake_interface = std::getenv("TRTLLM_MOONCAKE_INTERFACE");
if (mooncake_interface)
{
mooncakeInterface = mooncake_interface;
}
});
return mooncakeInterface;
}
bool getEnvDisaggLayerwise()
{
static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE");
@ -531,6 +565,11 @@ bool getEnvEplbForceGdrcopy()
return getBoolEnv("TRTLLM_EPLB_FORCE_GDRCOPY");
}
bool getEnvPrintSkipSoftmaxStat()
{
return getBoolEnv("TRTLLM_PRINT_SKIP_SOFTMAX_STAT");
}
} // namespace common
TRTLLM_NAMESPACE_END

View File

@ -83,8 +83,11 @@ inline void launchWithPdlWhenEnabled(char const* name, KernelFn kernelFn, dim3 g
bool getEnvUseUCXKvCache();
bool getEnvUseMPIKvCache();
bool getEnvUseNixlKvCache();
bool getEnvUseMooncakeKvCache();
bool getEnvUseRoundRobinBlockDistForCP();
std::string getEnvUCXInterface();
@ -93,6 +96,8 @@ std::string getEnvNixlInterface();
std::string getEnvNixlBackend();
std::string getEnvMooncakeInterface();
bool getEnvDisaggLayerwise();
bool getEnvParallelCacheSend();
@ -156,6 +161,8 @@ bool getEnvKVCacheTransferAllBlocksForWindow();
bool getEnvEplbForceGdrcopy();
bool getEnvPrintSkipSoftmaxStat();
} // namespace common
TRTLLM_NAMESPACE_END

View File

@ -0,0 +1,226 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ipUtils.h"
#include "tensorrt_llm/common/logger.h"
#include <arpa/inet.h>
#include <dirent.h>
#include <fcntl.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <netdb.h>
#include <netinet/in.h>
#include <string>
#include <sys/socket.h>
#include <unistd.h>
TRTLLM_NAMESPACE_BEGIN
namespace common
{
std::string getLocalIpByNic(std::string const& interface, int rank)
{
struct ifaddrs* ifaddr = nullptr;
if (getifaddrs(&ifaddr) == -1)
{
TLLM_LOG_ERROR(rank,
"getLocalIpByNic: Can't get local ip from NIC Interface. Please check whether corresponding INTERFACE is "
"set "
"correctly.");
return std::string{};
}
for (struct ifaddrs* ifa = ifaddr; ifa != nullptr; ifa = ifa->ifa_next)
{
if (ifa->ifa_addr == nullptr)
{
continue;
}
if (ifa->ifa_name == interface)
{
if (ifa->ifa_addr->sa_family == AF_INET)
{
char ip[INET_ADDRSTRLEN]{};
void* addr = &((reinterpret_cast<struct sockaddr_in*>(ifa->ifa_addr))->sin_addr);
if ((inet_ntop(AF_INET, addr, ip, sizeof(ip)) != nullptr) && std::strcmp(ip, "0.0.0.0") != 0)
{
freeifaddrs(ifaddr);
return std::string(ip);
}
}
else if (ifa->ifa_addr->sa_family == AF_INET6)
{
char ip[INET6_ADDRSTRLEN]{};
void* addr = &((reinterpret_cast<struct sockaddr_in6*>(ifa->ifa_addr))->sin6_addr);
if ((inet_ntop(AF_INET6, addr, ip, sizeof(ip)) != nullptr) && std::strncmp(ip, "fe80::", 6) != 0
&& std::strcmp(ip, "::1") != 0)
{
freeifaddrs(ifaddr);
return std::string(ip);
}
}
}
}
freeifaddrs(ifaddr);
TLLM_LOG_ERROR(
rank, "Can't get local ip from NIC Interface. Please check whether corresponding INTERFACE is set correctly.");
return std::string{};
}
std::string getLocalIpByHostname(int rank)
{
char hostname[256]{};
if (gethostname(hostname, sizeof(hostname)) == -1)
{
TLLM_LOG_ERROR(rank, "getLocalIpByHostname: Can't get hostname");
return std::string{};
}
struct addrinfo hints = {};
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_CANONNAME;
struct addrinfo* res = nullptr;
if (getaddrinfo(hostname, nullptr, &hints, &res) != 0)
{
TLLM_LOG_WARNING(rank, "getLocalIpByHostname: Can't get address info for hostname");
return std::string{};
}
for (struct addrinfo* p = res; p != nullptr; p = p->ai_next)
{
if (p->ai_family == AF_INET)
{ // IPv4
char ip[INET_ADDRSTRLEN]{};
struct sockaddr_in* ipv4 = reinterpret_cast<struct sockaddr_in*>(p->ai_addr);
void* addr = &(ipv4->sin_addr);
if ((inet_ntop(AF_INET, addr, ip, sizeof(ip)) != nullptr) && std::strcmp(ip, "127.0.0.1") != 0
&& std::strcmp(ip, "0.0.0.0") != 0)
{
freeaddrinfo(res);
return std::string(ip);
}
}
else if (p->ai_family == AF_INET6)
{ // IPv6
char ip[INET6_ADDRSTRLEN]{};
struct sockaddr_in6* ipv6 = reinterpret_cast<struct sockaddr_in6*>(p->ai_addr);
void* addr = &(ipv6->sin6_addr);
if ((inet_ntop(AF_INET6, addr, ip, sizeof(ip)) != nullptr) && std::strncmp(ip, "fe80::", 6) != 0
&& std::strcmp(ip, "::1") != 0)
{
freeaddrinfo(res);
return std::string(ip);
}
}
}
freeaddrinfo(res);
TLLM_LOG_WARNING(rank, "getLocalIpByHostname: Can't get local ip from hostname");
return std::string{};
}
std::string getLocalIpByRemoteOrHostName(int rank)
{
// Try IPv4
struct sockaddr_in addr
{
};
addr.sin_family = AF_INET;
addr.sin_port = htons(80);
// using google's public dns server to get the local ip which can be accessed from remote
char const* dns_ip_v4 = "8.8.8.8";
inet_pton(AF_INET, dns_ip_v4, &addr.sin_addr);
int sock = socket(AF_INET, SOCK_DGRAM, 0);
if (sock != -1)
{
if (connect(sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) != -1)
{
socklen_t addr_len = sizeof(addr);
if (getsockname(sock, reinterpret_cast<struct sockaddr*>(&addr), &addr_len) != -1)
{
char ip[INET_ADDRSTRLEN]{};
inet_ntop(AF_INET, &addr.sin_addr, ip, sizeof(ip));
close(sock);
return std::string(ip);
}
}
close(sock);
}
// Try IPv6
struct sockaddr_in6 addr6
{
};
addr6.sin6_family = AF_INET6;
addr6.sin6_port = htons(80);
// using google's public dns server
char const* dns_ipv6 = "2001:4860:4860::8888";
inet_pton(AF_INET6, dns_ipv6, &addr6.sin6_addr);
sock = socket(AF_INET6, SOCK_DGRAM, 0);
if (sock != -1)
{
if (connect(sock, reinterpret_cast<struct sockaddr*>(&addr6), sizeof(addr6)) != -1)
{
socklen_t addr_len = sizeof(addr6);
if (getsockname(sock, reinterpret_cast<struct sockaddr*>(&addr6), &addr_len) != -1)
{
char ip[INET6_ADDRSTRLEN]{};
inet_ntop(AF_INET6, &addr6.sin6_addr, ip, sizeof(ip));
close(sock);
return std::string(ip);
}
}
close(sock);
}
// Try hostname
return getLocalIpByHostname(rank);
}
std::string getLocalIp(std::string interface, int rank)
{
std::string localIP = {};
if (!interface.empty())
{
localIP = getLocalIpByNic(interface, rank);
}
if (localIP.empty())
{
localIP = getLocalIpByRemoteOrHostName(rank);
}
// check whether the localIP is valid
if (localIP.empty())
{
TLLM_THROW("getLocalIp: Can't get local ip");
}
return localIP;
}
} // namespace common
TRTLLM_NAMESPACE_END

View File

@ -0,0 +1,28 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/config.h"
#include <string>
TRTLLM_NAMESPACE_BEGIN
namespace common
{
std::string getLocalIp(std::string interface, int rank);
} // namespace common
TRTLLM_NAMESPACE_END

View File

@ -37,6 +37,46 @@ NcclCommResourceManager& NcclCommResourceManager::getInstance() noexcept
return instance;
}
NcclCommResourceManager::~NcclCommResourceManager()
{
// Mark that we're in destruction to prevent cleanup attempts from deleters
// that may run during static destruction
mIsDestroying.store(true, std::memory_order_release);
// Proactively clean up all resources before destruction
// This ensures cleanup happens in a controlled manner before static destruction
std::vector<std::pair<ncclComm_t, std::vector<ResourceEntry>>> allResources;
{
std::lock_guard<std::mutex> lock(mMutex);
// Move all resources out of the map
allResources.reserve(mCommResources.size());
for (auto& [comm, resources] : mCommResources)
{
allResources.emplace_back(comm, std::move(resources));
}
mCommResources.clear();
}
// Clean up all resources outside the lock
// Note: We don't call ncclCommDestroy here - that's the responsibility
// of the shared_ptr deleter. We just clean up registered resources.
for (auto& [comm, resources] : allResources)
{
for (auto& [cleanup, name] : resources)
{
try
{
cleanup();
}
catch (...)
{
// Ignore exceptions during destruction
}
}
}
}
void NcclCommResourceManager::registerResource(ncclComm_t comm, ResourceCleanupFunc cleanup, char const* debugName)
{
if (!comm)
@ -60,23 +100,56 @@ void NcclCommResourceManager::cleanupResources(ncclComm_t comm) noexcept
return;
}
// Check if we're in the process of being destroyed
// If so, skip cleanup - the destructor will handle it proactively
if (mIsDestroying.load(std::memory_order_acquire))
{
return;
}
std::vector<ResourceEntry> resourcesToClean;
{
std::lock_guard<std::mutex> lock(mMutex);
auto it = mCommResources.find(comm);
if (it == mCommResources.end())
// During static destruction, mutex and logging may not be safe.
// Use try-catch to handle any issues gracefully.
try
{
// Nothing registered for this comm, nothing to clean up
std::lock_guard<std::mutex> lock(mMutex);
// Double-check after acquiring lock (destruction may have started)
if (mIsDestroying.load(std::memory_order_acquire))
{
return;
}
auto it = mCommResources.find(comm);
if (it == mCommResources.end())
{
// Nothing registered for this comm, nothing to clean up
return;
}
// Move resources out (preserves order) and remove from map
resourcesToClean = std::move(it->second);
mCommResources.erase(it);
// Logging may fail during static destruction, so wrap in try-catch
try
{
TLLM_LOG_TRACE("[NCCLUtil] Cleaning up %zu resources for NCCL comm %p", resourcesToClean.size(),
static_cast<void*>(comm));
}
catch (...)
{
// Ignore logging failures during static destruction
}
}
catch (...)
{
// If mutex access fails during static destruction, just return.
// This prevents segfaults when the singleton is being destroyed.
return;
}
// Move resources out (preserves order) and remove from map
resourcesToClean = std::move(it->second);
mCommResources.erase(it);
TLLM_LOG_TRACE(
"[NCCLUtil] Cleaning up %zu resources for NCCL comm %p", resourcesToClean.size(), static_cast<void*>(comm));
}
// Clean up outside the lock to avoid deadlocks if cleanup functions try to access the manager
@ -85,19 +158,41 @@ void NcclCommResourceManager::cleanupResources(ncclComm_t comm) noexcept
{
try
{
TLLM_LOG_TRACE(
"[NCCLUtil] Cleaning up resource '%s' for NCCL comm %p", name.c_str(), static_cast<void*>(comm));
// Logging may fail during static destruction, so wrap in try-catch
try
{
TLLM_LOG_TRACE(
"[NCCLUtil] Cleaning up resource '%s' for NCCL comm %p", name.c_str(), static_cast<void*>(comm));
}
catch (...)
{
// Ignore logging failures during static destruction
}
cleanup();
}
catch (std::exception const& e)
{
TLLM_LOG_ERROR("[NCCLUtil] Exception during cleanup of resource '%s' for NCCL comm %p: %s", name.c_str(),
static_cast<void*>(comm), e.what());
try
{
TLLM_LOG_ERROR("[NCCLUtil] Exception during cleanup of resource '%s' for NCCL comm %p: %s",
name.c_str(), static_cast<void*>(comm), e.what());
}
catch (...)
{
// Ignore logging failures during static destruction
}
}
catch (...)
{
TLLM_LOG_ERROR("[NCCLUtil] Unknown exception during cleanup of resource '%s' for NCCL comm %p",
name.c_str(), static_cast<void*>(comm));
try
{
TLLM_LOG_ERROR("[NCCLUtil] Unknown exception during cleanup of resource '%s' for NCCL comm %p",
name.c_str(), static_cast<void*>(comm));
}
catch (...)
{
// Ignore logging failures during static destruction
}
}
}
}

View File

@ -26,6 +26,7 @@
#endif
#include <algorithm>
#include <atomic>
#include <functional>
#include <limits>
#include <memory>
@ -139,12 +140,13 @@ public:
private:
NcclCommResourceManager() = default;
~NcclCommResourceManager() = default;
~NcclCommResourceManager();
using ResourceEntry = std::pair<ResourceCleanupFunc, std::string>;
mutable std::mutex mMutex;
std::unordered_map<ncclComm_t, std::vector<ResourceEntry>> mCommResources;
std::atomic<bool> mIsDestroying{false};
};
// RAII helper to register a resource with a NCCL communicator.

View File

@ -123,13 +123,24 @@ std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group)
if (*comm)
{
// Clean up all registered resources FIRST
// The cleanupResources function uses a destruction guard to safely handle
// static destruction order issues - it will return early if the singleton
// is being destroyed (in which case the destructor handles cleanup proactively)
tensorrt_llm::common::nccl_util::NcclCommResourceManager::getInstance().cleanupResources(*comm);
// Now destroy the NCCL communicator
ncclResult_t result = ncclCommDestroy(*comm);
if (result != ncclSuccess)
{
TLLM_LOG_WARNING("ncclCommDestroy failed with error: %d", result);
// Logging may fail during static destruction, so wrap in try-catch
try
{
TLLM_LOG_WARNING("ncclCommDestroy failed with error: %d", result);
}
catch (...)
{
// Ignore logging failures during static destruction
}
}
// Clear the communicator value before freeing the pointer

View File

@ -46,7 +46,7 @@ CUTLASS_DEVICE
void launch_dependent_grids()
{
#if (defined(CUTLASS_GDC_ENABLED))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
@ -57,7 +57,7 @@ CUTLASS_DEVICE
void wait_on_dependent_grids()
{
#if (defined(CUTLASS_GDC_ENABLED))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif
}

View File

@ -686,4 +686,212 @@ public:
}
};
template <class Collective>
struct MixedInputUtilsSM100
{
private:
using KernelSchedule = typename Collective::KernelSchedule;
using ConversionMode = typename Collective::ConversionMode;
using SmemLayoutA = typename Collective::SmemLayoutA;
using SmemLayoutB = typename Collective::SmemLayoutB;
using ElementScale = typename Collective::ElementScale;
using ElementZero = typename Collective::ElementZero;
static constexpr auto KernelConversionMode = Collective::KernelConversionMode;
public:
// Helper functions to select packing for conversion
template <class SrcType, class DstType, int Cosize>
struct select_packing
{ // Naive packing policy
static constexpr auto value()
{
return Int<cute::gcd(Cosize, 32 / cute::min(sizeof_bits_v<SrcType>, sizeof_bits_v<DstType>))>{};
}
};
/// (Designed for separate transform pipeline in Blackwell)
/// Utilities to dequantize A.
template <class EngineIn, class EngineOut, class LayoutIn, class LayoutOut, class... Ts>
CUTLASS_DEVICE static void dequantize_A_kblock_for_transform(Tensor<EngineIn, LayoutIn> const& tArA,
Tensor<EngineOut, LayoutOut>& tArACompute, cute::tuple<Ts...> const& partitioned_extra_info, int const k_block)
{
static_assert(is_rmem<EngineIn>::value, "Input tensor for A conversion must come from registers");
static_assert(is_rmem<EngineOut>::value, "Output tensor for A conversion must come from registers");
static_assert(cosize_v<LayoutIn> == cosize_v<LayoutOut>);
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
using SrcType = typename EngineIn::value_type;
using DstType = typename EngineOut::value_type;
auto src = tArA(_, _, _, k_block);
auto dst = tArACompute(_, _, _, k_block);
auto pSrc = raw_pointer_cast(src.data());
auto pDst = const_cast<DstType*>(raw_pointer_cast(dst.data()));
constexpr int num_elements = decltype(size(src))::value;
constexpr int pack = decltype(select_packing<SrcType, DstType, num_elements>::value())::value;
using Converter
= cutlass::NumericArrayConverter<DstType, SrcType, pack, cutlass::FloatRoundStyle::round_to_nearest>;
using SrcArray = cutlass::Array<SrcType, pack>;
using DstArray = cutlass::Array<DstType, pack>;
constexpr int DstElementsPerReg = 32 / sizeof_bits_v<DstType>;
using RegArray = cutlass::AlignedArray<uint32_t, pack / DstElementsPerReg, sizeof(DstArray)>;
auto src_arr = recast<SrcArray>(src);
auto dst_arr = recast<DstArray>(dst);
Tensor dst_vm = cute::group_modes<1, -1>(cute::zipped_divide(dst, pack));
if constexpr (KernelConversionMode == ConversionMode::DirectConvert)
{
cute::transform(src_arr, dst_arr, Converter::convert);
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale)
{
auto const& scales = cute::get<1>(partitioned_extra_info)(_, _, _, k_block);
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
if constexpr (is_same_v<DstType, ElementScale>)
{
cute::transform(src_arr, dst_arr, Converter::convert);
using ScaleArray = cutlass::Array<ElementScale, pack>;
auto scale_arr = recast<ScaleArray>(filter_zeros(scales));
if constexpr (is_same_v<DstType, cutlass::bfloat16_t>)
{
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, pack));
for (int i = 0; i < size<1>(dst_vm); ++i)
{
auto&& r = cute::recast<RegArray>(dst_vm(_, i))(0);
auto&& scale_reg = cute::recast<RegArray>(scales_vm(_, i))(0);
CUTLASS_PRAGMA_UNROLL
for (size_t ii = 0; ii < RegArray::kElements; ++ii)
{
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
bf16x2_val = __hmul2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(scale_reg[ii]));
}
}
}
else
{
cute::transform(dst_arr, scale_arr, dst_arr, cute::multiplies{});
}
}
else
{
constexpr int pack1 = decltype(select_packing<SrcType, ElementScale, num_elements>::value())::value;
constexpr int pack2 = decltype(select_packing<ElementScale, DstType, num_elements>::value())::value;
constexpr int pack = cute::gcd(pack1, pack2);
using Converter1 = cutlass::NumericArrayConverter<ElementScale, SrcType, pack,
cutlass::FloatRoundStyle::round_to_nearest>;
using Converter2 = cutlass::NumericArrayConverter<DstType, ElementScale, pack,
cutlass::FloatRoundStyle::round_to_nearest>;
using SrcArray = cutlass::Array<SrcType, pack>;
using DstArray = cutlass::Array<DstType, pack>;
using StageArray = cutlass::Array<ElementScale, pack>;
constexpr int iters = num_elements / pack;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < iters; ++i)
{
SrcArray const* pSrcArr = reinterpret_cast<SrcArray const*>(pSrc) + i;
DstArray* pDstArr = reinterpret_cast<DstArray*>(pDst) + i;
StageArray stageArr;
stageArr = Converter1::convert(*pSrcArr);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < pack; ++j)
{
stageArr[j] = stageArr[j] * scales[i * pack + j];
}
*pDstArr = Converter2::convert(stageArr);
}
}
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero)
{
static_assert(is_same_v<ElementScale, ElementZero>, "ElementScale and ElementZero must be the same.");
auto const& scales = cute::get<1>(partitioned_extra_info)(_, _, _, k_block);
auto const& zeros = cute::get<3>(partitioned_extra_info)(_, _, _, k_block);
CUTE_STATIC_ASSERT_V(size(src) == size(scales));
CUTE_STATIC_ASSERT_V(size(src) == size(zeros));
if constexpr (is_same_v<DstType, ElementZero>)
{
cute::transform(src_arr, dst_arr, Converter::convert);
using ScaleArray = cutlass::Array<ElementScale, pack>;
auto scale_arr = recast<ScaleArray>(filter_zeros(scales));
using ZeroArray = cutlass::Array<ElementZero, pack>;
auto zero_arr = recast<ZeroArray>(filter_zeros(zeros));
if constexpr (is_same_v<DstType, cutlass::bfloat16_t>)
{
Tensor scales_vm = cute::group_modes<1, -1>(cute::zipped_divide(scales, pack));
Tensor zeros_vm = cute::group_modes<1, -1>(cute::zipped_divide(zeros, pack));
for (int i = 0; i < size<1>(dst_vm); ++i)
{
auto&& r = cute::recast<RegArray>(dst_vm(_, i))(0);
auto&& scale_reg = cute::recast<RegArray>(scales_vm(_, i))(0);
auto&& zero_reg = cute::recast<RegArray>(zeros_vm(_, i))(0);
CUTLASS_PRAGMA_UNROLL
for (size_t ii = 0; ii < RegArray::kElements; ++ii)
{
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
bf16x2_val = __hmul2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(scale_reg[ii]));
bf16x2_val = __hadd2(bf16x2_val, reinterpret_cast<__nv_bfloat162 const&>(zero_reg[ii]));
}
}
}
else
{
cute::transform(dst_arr, scale_arr, dst_arr, cute::multiplies{});
cute::transform(dst_arr, zero_arr, dst_arr, cute::plus{});
}
}
else
{
constexpr int pack1 = decltype(select_packing<SrcType, ElementScale, num_elements>::value())::value;
constexpr int pack2 = decltype(select_packing<ElementScale, DstType, num_elements>::value())::value;
constexpr int pack = cute::gcd(pack1, pack2);
using Converter1 = cutlass::NumericArrayConverter<ElementScale, SrcType, pack,
cutlass::FloatRoundStyle::round_to_nearest>;
using Converter2 = cutlass::NumericArrayConverter<DstType, ElementScale, pack,
cutlass::FloatRoundStyle::round_to_nearest>;
using SrcArray = cutlass::Array<SrcType, pack>;
using DstArray = cutlass::Array<DstType, pack>;
using StageArray = cutlass::Array<ElementScale, pack>;
constexpr int iters = num_elements / pack;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < iters; ++i)
{
SrcArray const* pSrcArr = reinterpret_cast<SrcArray const*>(pSrc) + i;
DstArray* pDstArr = reinterpret_cast<DstArray*>(pDst) + i;
StageArray stageArr;
stageArr = Converter1::convert(*pSrcArr);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < pack; ++j)
{
stageArr[j] = stageArr[j] * scales[i * pack + j] + zeros[i * pack + j];
}
*pDstArr = Converter2::convert(stageArr);
}
}
}
else
{
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled for input partitioning.");
}
}
};
} // namespace cutlass::gemm::collective::detail

View File

@ -0,0 +1,294 @@
/*
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/gemm/collective/builders/sm100_common.inl"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail
{
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
template <int CapacityBytes, class ElementA, class ElementAMma, class ElementScale, class ElementZero, class ElementB,
class CtaTileShape_MNK, class TiledMma, class KernelScheduleType, UMMA::Major UmmaMajorA, int ScaleGranularityK,
int stages>
constexpr cute::tuple<int, int, int> sm100_compute_stage_count_or_override_weightonly(StageCount<stages> stage_count)
{
constexpr int Load2TransformStageCount = stages;
constexpr int Transform2MmaStageCount = stages;
constexpr int AccumulatorStageCount = stages;
return cute::make_tuple(Load2TransformStageCount, Transform2MmaStageCount, AccumulatorStageCount);
}
template <int CapacityBytes, class ElementA, class ElementAMma, class ElementScale, class ElementZero, class ElementB,
class CtaTileShape_MNK, class TiledMma, class KernelScheduleType, UMMA::Major UmmaMajorA, int ScaleGranularityK,
int carveout_bytes>
constexpr cute::tuple<int, int, int> sm100_compute_stage_count_or_override_weightonly(
StageCountAutoCarveout<carveout_bytes> stage_count)
{
constexpr int CtaM = get<0>(CtaTileShape_MNK{});
constexpr int CtaN = get<1>(CtaTileShape_MNK{});
static_assert(CtaN <= 128, "Can't support CtaN>128 tiles");
constexpr int CtaK = get<2>(CtaTileShape_MNK{});
using AtomThrID = typename TiledMma::AtomThrID;
constexpr int TmemColumns = 512;
constexpr bool IsAComputeinTmem = UmmaMajorA == cute::UMMA::Major::K
&& !cute::is_base_of_v<KernelTmaWarpSpecializedMixedInputSmemSm100, KernelScheduleType>;
constexpr bool IsAComputeinSmem = !IsAComputeinTmem;
// Detect 2x2 TMEM layout
constexpr int TmemAccWordsPerDP = (CtaM == 64 && size(AtomThrID{}) == 2) ? CtaN / 2 : CtaN;
constexpr int TmemAWordsPerDP = CtaK / 2;
constexpr int AccumulatorStageCount
= (IsAComputeinTmem) ? ((TmemAccWordsPerDP == 128) ? 2 : 3) : (TmemColumns / TmemAccWordsPerDP);
constexpr int SmemCapacityAfterMma2AccumCarveout = CapacityBytes - (carveout_bytes + AccumulatorStageCount * 32);
constexpr int TmemInAStageCount_Potential
= (IsAComputeinTmem) ? (TmemColumns - AccumulatorStageCount * TmemAccWordsPerDP) / TmemAWordsPerDP : 10000;
// Mainload2Transform Pipeline
constexpr auto load2transform_pipeline_bytes
= sizeof(typename cutlass::PipelineTmaTransformAsync<1>::SharedStorage);
constexpr auto a_bits = cute::sizeof_bits_v<ElementA>; // ElementA introduce here
constexpr auto s_bits = cute::is_void_v<ElementScale> ? 0 : cute::sizeof_bits_v<ElementScale>;
constexpr auto z_bits = cute::is_void_v<ElementZero> ? 0 : cute::sizeof_bits_v<ElementZero>;
constexpr auto load2mma_pipeline_bytes = sizeof(typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage);
constexpr auto b_bits = cute::sizeof_bits_v<ElementB>; // ElementB introduce here
constexpr int ab_stage_bytes
= cutlass::bits_to_bytes(a_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}))
+ cutlass::bits_to_bytes(s_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}) / ScaleGranularityK)
+ cutlass::bits_to_bytes(z_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}) / ScaleGranularityK)
+ cutlass::bits_to_bytes(b_bits * size<1>(CtaTileShape_MNK{}) / size(AtomThrID{}) * size<2>(CtaTileShape_MNK{}))
+ static_cast<int>(load2transform_pipeline_bytes) + static_cast<int>(load2mma_pipeline_bytes);
// Transform2Mma Pipeline
constexpr auto transform2mma_pipeline_bytes = sizeof(typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage);
constexpr auto a_compute_bits = cute::sizeof_bits_v<ElementAMma>;
constexpr int ab_compute_stage_bytes = cutlass::bits_to_bytes(a_compute_bits * int(IsAComputeinSmem)
* size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}))
+ // If ACompute is in TMEM, Acompute buffer has 0 bytes.
static_cast<int>(transform2mma_pipeline_bytes);
constexpr int ABComputeStageCount_Potential
= SmemCapacityAfterMma2AccumCarveout / (ab_stage_bytes + ab_compute_stage_bytes);
// The number of SMEM buffers for A, B. ACompute (if in SMEM), BCompute should be at least Transform2MmaStageCount
constexpr int Transform2MmaStageCount = std::min(TmemInAStageCount_Potential, ABComputeStageCount_Potential);
constexpr int SmemCapacityAfterABComputeCarveout
= SmemCapacityAfterMma2AccumCarveout - (Transform2MmaStageCount * ab_compute_stage_bytes);
// Can we boost the number of buffers for A and B?
constexpr int Load2TransformStageCount = SmemCapacityAfterABComputeCarveout / ab_stage_bytes;
static_assert(Load2TransformStageCount >= 2 && Transform2MmaStageCount >= 2 && AccumulatorStageCount >= 2,
"Not enough SMEM or TMEM capacity for selected tile size");
return cute::make_tuple(Load2TransformStageCount, Transform2MmaStageCount, AccumulatorStageCount);
}
} // namespace detail
// Mixed Input MMA kernels builder
template <class ElementAOptionalTuple, class GmemLayoutATagTuple, int AlignmentA, class ElementBOptionalTuple,
class GmemLayoutBTag, int AlignmentB, class ElementAccumulator,
class TileShape_MNK, // The Cluster-level TileShape
class ClusterShape_MNK, class StageCountType, class KernelScheduleType>
struct CollectiveBuilderSm100WeightOnly<arch::Sm100, arch::OpClassTensorOp,
ElementAOptionalTuple, // ElementA
GmemLayoutATagTuple, // LayoutA
AlignmentA,
ElementBOptionalTuple, // ElementB
GmemLayoutBTag, // LayoutB
AlignmentB, ElementAccumulator,
TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK)
ClusterShape_MNK, // Static cluster shape or dynamic (int, int, int)
StageCountType, KernelScheduleType,
cute::enable_if_t<(cute::is_base_of_v<KernelScheduleSm100MixedInputGemm, KernelScheduleType>) &&(
(sizeof(float) * AlignmentA) % detail::tma_alignment_bytes == 0)
&& ((sizeof(float) * AlignmentB) % detail::tma_alignment_bytes == 0)>>
{
using GmemLayoutATag = detail::deduce_mixed_width_dtype_t<0, GmemLayoutATagTuple>;
using GmemLayoutScaleTag = detail::deduce_mixed_width_dtype_t<1, GmemLayoutATagTuple>;
static constexpr cute::UMMA::Major UmmaMajorA
= cutlass::gemm::collective::detail::tag_to_umma_major_A<GmemLayoutATag>();
static constexpr cute::UMMA::Major UmmaMajorB
= cutlass::gemm::collective::detail::tag_to_umma_major_B<GmemLayoutBTag>();
using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>;
using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>;
using ElementScale = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple>;
using ElementZero = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>;
static constexpr bool NeitherIsTuple
= !cute::is_tuple<ElementAOptionalTuple>::value && !cute::is_tuple<ElementBOptionalTuple>::value;
static constexpr bool IsANarrow = cute::sizeof_bits_v<ElementA> < cute::sizeof_bits_v<ElementB>;
static constexpr bool IsMixedInput = cute::sizeof_bits_v<ElementA> != cute::sizeof_bits_v<ElementB>;
static_assert(IsMixedInput, "Mixed Input GEMM Kernel doesn't support regular gemm.");
static_assert(
(cute::is_tuple<ElementAOptionalTuple>::value ^ cute::is_tuple<ElementBOptionalTuple>::value
|| (NeitherIsTuple && (cute::sizeof_bits<ElementA>::value != cute::sizeof_bits<ElementB>::value))),
"Either A OR B must be a tuple or the widths of A and B must be different.");
using ElementPairA = cute::conditional_t<IsMixedInput && IsANarrow && NeitherIsTuple, cute::tuple<ElementA>,
ElementAOptionalTuple>;
using ElementPairB = cute::conditional_t<IsMixedInput && !IsANarrow && NeitherIsTuple, cute::tuple<ElementB>,
ElementBOptionalTuple>;
static constexpr bool IsATransformed = cute::is_tuple<ElementPairA>::value;
static_assert(IsATransformed, "A matrix should be transformed.");
// For fp32 types, map to tf32 MMA value type.
using ElementMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
using ElementAMma = ElementMma;
using ElementBMma = ElementMma;
static constexpr int IsSubbyteA = cute::sizeof_bits_v<ElementA> < 8;
using TmaElementA = cute::conditional_t<IsSubbyteA, uint8_t, ElementA>;
static constexpr int ScalingFactor = 1;
using TiledMma = decltype(detail::sm100_make_trivial_mixed_input_tiled_mma<ElementAMma, ElementB,
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, UmmaMajorA, UmmaMajorB, KernelScheduleType>());
using AtomThrID = typename TiledMma::AtomThrID;
using AtomThrShapeMNK = Shape<decltype(shape<0>(typename TiledMma::ThrLayoutVMNK{})), _1, _1>;
using CtaTileShape_MNK = decltype(shape_div(TileShape_MNK{}, AtomThrShapeMNK{}));
// ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K)
using MmaShapeA_MK = decltype(partition_shape_A(
TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), cute::size<2>(TileShape_MNK{}))));
// ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K)
using MmaShapeB_NK = decltype(partition_shape_B(
TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), cute::size<2>(TileShape_MNK{}))));
using BlockTileA_M = decltype(cute::size<0, 0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{}));
using BlockTileA_K = decltype(cute::size<0, 1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(cute::size<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm100_cluster_shape_to_tma_atom_B(ClusterShape_MNK{}, AtomThrID{}));
// Input transform kernel can not use TMA 2SM instructions.
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorA, ElementA,
BlockTileA_M, BlockTileA_K>());
using SmemLayoutAtomACompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorA,
ElementAMma, BlockTileA_M, BlockTileA_K>());
using SmemLayoutAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType<SmemLayoutAtomA,
SmemLayoutAtomACompute>;
static constexpr int MMA_M = cute::size<0, 0>(MmaShapeA_MK{});
using CopyAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType<
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementA>,
cute::conditional_t<
(UmmaMajorA == cute::UMMA::Major::K
&& !cute::is_base_of_v<KernelTmaWarpSpecializedMixedInputSmemSm100, KernelScheduleType>),
cute::conditional_t<(MMA_M == 64 && size(AtomThrID{}) == 1), SM100_TMEM_STORE_16dp256b1x,
SM100_TMEM_STORE_32dp32b8x>, // TS Implementation
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementA>> // SS Implementation
>;
using BlockTileB_N = decltype(cute::size<0, 0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{}));
using BlockTileB_K = decltype(cute::size<0, 1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{}));
// Input transform kernel can not use TMA 2SM instructions.
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorB, ElementB,
BlockTileB_N, BlockTileB_K>());
using SmemLayoutAtomBCompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<UmmaMajorB,
ElementBMma, BlockTileB_N, BlockTileB_K>());
using SmemLayoutAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType<SmemLayoutAtomB,
SmemLayoutAtomBCompute>;
using CopyAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType<
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementB>,
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementMma>>;
// Creating the stride of Transformed Input
using StrideA = cutlass::gemm::TagToStrideA_t<GmemLayoutATag>;
using LayoutScale = cutlass::gemm::TagToStrideA_t<GmemLayoutScaleTag>;
using VoidShapeScale
= Shape<Shape<Int<128>, _1>, Shape<Int<64>, _1>, _1>; // Dummy Value to create a dummy ScaleConfig
using VoidStrideScale = Stride<Stride<_0, _1>, Stride<_0, _1>, _1>;
using VoidLayoutScale = Layout<VoidShapeScale, VoidStrideScale>;
using NonVoidLayoutScale = cute::conditional_t<cute::is_void_v<LayoutScale>, VoidLayoutScale, LayoutScale>;
using StridePairA = decltype(cute::make_tuple(StrideA{}, NonVoidLayoutScale{}));
// SmemCarveout
static constexpr int SchedulerPipelineStageCount = 3;
static constexpr bool IsArrayOfPointersGemm
= (cute::is_base_of_v<KernelScheduleSm100PtrArrayFastFP32Gemm, KernelScheduleType>);
// CLCPipeline = PipelineCLCFetchAsync
static constexpr auto CLCPipelineStorage
= sizeof(typename cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount, ClusterShape_MNK>::SharedStorage);
// CLC (scheduler) response
static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize;
// CLC Throttle pipeline storage
static constexpr auto CLCThrottlePipelineStorage
= sizeof(typename cutlass::PipelineAsync<SchedulerPipelineStageCount>::SharedStorage);
// Tmem dealloc
static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier);
// Tmem ptr storage
static constexpr auto TmemBasePtrsStorage = sizeof(uint32_t);
// Tensormap Storage
static constexpr size_t TensorMapStorage
= IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
// Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage
static constexpr auto KernelSmemCarveout = static_cast<int>(CLCPipelineStorage + CLCResponseStorage
+ CLCThrottlePipelineStorage + TmemDeallocStorage + TmemBasePtrsStorage + TensorMapStorage);
// Reduce SMEM capacity available for buffers considering extra B smem and barrier smem allocations
static constexpr int Sm100ReducedSmemCapacityBytes = detail::sm100_smem_capacity_bytes - KernelSmemCarveout;
static constexpr int ScaleGranularityK = get_ScaleGranularityK<LayoutScale>();
static constexpr auto stage_info
= cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_weightonly<
Sm100ReducedSmemCapacityBytes, TmaElementA, ElementAMma, ElementScale, ElementZero, ElementB,
CtaTileShape_MNK, TiledMma, KernelScheduleType, UmmaMajorA, ScaleGranularityK>(StageCountType{});
static constexpr int Load2TransformPipelineStageCount = get<0>(stage_info);
static constexpr int Transform2MmaPipelineStageCount = get<1>(stage_info);
static constexpr int AccumulatorPipelineStageCount = get<2>(stage_info);
static_assert(!IsArrayOfPointersGemm, "mixed input does not support grouped gemm on Blackwell");
using DispatchPolicy
= cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedMixedInput<Load2TransformPipelineStageCount,
Transform2MmaPipelineStageCount, SchedulerPipelineStageCount, AccumulatorPipelineStageCount,
ClusterShape_MNK>;
using CollectiveOp = cutlass::gemm::collective::CollectiveMmaSm100WeightOnly<DispatchPolicy, TileShape_MNK,
ElementPairA, StridePairA, ElementPairB, cutlass::gemm::TagToStrideB_t<GmemLayoutBTag>, TiledMma,
GmemTiledCopyA, SmemLayoutAtomPairA, CopyAtomPairA, cute::identity, GmemTiledCopyB, SmemLayoutAtomPairB,
CopyAtomPairB, cute::identity>;
};
} // namespace cutlass::gemm::collective

View File

@ -0,0 +1,42 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/collective/collective_mma_sm100_weightonly.hpp"
namespace cutlass::gemm::collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA, int AlignmentA, class ElementB,
class GmemLayoutB, int AlignmentB, class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
class StageCountType, class KernelScheduleType, class Enable = void>
struct CollectiveBuilderSm100WeightOnly
{
static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/builders/sm100_umma_builder_weightonly.inl"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,42 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/detail/dependent_false.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <class DispatchPolicy, class TileShape, class ElementA, class StrideA, class ElementB, class StrideB,
class TiledMma, class GmemTiledCopyA, class SmemLayoutAtomA, class SmemCopyAtomA, class TransformA,
class GmemTiledCopyB, class SmemLayoutAtomB, class SmemCopyAtomB, class TransformB>
struct CollectiveMmaSm100WeightOnly
{
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -533,8 +533,8 @@ struct GemmFpAIntB
run_kernel<arch::Sm80>(params, shared_storage);
#elif (__CUDA_ARCH__ == 890)
run_kernel<arch::Sm89>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 1000)
// Use SM80 implementation for GB10x, GB20x.
#elif (__CUDA_ARCH__ >= 1200)
// Use SM80 implementation for GB20x.
run_kernel<arch::Sm80>(params, shared_storage);
#else
CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.

View File

@ -87,7 +87,9 @@ public:
// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
// which signals that we want to dequantize after loading from smem.
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
struct LayoutDetailsB<TypeA, uint8_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability != 100
&& Arch::kMinComputeCapability != 103>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
@ -102,7 +104,9 @@ public:
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
struct LayoutDetailsB<TypeA, uint4b_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability != 100
&& Arch::kMinComputeCapability != 103>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
@ -116,6 +120,26 @@ public:
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint8_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability == 100 || Arch::kMinComputeCapability == 103>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint4b_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability == 100 || Arch::kMinComputeCapability == 103>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -38,7 +38,13 @@ foreach(SOURCE_FILE ${DEEP_GEMM_ALL_FILES})
if(FILE_EXT STREQUAL ".py")
# Read file content and replace module imports for Python files
file(READ ${SOURCE_FILE} _content)
string(REPLACE "deep_gemm_cpp" "tensorrt_llm.deep_gemm_cpp_tllm" _content
string(REPLACE "from . import _C" "import tensorrt_llm.deep_gemm_cpp_tllm"
_content "${_content}")
string(REPLACE ".._C" "tensorrt_llm.deep_gemm_cpp_tllm" _content
"${_content}")
string(REPLACE "._C" "tensorrt_llm.deep_gemm_cpp_tllm" _content
"${_content}")
string(REPLACE "_C." "tensorrt_llm.deep_gemm_cpp_tllm." _content
"${_content}")
# Add adaptation header

View File

@ -90,4 +90,5 @@ target_compile_definitions(${EXECUTOR_STATIC_TARGET}
PUBLIC TOP_LEVEL_DIR="${TOP_LEVEL_DIR}")
add_subdirectory(cache_transmission/ucx_utils)
add_subdirectory(cache_transmission/mooncake_utils)
add_subdirectory(cache_transmission/nixl_utils)

View File

@ -141,7 +141,8 @@ void AgentConnection::send(DataContext const& ctx, void const* data, size_t size
NotificationInfo notificationInfo{syncInfo};
std::stringstream ss;
NotificationInfo::serialize(notificationInfo, ss);
status->wait();
TransferState transferState = status->wait();
TLLM_CHECK_WITH_INFO(transferState == TransferState::kSUCCESS, "AgentConnection::send failed");
// TODO: there is a bug in request_with_notify https://github.com/ai-dynamo/nixl/pull/252
mAgentConnectionManager->getAgent()->notifySyncMessage(mRemoteAgentName, ss.str());
}
@ -150,7 +151,7 @@ void AgentConnection::recv(DataContext const& ctx, void* data, size_t size) cons
{
NotificationSyncInfo syncInfo{mAgentName, ctx};
mAgentConnectionManager->waitForSyncInfo(mRemoteAgentName, syncInfo);
mAgentConnectionManager->waitForSyncInfo(mRemoteAgentName, syncInfo, ctx.getTransferTerminate());
}
void AgentConnection::sendRequestAndBufferInfo(batch_manager::RequestInfo& requestInfo,
@ -230,13 +231,13 @@ void AgentConnection::sendReadySignal(DataContext const& ctx, bool isReady) cons
bool AgentConnection::recvReadySignal(DataContext const& ctx) const
{
ReadySignalInfo readySignalInfo{mAgentName, ctx, false};
mAgentConnectionManager->waitForReadySignal(mRemoteAgentName, readySignalInfo);
return true;
mAgentConnectionManager->waitForReadySignal(mRemoteAgentName, readySignalInfo, ctx.getTransferTerminate());
return readySignalInfo.mIsReady;
}
AgentConnectionManager::AgentConnectionManager(
std::vector<batch_manager::kv_cache_manager::CacheTransBufferManager*> cacheTransBufferManagers,
CacheState cacheState)
CacheState cacheState, std::string const& backendType)
: mCacheState(std::move(cacheState))
, mCacheTransBufferManagers(std::move(cacheTransBufferManagers))
, mRegMemDescs(MemoryType::kVRAM, {})
@ -246,8 +247,8 @@ AgentConnectionManager::AgentConnectionManager(
mAgentName = genUniqueAgentName();
// Create Agent
BaseAgentConfig config{mAgentName, true};
m_Agent = makeTransferAgent("nixl", &config);
BaseAgentConfig config{mAgentName, true, false, true, 1};
m_Agent = makeTransferAgent(backendType, &config);
TLLM_CHECK(!mCacheTransBufferManagers.empty());
std::vector<MemoryDesc> memDescs;
for (auto* cacheTransBufferManager : mCacheTransBufferManagers)
@ -315,9 +316,10 @@ AgentConnectionManager::AgentConnectionManager(
" ***** AgentConnectionManager::AgentConnectionManager mCommState: %s", mCommState.toString().c_str());
}
AgentConnection const* AgentConnectionManager::recvConnectionAndRequestInfo(batch_manager::RequestInfo& requestInfo)
AgentConnection const* AgentConnectionManager::recvConnectionAndRequestInfo(
batch_manager::RequestInfo& requestInfo, std::atomic<bool> const& terminateFlag)
{
while (true)
while (!terminateFlag.load())
{
if (!mIsRunning)
{
@ -490,16 +492,16 @@ int AgentConnectionManager::getDeviceId() const
}
template <typename NotificationType>
void AgentConnectionManager::waitForNotification(std::string const& remoteAgentName, NotificationType& expectedInfo)
void AgentConnectionManager::waitForNotification(
std::string const& remoteAgentName, NotificationType& expectedInfo, std::atomic<bool> const& terminateFlag)
{
while (true)
while (!terminateFlag.load())
{
if (!mIsRunning)
{
return;
}
updateUnhandledNotifications();
std::scoped_lock lock(mNotificationMutex);
auto it = mUnhandledNotifications.begin();
@ -575,18 +577,20 @@ void AgentConnectionManager::waitForNotification(std::string const& remoteAgentN
// Explicit template instantiations
template void AgentConnectionManager::waitForNotification<NotificationSyncInfo>(
std::string const& remoteAgentName, NotificationSyncInfo& expectedInfo);
std::string const& remoteAgentName, NotificationSyncInfo& expectedInfo, std::atomic<bool> const& terminateFlag);
template void AgentConnectionManager::waitForNotification<ReadySignalInfo>(
std::string const& remoteAgentName, ReadySignalInfo& expectedInfo);
std::string const& remoteAgentName, ReadySignalInfo& expectedInfo, std::atomic<bool> const& terminateFlag);
void AgentConnectionManager::waitForSyncInfo(std::string const& remoteAgentName, NotificationSyncInfo& syncInfo)
void AgentConnectionManager::waitForSyncInfo(
std::string const& remoteAgentName, NotificationSyncInfo& syncInfo, std::atomic<bool> const& terminateFlag)
{
waitForNotification(remoteAgentName, syncInfo);
waitForNotification(remoteAgentName, syncInfo, terminateFlag);
}
void AgentConnectionManager::waitForReadySignal(std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo)
void AgentConnectionManager::waitForReadySignal(
std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo, std::atomic<bool> const& terminateFlag)
{
waitForNotification(remoteAgentName, readySignalInfo);
waitForNotification(remoteAgentName, readySignalInfo, terminateFlag);
}
std::string const& AgentConnectionManager::getAgentName() const

View File

@ -277,12 +277,13 @@ class AgentConnectionManager : public ConnectionManager
public:
AgentConnectionManager(
std::vector<batch_manager::kv_cache_manager::CacheTransBufferManager*> cacheTransBufferManagers,
CacheState cacheState);
CacheState cacheState, std::string const& backendType);
~AgentConnectionManager();
AgentConnection* recvConnect(DataContext const& ctx, void* data, size_t size) override;
[[nodiscard]] std::vector<Connection const*> getConnections(CommState const& state) override;
[[nodiscard]] CommState const& getCommState() const override;
AgentConnection const* recvConnectionAndRequestInfo(batch_manager::RequestInfo& requestInfo);
AgentConnection const* recvConnectionAndRequestInfo(
batch_manager::RequestInfo& requestInfo, std::atomic<bool> const& terminateFlag);
[[nodiscard]] std::vector<batch_manager::kv_cache_manager::CacheTransBufferManager*> const&
getCacheTransBufferManagers() const;
void updateUnhandledNotifications();
@ -293,9 +294,12 @@ public:
[[nodiscard]] std::string const& getAgentName() const;
template <typename NotificationType>
void waitForNotification(std::string const& remoteAgentName, NotificationType& expectedInfo);
void waitForSyncInfo(std::string const& remoteAgentName, NotificationSyncInfo& syncInfo);
void waitForReadySignal(std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo);
void waitForNotification(
std::string const& remoteAgentName, NotificationType& expectedInfo, std::atomic<bool> const& terminateFlag);
void waitForSyncInfo(
std::string const& remoteAgentName, NotificationSyncInfo& syncInfo, std::atomic<bool> const& terminateFlag);
void waitForReadySignal(
std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo, std::atomic<bool> const& terminateFlag);
[[nodiscard]] bool isRunning() const override;
private:

View File

@ -107,9 +107,9 @@ TargetRanksInfo TargetRanksInfoForDP(
auto const peerCPNum = peerParConfig.mContextParallelism;
auto const selfCPNum = selfParConfig.mContextParallelism;
auto const selfTPRank = selfRank % selfTPNum;
auto const selfCPRank = selfRank % selfCPNum;
auto const selfTPRank = (selfRank % (selfTPNum * selfCPNum)) / selfCPNum;
auto const selfPPRank = selfRank / (selfTPNum * selfCPNum);
auto const selfCPRank = (selfRank % (selfTPNum * selfCPNum)) / selfTPNum;
int peerPPRankStart = 0;
int mDomainPPSize = 1;
@ -205,13 +205,14 @@ TargetRanksInfo TargetRanksInfoForDP(
}
std::vector<int> retRanks;
for (int i = peerTPRankStart; i < peerTPRankEnd; i++)
for (int i = peerCPRankStart; i < peerCPRankEnd; i++)
{
for (int j = peerCPRankStart; j < peerCPRankEnd; j++)
for (int j = peerTPRankStart; j < peerTPRankEnd; j++)
{
for (int k = peerPPRankStart; k < peerPPRankEnd; k++)
{
int irank = (k * peerTPNum * peerCPNum) + (j * peerTPNum) + i;
// Rank formula: ppRank * (tpNum * cpNum) + tpRank * cpNum + cpRank.
int irank = (k * peerTPNum * peerCPNum) + (j * peerCPNum) + i;
retRanks.push_back(irank);
}
}

View File

@ -0,0 +1,45 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: NVIDIA TensorRT
# Source Code License Agreement
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this material and related documentation without an express
# license agreement from NVIDIA CORPORATION or its affiliates is strictly
# prohibited.
# MOONCAKE is not supported on Rocky8 for now
set(IS_ROCKY8 FALSE)
if(EXISTS "/etc/redhat-release")
set(IS_ROCKY8 TRUE)
endif()
if(MOONCAKE_ROOT AND NOT IS_ROCKY8)
find_library(TRANSFER_ENGINE_LIB transfer_engine ${MOONCAKE_ROOT}/lib)
find_path(TRANSFER_ENGINE_INCLUDE_DIR transfer_engine_c.h
${MOONCAKE_ROOT}/include)
message(STATUS "Find transfer engine results:")
message(STATUS " TRANSFER_ENGINE_LIB = ${TRANSFER_ENGINE_LIB}")
message(
STATUS " TRANSFER_ENGINE_INCLUDE_DIR = ${TRANSFER_ENGINE_INCLUDE_DIR}")
if(TRANSFER_ENGINE_LIB AND TRANSFER_ENGINE_INCLUDE_DIR)
set(MOONCAKE_WRAPPER_TARGET "tensorrt_llm_mooncake_wrapper")
add_library(${MOONCAKE_WRAPPER_TARGET} SHARED transferAgent.cpp)
target_compile_options(${MOONCAKE_WRAPPER_TARGET} PRIVATE -Wno-error)
target_include_directories(${MOONCAKE_WRAPPER_TARGET}
PRIVATE ${TRANSFER_ENGINE_INCLUDE_DIR})
target_link_libraries(${MOONCAKE_WRAPPER_TARGET}
PRIVATE ${TRANSFER_ENGINE_LIB} CUDA::cudart)
# Export variables to parent scope for transfer_agent_binding
set(TRANSFER_ENGINE_INCLUDE_DIR
${TRANSFER_ENGINE_INCLUDE_DIR}
PARENT_SCOPE)
endif()
endif()

View File

@ -0,0 +1,612 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/executor/cache_transmission/mooncake_utils/transferAgent.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/common/ipUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/executor/transferAgent.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <algorithm>
#include <arpa/inet.h>
#include <chrono>
#include <dirent.h>
#include <fcntl.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <netdb.h>
#include <netinet/in.h>
#include <sys/file.h>
#include <sys/stat.h>
#include <thread>
#include <unistd.h>
namespace tensorrt_llm::executor::kv_cache
{
MooncakeTransferStatus::MooncakeTransferStatus(transfer_engine_t engine, uint64_t batchId, size_t requestCount)
: mEngine{engine}
, mBatchId{batchId}
, mRequestCount{requestCount}
{
TLLM_CHECK(mEngine);
}
TransferState MooncakeTransferStatus::wait(int64_t timeout_ms) const
{
auto startTime = std::chrono::steady_clock::now();
while (true)
{
if (mBatchFreed)
{
return TransferState::kSUCCESS;
}
bool has_failed = false;
bool all_completed = true;
for (size_t index = 0; index < mRequestCount; ++index)
{
transfer_status_t status;
int rc = getTransferStatus(mEngine, mBatchId, index, &status);
if (rc || status.status == STATUS_FAILED)
{
has_failed = true;
if (rc)
{
TLLM_LOG_ERROR(
"Failed to get transfer status for batch %lu, task %zu: error code %d", mBatchId, index, rc);
}
else
{
TLLM_LOG_ERROR(
"Transfer failed for batch %lu, task %zu: status %d", mBatchId, index, status.status);
}
}
else if (status.status != STATUS_COMPLETED)
{
all_completed = false;
}
}
// If any request failed, return failure
if (has_failed)
{
return TransferState::kFAILURE;
}
// If all requests completed successfully
if (all_completed)
{
freeBatchID(mEngine, mBatchId);
mBatchFreed = true;
TLLM_LOG_DEBUG("Batch ID %lu freed in wait()", mBatchId);
syncSegmentCache(mEngine);
return TransferState::kSUCCESS;
}
// If timeout_ms < 0, wait indefinitely
if (timeout_ms < 0)
{
std::this_thread::yield();
continue;
}
// Check if timeout has elapsed
auto elapsed
= std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - startTime)
.count();
if (elapsed >= timeout_ms)
{
return TransferState::kIN_PROGRESS;
}
std::this_thread::yield();
}
}
[[nodiscard]] bool MooncakeTransferStatus::isCompleted() const
{
if (mBatchFreed)
{
return true;
}
bool has_failed = false;
for (size_t index = 0; index < mRequestCount; ++index)
{
transfer_status_t status;
int rc = getTransferStatus(mEngine, mBatchId, index, &status);
if (rc || status.status == STATUS_FAILED)
{
has_failed = true;
if (rc)
{
TLLM_LOG_ERROR(
"Failed to get transfer status for batch %lu, task %zu: error code %d", mBatchId, index, rc);
}
else
{
TLLM_LOG_ERROR("Transfer failed for batch %lu, task %zu: status %d", mBatchId, index, status.status);
}
}
else if (status.status == STATUS_PENDING || status.status == STATUS_WAITING)
{
TLLM_LOG_DEBUG("Transfer is pending for batch %lu, task %zu", mBatchId, index);
return false;
}
}
if (!has_failed)
{
// Each batchId has the batch size, and cannot process more requests
// than the batch size. So, free the batch id here to workaround the issue
// where the same batchId could be used to post multiple transfer.
freeBatchID(mEngine, mBatchId);
mBatchFreed = true;
TLLM_LOG_DEBUG("Batch ID %lu freed, future calls will return true directly", mBatchId);
}
// Currently, we cannot distinguish between failed and completed from return value.
TLLM_LOG_DEBUG("Transfer is completed for batch %lu", mBatchId);
return true;
}
std::string const MooncakeBase64Helper::STANDARD_CHARS
= "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
std::string MooncakeBase64Helper::encode(std::vector<uint8_t> const& data)
{
return encodeInternal(data, STANDARD_CHARS);
}
std::string MooncakeBase64Helper::encode(std::string const& data)
{
std::vector<uint8_t> vec(data.begin(), data.end());
return encode(vec);
}
std::vector<uint8_t> MooncakeBase64Helper::decode(std::string const& encoded)
{
return decodeInternal(encoded, STANDARD_CHARS);
}
std::string MooncakeBase64Helper::decodeToString(std::string const& encoded)
{
auto vec = decode(encoded);
return std::string(vec.begin(), vec.end());
}
std::string MooncakeBase64Helper::encodeInternal(std::vector<uint8_t> const& data, std::string const& chars)
{
std::string encoded;
size_t i = 0;
size_t j = 0;
std::array<uint8_t, 3> charArray3{};
std::array<uint8_t, 4> charArray4{};
size_t dataLen = data.size();
uint8_t const* bytes = data.data();
while (dataLen--)
{
charArray3[i++] = *(bytes++);
if (i == 3)
{
charArray4[0] = (charArray3[0] & 0xfc) >> 2;
charArray4[1] = ((charArray3[0] & 0x03) << 4) + ((charArray3[1] & 0xf0) >> 4);
charArray4[2] = ((charArray3[1] & 0x0f) << 2) + ((charArray3[2] & 0xc0) >> 6);
charArray4[3] = charArray3[2] & 0x3f;
for (i = 0; i < 4; i++)
{
encoded += chars[charArray4[i]];
}
i = 0;
}
}
if (i > 0)
{
for (j = i; j < 3; j++)
{
charArray3[j] = '\0';
}
charArray4[0] = (charArray3[0] & 0xfc) >> 2;
charArray4[1] = ((charArray3[0] & 0x03) << 4) + ((charArray3[1] & 0xf0) >> 4);
charArray4[2] = ((charArray3[1] & 0x0f) << 2) + ((charArray3[2] & 0xc0) >> 6);
charArray4[3] = charArray3[2] & 0x3f;
for (j = 0; j < i + 1; j++)
{
encoded += chars[charArray4[j]];
}
while (i++ < 3)
{
encoded += '=';
}
}
return encoded;
}
std::vector<uint8_t> MooncakeBase64Helper::decodeInternal(std::string const& encoded, std::string const& chars)
{
size_t encodedLen = encoded.size();
size_t i = 0;
size_t j = 0;
size_t in_ = 0;
std::array<uint8_t, 3> charArray3{};
std::array<uint8_t, 4> charArray4{};
std::vector<uint8_t> decoded;
std::string cleanEncoded;
for (char c : encoded)
{
if (!isWhitespace(c))
{
cleanEncoded += c;
}
}
encodedLen = cleanEncoded.size();
while (encodedLen-- && cleanEncoded[in_] != '=' && isBase64(cleanEncoded[in_], chars))
{
charArray4[i++] = cleanEncoded[in_];
in_++;
if (i == 4)
{
for (i = 0; i < 4; i++)
{
charArray4[i] = chars.find(charArray4[i]);
}
charArray3[0] = (charArray4[0] << 2) + ((charArray4[1] & 0x30) >> 4);
charArray3[1] = ((charArray4[1] & 0xf) << 4) + ((charArray4[2] & 0x3c) >> 2);
charArray3[2] = ((charArray4[2] & 0x3) << 6) + charArray4[3];
for (i = 0; i < 3; i++)
{
decoded.push_back(charArray3[i]);
}
i = 0;
}
}
if (i > 0)
{
for (j = i; j < 4; j++)
{
charArray4[j] = 0;
}
for (j = 0; j < 4; j++)
{
charArray4[j] = chars.find(charArray4[j]);
}
charArray3[0] = (charArray4[0] << 2) + ((charArray4[1] & 0x30) >> 4);
charArray3[1] = ((charArray4[1] & 0xf) << 4) + ((charArray4[2] & 0x3c) >> 2);
charArray3[2] = ((charArray4[2] & 0x3) << 6) + charArray4[3];
for (j = 0; j < i - 1; j++)
{
decoded.push_back(charArray3[j]);
}
}
return decoded;
}
bool MooncakeBase64Helper::isBase64(uint8_t c, std::string const& chars)
{
return (isalnum(c) || (c == chars[62]) || (c == chars[63]));
}
bool MooncakeBase64Helper::isWhitespace(uint8_t c)
{
return (c == ' ' || c == '\n' || c == '\r' || c == '\t');
}
MooncakeTransferAgent::MooncakeTransferAgent(BaseAgentConfig const& config)
{
mLocalAgentName = config.mName;
std::string segmentName = "127.0.0.1";
if (getenv("TLLM_MOONCAKE_IP_ADDR"))
{
segmentName = std::string(getenv("TLLM_MOONCAKE_IP_ADDR"));
}
else
{
auto ip = common::getLocalIp(common::getEnvMooncakeInterface(), mpi::MpiComm::session().getRank());
if (!ip.empty())
segmentName = ip;
}
mEngine = createTransferEngine("P2PHANDSHAKE", segmentName.c_str(), "", 0, true);
}
void MooncakeTransferAgent::registerMemory(RegisterDescs const& descs)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::registerMemory");
std::lock_guard<std::mutex> lock(mMutex);
for (auto const& desc : descs.getDescs())
{
auto it = mMemRegInfo.find(desc.getAddr());
if (it != mMemRegInfo.end())
{
it->second->addRef();
continue;
}
int err = registerLocalMemory(mEngine, reinterpret_cast<void*>(desc.getAddr()), desc.getLen(), "*", 1);
TLLM_CHECK_WITH_INFO(err == 0, "registerLocalMemory failed, addr: %p, len: %lu",
reinterpret_cast<void*>(desc.getAddr()), desc.getLen());
auto mooncakeDesc = std::make_shared<MooncakeMemoryDesc>(desc);
mMemRegInfo[desc.getAddr()] = std::move(mooncakeDesc);
}
}
void MooncakeTransferAgent::deregisterMemory(RegisterDescs const& descs)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::deregisterMemory");
std::lock_guard<std::mutex> lock(mMutex);
for (auto const& desc : descs.getDescs())
{
auto it = mMemRegInfo.find(desc.getAddr());
if (it != mMemRegInfo.end())
{
auto const& mooncakeDesc = it->second;
mooncakeDesc->releaseRef();
if (mooncakeDesc->getRefCount())
continue;
int err = unregisterLocalMemory(mEngine, reinterpret_cast<void*>(desc.getAddr()));
TLLM_CHECK_WITH_INFO(
err == 0, "unregisterLocalMemory failed, addr: %p", reinterpret_cast<void*>(desc.getAddr()));
mMemRegInfo.erase(desc.getAddr());
}
}
}
void MooncakeTransferAgent::loadRemoteAgent(std::string const& name, AgentDesc const& agentDesc)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::loadRemoteAgent");
// Do the same thing as loadRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo)
loadRemoteAgent(name, std::move(agentDesc.getBackendAgentDesc()));
}
void MooncakeTransferAgent::loadRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo)
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"MooncakeTransferAgent::loadRemoteAgent loadRemoteAgent to %s remoteagent name: %s", connectionInfo.c_str(),
name.c_str());
std::lock_guard<std::mutex> lock(mMutex);
auto segmentId = openSegment(mEngine, connectionInfo.c_str());
TLLM_CHECK_WITH_INFO(
segmentId >= 0, "loadRemoteAgent openSegment failed, connectionInfo: %s", connectionInfo.c_str());
mConnectedAgents[name].segmentId = segmentId;
}
void MooncakeTransferAgent::invalidateRemoteAgent(std::string const& name)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::invalidateRemoteAgent");
}
AgentDesc MooncakeTransferAgent::getLocalAgentDesc()
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::getLocalAgentDesc");
// Using connection info as agent desc
static size_t const kBufLen = 64;
char connectionInfo[kBufLen];
int ret = getLocalIpAndPort(mEngine, connectionInfo, kBufLen);
TLLM_CHECK_WITH_INFO(ret == 0, "MooncakeTransferAgent::getLocalAgentDesc::getLocalIpAndPort failed");
return AgentDesc{std::string(connectionInfo)};
}
ConnectionInfoType MooncakeTransferAgent::getLocalConnectionInfo()
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::getLocalConnectionInfo");
static size_t const kBufLen = 64;
char connectionInfo[kBufLen];
int ret = getLocalIpAndPort(mEngine, connectionInfo, kBufLen);
TLLM_CHECK_WITH_INFO(ret == 0, "MooncakeTransferAgent::getLocalAgentDesc::getLocalConnectionInfo failed");
return std::string(connectionInfo);
}
[[nodiscard]] std::unique_ptr<TransferStatus> MooncakeTransferAgent::submitTransferRequests(
TransferRequest const& request)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::submitTransferRequests");
bool hasNotif = false;
std::string syncMessage;
if (request.getSyncMessage().has_value())
{
hasNotif = true;
syncMessage = request.getSyncMessage().value();
}
static size_t const kMaxRequestCount = 1024;
uint64_t batchId = allocateBatchID(mEngine, kMaxRequestCount);
TLLM_CHECK_WITH_INFO(batchId != INVALID_BATCH, "allocateBatchID failed");
int segmentId;
{
std::lock_guard<std::mutex> lock(mMutex);
std::string remoteName = request.getRemoteName();
auto it = mConnectedAgents.find(remoteName);
if (it == mConnectedAgents.end())
{
std::string error = "Remote agent " + remoteName + "not found";
TLLM_THROW(error);
}
auto const& agentInfo = it->second;
segmentId = agentInfo.segmentId;
}
auto localDescs = request.getSrcDescs().getDescs();
auto remoteDescs = request.getDstDescs().getDescs();
TLLM_CHECK_WITH_INFO(localDescs.size() == remoteDescs.size(), "Number of local and remote memory must match");
size_t requestCount = localDescs.size();
std::vector<transfer_request_t> transferRequests(requestCount);
for (size_t index = 0; index < requestCount; ++index)
{
TLLM_CHECK_WITH_INFO(
localDescs[index].getLen() == remoteDescs[index].getLen(), "Length of local and remote memory must match");
transferRequests[index].opcode = (request.getOp() == TransferOp::kREAD) ? OPCODE_READ : OPCODE_WRITE;
transferRequests[index].source = reinterpret_cast<void*>(localDescs[index].getAddr());
transferRequests[index].target_offset = remoteDescs[index].getAddr();
transferRequests[index].length = localDescs[index].getLen();
transferRequests[index].target_id = segmentId;
}
int rc = 0;
if (hasNotif)
{
notify_msg_t notifyMsg;
notifyMsg.name = const_cast<char*>(mLocalAgentName.c_str());
notifyMsg.msg = const_cast<char*>(syncMessage.c_str());
rc = submitTransferWithNotify(mEngine, batchId, transferRequests.data(), requestCount, notifyMsg);
}
else
{
rc = submitTransfer(mEngine, batchId, transferRequests.data(), requestCount);
}
TLLM_CHECK_WITH_INFO(rc == 0, "submitTransfer failed with status: %d", rc);
return std::make_unique<MooncakeTransferStatus>(mEngine, batchId, requestCount);
}
void MooncakeTransferAgent::notifySyncMessage(std::string const& name, SyncMessage const& syncMessage)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::notifySyncMessage");
int segmentId;
{
std::lock_guard<std::mutex> lock(mMutex);
auto it = mConnectedAgents.find(name);
if (it == mConnectedAgents.end())
{
TLLM_LOG_WARNING("Remote agent %s not found", name.c_str());
return;
}
auto const& agentInfo = it->second;
segmentId = agentInfo.segmentId;
}
notify_msg_t notifyMsg;
notifyMsg.name = const_cast<char*>(mLocalAgentName.c_str());
std::string encoded = MooncakeBase64Helper::encode(syncMessage);
notifyMsg.msg = const_cast<char*>(encoded.c_str());
TLLM_LOG_DEBUG("MooncakeTransferAgent::notifySyncMessage notifyMsg.name: %s, notifyMsg.msg: %s", notifyMsg.name,
notifyMsg.msg);
int ret = genNotifyInEngine(mEngine, segmentId, notifyMsg);
TLLM_CHECK_WITH_INFO(ret == 0, "genNotifyInEngine failed with status: %d", ret);
}
[[nodiscard]] std::unordered_map<std::string, std::vector<SyncMessage>> MooncakeTransferAgent::getNotifiedSyncMessages()
{
std::unordered_map<std::string, std::vector<SyncMessage>> notifs;
int size = 0;
notify_msg_t* notifyMsgs = getNotifsFromEngine(mEngine, &size);
TLLM_CHECK_WITH_INFO(size >= 0, "getNotifsFromEngine returned negative size: %d", size);
for (int i = 0; i < size; i++)
{
if (notifyMsgs[i].msg == nullptr)
{
TLLM_LOG_WARNING("Message pointer is null for: %s", notifyMsgs[i].name);
continue;
}
std::string decoded = MooncakeBase64Helper::decodeToString(notifyMsgs[i].msg);
notifs[notifyMsgs[i].name].emplace_back(std::move(decoded));
TLLM_LOG_DEBUG("MooncakeTransferAgent::getNotifiedSyncMessages getNotifsFromEngine: %s, %s", notifyMsgs[i].name,
notifyMsgs[i].msg);
}
freeNotifsMsgBuf(notifyMsgs, size);
return notifs;
}
bool MooncakeTransferAgent::checkRemoteDescs(std::string const& name, MemoryDescs const& memoryDescs)
{
TLLM_LOG_DEBUG("MooncakeTransferAgent::checkRemoteDescs");
return true;
}
MooncakeTransferAgent::~MooncakeTransferAgent()
{
destroyTransferEngine(mEngine);
TLLM_LOG_DEBUG("MooncakeTransferAgent::~MooncakeTransferAgent");
}
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreturn-type-c-linkage"
#endif
extern "C"
{
std::unique_ptr<BaseTransferAgent> createMooncakeTransferAgent(BaseAgentConfig const* config)
{
TLLM_CHECK(config);
return std::make_unique<MooncakeTransferAgent>(*config);
}
}
} // namespace tensorrt_llm::executor::kv_cache

View File

@ -0,0 +1,165 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <atomic>
#include <mutex>
#include <thread>
#include <vector>
#include "tensorrt_llm/executor/transferAgent.h"
#include "transfer_engine_c.h"
namespace tensorrt_llm::executor::kv_cache
{
class MooncakeTransferStatus final : public TransferStatus
{
public:
MooncakeTransferStatus(transfer_engine_t engine, uint64_t batchId, size_t requestCount);
[[nodiscard]] bool isCompleted() const override;
TransferState wait(int64_t timeout_ms = -1) const override;
private:
transfer_engine_t mEngine;
uint64_t mBatchId;
size_t mRequestCount;
mutable bool mBatchFreed = false;
};
class MooncakeMemoryDesc
{
public:
MooncakeMemoryDesc(MemoryDesc desc)
: mDesc{std::move(desc)}
, mRefCnt{0}
{
}
MooncakeMemoryDesc(MooncakeMemoryDesc const& other)
: mDesc{other.mDesc}
, mRefCnt{0}
{
}
MooncakeMemoryDesc& operator=(MooncakeMemoryDesc const&) = delete;
~MooncakeMemoryDesc() = default;
void addRef() noexcept
{
++mRefCnt;
}
int releaseRef() noexcept
{
return --mRefCnt;
}
int getRefCount() const noexcept
{
return mRefCnt;
}
MemoryDesc const& getDesc() const noexcept
{
return mDesc;
}
private:
MemoryDesc mDesc;
int mRefCnt;
};
class MooncakeBase64Helper
{
public:
static std::string encode(std::vector<uint8_t> const& data);
static std::string encode(std::string const& data);
static std::vector<uint8_t> decode(std::string const& encoded);
static std::string decodeToString(std::string const& encoded);
private:
static const std::string STANDARD_CHARS;
static std::string encodeInternal(std::vector<uint8_t> const& data, std::string const& chars);
static std::vector<uint8_t> decodeInternal(std::string const& encoded, std::string const& chars);
static inline bool isBase64(uint8_t c, std::string const& chars);
static inline bool isWhitespace(uint8_t c);
};
class MooncakeTransferAgent final : public BaseTransferAgent
{
public:
MooncakeTransferAgent(BaseAgentConfig const& config);
~MooncakeTransferAgent();
void registerMemory(RegisterDescs const& descs) override;
void deregisterMemory(RegisterDescs const& descs) override;
void loadRemoteAgent(std::string const& name, AgentDesc const& agentDesc) override;
void loadRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo) override;
void invalidateRemoteAgent(std::string const& name) override;
AgentDesc getLocalAgentDesc() override;
ConnectionInfoType getLocalConnectionInfo() override;
[[nodiscard]] std::unique_ptr<TransferStatus> submitTransferRequests(TransferRequest const& request) override;
void notifySyncMessage(std::string const& name, SyncMessage const& syncMessage) override;
[[nodiscard]] std::unordered_map<std::string, std::vector<SyncMessage>> getNotifiedSyncMessages() override;
bool checkRemoteDescs(std::string const& name, MemoryDescs const& memoryDescs) override;
private:
struct AgentInfo
{
int segmentId;
};
mutable std::mutex mMutex;
transfer_engine_t mEngine;
std::unordered_map<uintptr_t, std::shared_ptr<MooncakeMemoryDesc>> mMemRegInfo;
std::unordered_map<std::string, AgentInfo> mConnectedAgents;
std::string mLocalAgentName;
};
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreturn-type-c-linkage"
#endif
extern "C"
{
[[nodiscard]] std::unique_ptr<BaseTransferAgent> createMooncakeTransferAgent(BaseAgentConfig const* config);
}
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
} // namespace tensorrt_llm::executor::kv_cache

View File

@ -13,6 +13,9 @@
# License for the specific language governing permissions and limitations under
# the License.
# ============================================================================
# NIXL Wrapper Library
# ============================================================================
if(NIXL_ROOT)
find_package(NIXL REQUIRED)
# Check if all required packages were found
@ -30,6 +33,8 @@ if(NIXL_ROOT)
# Add include directories
target_include_directories(${NIXL_WRAPPER_TARGET} PRIVATE NIXL::nixl)
target_include_directories(${NIXL_WRAPPER_TARGET}
PRIVATE ${PROJECT_SOURCE_DIR}/include)
# Link against all NIXL libraries
target_link_libraries(${NIXL_WRAPPER_TARGET} PRIVATE NIXL::nixl)
@ -37,4 +42,85 @@ if(NIXL_ROOT)
# Link against CUDA
target_link_libraries(${NIXL_WRAPPER_TARGET} PRIVATE CUDA::cudart)
set(NIXL_ENABLED TRUE)
else()
set(NIXL_ENABLED FALSE)
endif()
# ============================================================================
# Check if Mooncake wrapper is available (built in mooncake_utils)
# ============================================================================
if(MOONCAKE_ROOT AND TARGET tensorrt_llm_mooncake_wrapper)
set(MOONCAKE_ENABLED TRUE)
else()
set(MOONCAKE_ENABLED FALSE)
endif()
# ============================================================================
# TensorRT-LLM Transfer Agent Binding Python Module Build if either NIXL or
# Mooncake is enabled
# ============================================================================
if(NIXL_ENABLED OR MOONCAKE_ENABLED)
set(TRANSFER_AGENT_BINDING_TARGET "tensorrt_llm_transfer_agent_binding")
# Collect binding source files
set(AGENT_BINDING_SOURCES "")
if(BINDING_TYPE STREQUAL "pybind")
list(APPEND AGENT_BINDING_SOURCES agentBindingsPybind.cpp)
else()
list(APPEND AGENT_BINDING_SOURCES agentBindingsNanobind.cpp)
endif()
if(BINDING_TYPE STREQUAL "pybind")
# Use pybind11 (already fetched via FetchContent)
pybind11_add_module(${TRANSFER_AGENT_BINDING_TARGET}
${AGENT_BINDING_SOURCES})
message(STATUS "Building tensorrt_llm_transfer_agent_binding with pybind11")
else()
# Default to nanobind (already fetched via FetchContent)
nanobind_add_module(${TRANSFER_AGENT_BINDING_TARGET}
${AGENT_BINDING_SOURCES})
message(STATUS "Building tensorrt_llm_transfer_agent_binding with nanobind")
endif()
target_compile_options(${TRANSFER_AGENT_BINDING_TARGET} PRIVATE -Wno-error)
# Add common include directories
target_include_directories(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE ${PROJECT_SOURCE_DIR}/include)
# Conditionally add NIXL support
if(NIXL_ENABLED)
target_compile_definitions(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE ENABLE_NIXL)
target_include_directories(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE NIXL::nixl)
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE ${NIXL_WRAPPER_TARGET})
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET} PRIVATE NIXL::nixl)
message(STATUS "Transfer agent binding: NIXL support enabled")
endif()
# Conditionally add Mooncake support
if(MOONCAKE_ENABLED)
target_compile_definitions(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE ENABLE_MOONCAKE)
target_include_directories(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE ${TRANSFER_ENGINE_INCLUDE_DIR})
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE tensorrt_llm_mooncake_wrapper)
message(STATUS "Transfer agent binding: Mooncake support enabled")
endif()
# Common dependencies
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET} PRIVATE CUDA::cudart)
target_link_libraries(${TRANSFER_AGENT_BINDING_TARGET}
PRIVATE ${SHARED_TARGET})
# Set RPATH for the module to find wrapper libraries
set_target_properties(
${TRANSFER_AGENT_BINDING_TARGET}
PROPERTIES BUILD_RPATH "$ORIGIN;$ORIGIN/libs;$ORIGIN/libs/nixl"
INSTALL_RPATH "$ORIGIN;$ORIGIN/libs;$ORIGIN/libs/nixl")
endif()

View File

@ -0,0 +1,239 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/executor/transferAgent.h"
#ifdef ENABLE_NIXL
#include "transferAgent.h"
#endif
#ifdef ENABLE_MOONCAKE
#include "../mooncake_utils/transferAgent.h"
#endif
#include <nanobind/nanobind.h>
#include <nanobind/stl/function.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/vector.h>
namespace nb = nanobind;
namespace kvc = tensorrt_llm::executor::kv_cache;
NB_MODULE(tensorrt_llm_transfer_agent_binding, m)
{
m.doc() = "TensorRT-LLM Transfer Agent Python bindings (nanobind)";
// MemoryType enum
nb::enum_<kvc::MemoryType>(m, "MemoryType")
.value("DRAM", kvc::MemoryType::kDRAM)
.value("VRAM", kvc::MemoryType::kVRAM)
.value("BLK", kvc::MemoryType::kBLK)
.value("OBJ", kvc::MemoryType::kOBJ)
.value("FILE", kvc::MemoryType::kFILE);
// TransferOp enum
nb::enum_<kvc::TransferOp>(m, "TransferOp")
.value("READ", kvc::TransferOp::kREAD)
.value("WRITE", kvc::TransferOp::kWRITE);
// TransferState enum
nb::enum_<kvc::TransferState>(m, "TransferState")
.value("IN_PROGRESS", kvc::TransferState::kIN_PROGRESS)
.value("SUCCESS", kvc::TransferState::kSUCCESS)
.value("FAILURE", kvc::TransferState::kFAILURE);
// MemoryDesc class
nb::class_<kvc::MemoryDesc>(m, "MemoryDesc")
.def(nb::init<uintptr_t, size_t, uint32_t>(), nb::arg("addr"), nb::arg("len"), nb::arg("device_id"))
.def_prop_ro("addr", &kvc::MemoryDesc::getAddr)
.def_prop_ro("len", &kvc::MemoryDesc::getLen)
.def_prop_ro("device_id", &kvc::MemoryDesc::getDeviceId);
// MemoryDescs class
nb::class_<kvc::MemoryDescs>(m, "MemoryDescs")
.def(nb::init<kvc::MemoryType, std::vector<kvc::MemoryDesc>>(), nb::arg("type"), nb::arg("descs"))
.def_prop_ro("type", &kvc::MemoryDescs::getType)
.def_prop_ro("descs", &kvc::MemoryDescs::getDescs);
// AgentDesc class
nb::class_<kvc::AgentDesc>(m, "AgentDesc")
.def(
"__init__",
[](kvc::AgentDesc* self, nb::bytes data)
{
std::string str(data.c_str(), data.size());
new (self) kvc::AgentDesc{std::move(str)};
},
nb::arg("backend_agent_desc"))
.def(nb::init<std::string>(), nb::arg("backend_agent_desc"))
.def_prop_ro("backend_agent_desc",
[](kvc::AgentDesc const& self)
{
auto const& desc = self.getBackendAgentDesc();
return nb::bytes(desc.data(), desc.size());
});
// TransferRequest class
nb::class_<kvc::TransferRequest>(m, "TransferRequest")
.def(nb::init<kvc::TransferOp, kvc::TransferDescs, kvc::TransferDescs, std::string const&,
std::optional<kvc::SyncMessage>>(),
nb::arg("op"), nb::arg("src_descs"), nb::arg("dst_descs"), nb::arg("remote_name"),
nb::arg("sync_message") = std::nullopt)
.def_prop_ro("op", &kvc::TransferRequest::getOp)
.def_prop_ro("src_descs", &kvc::TransferRequest::getSrcDescs)
.def_prop_ro("dst_descs", &kvc::TransferRequest::getDstDescs)
.def_prop_ro("remote_name", &kvc::TransferRequest::getRemoteName)
.def_prop_ro("sync_message", &kvc::TransferRequest::getSyncMessage);
// TransferStatus base class
nb::class_<kvc::TransferStatus>(m, "TransferStatus")
.def("is_completed", &kvc::TransferStatus::isCompleted)
.def("wait", &kvc::TransferStatus::wait, nb::arg("timeout_ms") = -1);
// BaseAgentConfig struct
nb::class_<kvc::BaseAgentConfig>(m, "BaseAgentConfig")
.def(nb::init<>())
.def(
"__init__",
[](kvc::BaseAgentConfig* self, std::string name, bool use_prog_thread, bool multi_thread,
bool use_listen_thread, unsigned int num_workers) {
new (self) kvc::BaseAgentConfig{
std::move(name), use_prog_thread, multi_thread, use_listen_thread, num_workers};
},
nb::arg("name"), nb::arg("use_prog_thread") = true, nb::arg("multi_thread") = false,
nb::arg("use_listen_thread") = false, nb::arg("num_workers") = 1)
.def_rw("name", &kvc::BaseAgentConfig::mName)
.def_rw("use_prog_thread", &kvc::BaseAgentConfig::useProgThread)
.def_rw("multi_thread", &kvc::BaseAgentConfig::multiThread)
.def_rw("use_listen_thread", &kvc::BaseAgentConfig::useListenThread)
.def_rw("num_workers", &kvc::BaseAgentConfig::numWorkers);
// BaseTransferAgent class (abstract base)
nb::class_<kvc::BaseTransferAgent>(m, "BaseTransferAgent")
.def("register_memory", &kvc::BaseTransferAgent::registerMemory, nb::arg("descs"))
.def("deregister_memory", &kvc::BaseTransferAgent::deregisterMemory, nb::arg("descs"))
.def("load_remote_agent",
nb::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::BaseTransferAgent::loadRemoteAgent),
nb::arg("name"), nb::arg("agent_desc"))
.def("load_remote_agent_by_connection",
nb::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
&kvc::BaseTransferAgent::loadRemoteAgent),
nb::arg("name"), nb::arg("connection_info"))
.def("get_local_agent_desc", &kvc::BaseTransferAgent::getLocalAgentDesc)
.def("invalidate_remote_agent", &kvc::BaseTransferAgent::invalidateRemoteAgent, nb::arg("name"))
.def(
"submit_transfer_requests",
[](kvc::BaseTransferAgent& self, kvc::TransferRequest const& request)
{ return self.submitTransferRequests(request).release(); },
nb::arg("request"), nb::rv_policy::take_ownership)
.def(
"notify_sync_message", &kvc::BaseTransferAgent::notifySyncMessage, nb::arg("name"), nb::arg("sync_message"))
.def("get_notified_sync_messages", &kvc::BaseTransferAgent::getNotifiedSyncMessages)
.def("get_local_connection_info", &kvc::BaseTransferAgent::getLocalConnectionInfo)
.def("check_remote_descs", &kvc::BaseTransferAgent::checkRemoteDescs, nb::arg("name"), nb::arg("memory_descs"));
#ifdef ENABLE_NIXL
// NixlTransferStatus class - release GIL for blocking operations
nb::class_<kvc::NixlTransferStatus, kvc::TransferStatus>(m, "NixlTransferStatus")
.def("is_completed", &kvc::NixlTransferStatus::isCompleted, nb::call_guard<nb::gil_scoped_release>())
.def("wait", &kvc::NixlTransferStatus::wait, nb::arg("timeout_ms") = -1,
nb::call_guard<nb::gil_scoped_release>());
// NixlTransferAgent class
nb::class_<kvc::NixlTransferAgent, kvc::BaseTransferAgent>(m, "NixlTransferAgent")
.def(nb::init<kvc::BaseAgentConfig const&>(), nb::arg("config"))
.def("register_memory", &kvc::NixlTransferAgent::registerMemory, nb::arg("descs"))
.def("deregister_memory", &kvc::NixlTransferAgent::deregisterMemory, nb::arg("descs"))
.def("load_remote_agent",
nb::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::NixlTransferAgent::loadRemoteAgent),
nb::arg("name"), nb::arg("agent_desc"))
.def("load_remote_agent_by_connection",
nb::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
&kvc::NixlTransferAgent::loadRemoteAgent),
nb::arg("name"), nb::arg("connection_info"))
.def("get_local_agent_desc", &kvc::NixlTransferAgent::getLocalAgentDesc)
.def("get_local_connection_info", &kvc::NixlTransferAgent::getLocalConnectionInfo)
.def("invalidate_remote_agent", &kvc::NixlTransferAgent::invalidateRemoteAgent, nb::arg("name"))
.def(
"submit_transfer_requests",
[](kvc::NixlTransferAgent& self, kvc::TransferRequest const& request)
{ return self.submitTransferRequests(request).release(); },
nb::arg("request"), nb::rv_policy::take_ownership, nb::call_guard<nb::gil_scoped_release>())
.def(
"notify_sync_message", &kvc::NixlTransferAgent::notifySyncMessage, nb::arg("name"), nb::arg("sync_message"))
.def("get_notified_sync_messages", &kvc::NixlTransferAgent::getNotifiedSyncMessages)
.def("check_remote_descs", &kvc::NixlTransferAgent::checkRemoteDescs, nb::arg("name"), nb::arg("memory_descs"));
#endif
#ifdef ENABLE_MOONCAKE
// MooncakeTransferStatus class - release GIL for blocking operations
nb::class_<kvc::MooncakeTransferStatus, kvc::TransferStatus>(m, "MooncakeTransferStatus")
.def("is_completed", &kvc::MooncakeTransferStatus::isCompleted, nb::call_guard<nb::gil_scoped_release>())
.def("wait", &kvc::MooncakeTransferStatus::wait, nb::arg("timeout_ms") = -1,
nb::call_guard<nb::gil_scoped_release>());
// MooncakeTransferAgent class
nb::class_<kvc::MooncakeTransferAgent, kvc::BaseTransferAgent>(m, "MooncakeTransferAgent")
.def(nb::init<kvc::BaseAgentConfig const&>(), nb::arg("config"))
.def("register_memory", &kvc::MooncakeTransferAgent::registerMemory, nb::arg("descs"))
.def("deregister_memory", &kvc::MooncakeTransferAgent::deregisterMemory, nb::arg("descs"))
.def("load_remote_agent",
nb::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::MooncakeTransferAgent::loadRemoteAgent),
nb::arg("name"), nb::arg("agent_desc"))
.def("load_remote_agent_by_connection",
nb::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
&kvc::MooncakeTransferAgent::loadRemoteAgent),
nb::arg("name"), nb::arg("connection_info"))
.def("get_local_agent_desc", &kvc::MooncakeTransferAgent::getLocalAgentDesc)
.def("get_local_connection_info", &kvc::MooncakeTransferAgent::getLocalConnectionInfo)
.def("invalidate_remote_agent", &kvc::MooncakeTransferAgent::invalidateRemoteAgent, nb::arg("name"))
.def(
"submit_transfer_requests",
[](kvc::MooncakeTransferAgent& self, kvc::TransferRequest const& request)
{ return self.submitTransferRequests(request).release(); },
nb::arg("request"), nb::rv_policy::take_ownership, nb::call_guard<nb::gil_scoped_release>())
.def("notify_sync_message", &kvc::MooncakeTransferAgent::notifySyncMessage, nb::arg("name"),
nb::arg("sync_message"))
.def("get_notified_sync_messages", &kvc::MooncakeTransferAgent::getNotifiedSyncMessages)
.def("check_remote_descs", &kvc::MooncakeTransferAgent::checkRemoteDescs, nb::arg("name"),
nb::arg("memory_descs"));
#endif
// Factory function to create transfer agent by backend name (uses dynamic loading)
m.def(
"make_transfer_agent",
[](std::string const& backend, kvc::BaseAgentConfig const& config) -> kvc::BaseTransferAgent*
{ return kvc::makeTransferAgent(backend, &config).release(); },
nb::arg("backend"), nb::arg("config"), nb::rv_policy::take_ownership,
"Create a transfer agent by backend name ('nixl' or 'mooncake'). Uses dynamic loading.");
// Expose which backends are available
#ifdef ENABLE_NIXL
m.attr("NIXL_ENABLED") = true;
#else
m.attr("NIXL_ENABLED") = false;
#endif
#ifdef ENABLE_MOONCAKE
m.attr("MOONCAKE_ENABLED") = true;
#else
m.attr("MOONCAKE_ENABLED") = false;
#endif
}

View File

@ -0,0 +1,234 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/executor/transferAgent.h"
#ifdef ENABLE_NIXL
#include "transferAgent.h"
#endif
#ifdef ENABLE_MOONCAKE
#include "../mooncake_utils/transferAgent.h"
#endif
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
namespace kvc = tensorrt_llm::executor::kv_cache;
PYBIND11_MODULE(tensorrt_llm_transfer_agent_binding, m)
{
m.doc() = "TensorRT-LLM Transfer Agent Python bindings (pybind11)";
// MemoryType enum
py::enum_<kvc::MemoryType>(m, "MemoryType")
.value("DRAM", kvc::MemoryType::kDRAM)
.value("VRAM", kvc::MemoryType::kVRAM)
.value("BLK", kvc::MemoryType::kBLK)
.value("OBJ", kvc::MemoryType::kOBJ)
.value("FILE", kvc::MemoryType::kFILE);
// TransferOp enum
py::enum_<kvc::TransferOp>(m, "TransferOp")
.value("READ", kvc::TransferOp::kREAD)
.value("WRITE", kvc::TransferOp::kWRITE);
// TransferState enum
py::enum_<kvc::TransferState>(m, "TransferState")
.value("IN_PROGRESS", kvc::TransferState::kIN_PROGRESS)
.value("SUCCESS", kvc::TransferState::kSUCCESS)
.value("FAILURE", kvc::TransferState::kFAILURE);
// MemoryDesc class
py::class_<kvc::MemoryDesc>(m, "MemoryDesc")
.def(py::init<uintptr_t, size_t, uint32_t>(), py::arg("addr"), py::arg("len"), py::arg("device_id"))
.def_property_readonly("addr", &kvc::MemoryDesc::getAddr)
.def_property_readonly("len", &kvc::MemoryDesc::getLen)
.def_property_readonly("device_id", &kvc::MemoryDesc::getDeviceId);
// MemoryDescs class
py::class_<kvc::MemoryDescs>(m, "MemoryDescs")
.def(py::init<kvc::MemoryType, std::vector<kvc::MemoryDesc>>(), py::arg("type"), py::arg("descs"))
.def_property_readonly("type", &kvc::MemoryDescs::getType)
.def_property_readonly("descs", &kvc::MemoryDescs::getDescs);
// AgentDesc class
py::class_<kvc::AgentDesc>(m, "AgentDesc")
.def(py::init(
[](py::bytes data)
{
std::string str(PyBytes_AsString(data.ptr()), PyBytes_Size(data.ptr()));
return kvc::AgentDesc{std::move(str)};
}),
py::arg("backend_agent_desc"))
.def(py::init<std::string>(), py::arg("backend_agent_desc"))
.def_property_readonly("backend_agent_desc",
[](kvc::AgentDesc const& self)
{
auto const& desc = self.getBackendAgentDesc();
return py::bytes(desc.data(), desc.size());
});
// TransferRequest class
py::class_<kvc::TransferRequest>(m, "TransferRequest")
.def(py::init<kvc::TransferOp, kvc::TransferDescs, kvc::TransferDescs, std::string const&,
std::optional<kvc::SyncMessage>>(),
py::arg("op"), py::arg("src_descs"), py::arg("dst_descs"), py::arg("remote_name"),
py::arg("sync_message") = std::nullopt)
.def_property_readonly("op", &kvc::TransferRequest::getOp)
.def_property_readonly("src_descs", &kvc::TransferRequest::getSrcDescs)
.def_property_readonly("dst_descs", &kvc::TransferRequest::getDstDescs)
.def_property_readonly("remote_name", &kvc::TransferRequest::getRemoteName)
.def_property_readonly("sync_message", &kvc::TransferRequest::getSyncMessage);
// TransferStatus base class
py::class_<kvc::TransferStatus>(m, "TransferStatus")
.def("is_completed", &kvc::TransferStatus::isCompleted)
.def("wait", &kvc::TransferStatus::wait, py::arg("timeout_ms") = -1);
// BaseAgentConfig struct
py::class_<kvc::BaseAgentConfig>(m, "BaseAgentConfig")
.def(py::init<>())
.def(py::init(
[](std::string name, bool use_prog_thread, bool multi_thread, bool use_listen_thread,
unsigned int num_workers) {
return kvc::BaseAgentConfig{
std::move(name), use_prog_thread, multi_thread, use_listen_thread, num_workers};
}),
py::arg("name"), py::arg("use_prog_thread") = true, py::arg("multi_thread") = false,
py::arg("use_listen_thread") = false, py::arg("num_workers") = 1)
.def_readwrite("name", &kvc::BaseAgentConfig::mName)
.def_readwrite("use_prog_thread", &kvc::BaseAgentConfig::useProgThread)
.def_readwrite("multi_thread", &kvc::BaseAgentConfig::multiThread)
.def_readwrite("use_listen_thread", &kvc::BaseAgentConfig::useListenThread)
.def_readwrite("num_workers", &kvc::BaseAgentConfig::numWorkers);
// BaseTransferAgent class (abstract base)
py::class_<kvc::BaseTransferAgent>(m, "BaseTransferAgent")
.def("register_memory", &kvc::BaseTransferAgent::registerMemory, py::arg("descs"))
.def("deregister_memory", &kvc::BaseTransferAgent::deregisterMemory, py::arg("descs"))
.def("load_remote_agent",
py::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::BaseTransferAgent::loadRemoteAgent),
py::arg("name"), py::arg("agent_desc"))
.def("load_remote_agent_by_connection",
py::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
&kvc::BaseTransferAgent::loadRemoteAgent),
py::arg("name"), py::arg("connection_info"))
.def("get_local_agent_desc", &kvc::BaseTransferAgent::getLocalAgentDesc)
.def("invalidate_remote_agent", &kvc::BaseTransferAgent::invalidateRemoteAgent, py::arg("name"))
.def(
"submit_transfer_requests",
[](kvc::BaseTransferAgent& self, kvc::TransferRequest const& request)
{ return self.submitTransferRequests(request).release(); },
py::arg("request"), py::return_value_policy::take_ownership)
.def(
"notify_sync_message", &kvc::BaseTransferAgent::notifySyncMessage, py::arg("name"), py::arg("sync_message"))
.def("get_notified_sync_messages", &kvc::BaseTransferAgent::getNotifiedSyncMessages)
.def("get_local_connection_info", &kvc::BaseTransferAgent::getLocalConnectionInfo)
.def("check_remote_descs", &kvc::BaseTransferAgent::checkRemoteDescs, py::arg("name"), py::arg("memory_descs"));
#ifdef ENABLE_NIXL
// NixlTransferStatus class - release GIL for blocking operations
py::class_<kvc::NixlTransferStatus, kvc::TransferStatus>(m, "NixlTransferStatus")
.def("is_completed", &kvc::NixlTransferStatus::isCompleted, py::call_guard<py::gil_scoped_release>())
.def("wait", &kvc::NixlTransferStatus::wait, py::arg("timeout_ms") = -1,
py::call_guard<py::gil_scoped_release>());
// NixlTransferAgent class
py::class_<kvc::NixlTransferAgent, kvc::BaseTransferAgent>(m, "NixlTransferAgent")
.def(py::init<kvc::BaseAgentConfig const&>(), py::arg("config"))
.def("register_memory", &kvc::NixlTransferAgent::registerMemory, py::arg("descs"))
.def("deregister_memory", &kvc::NixlTransferAgent::deregisterMemory, py::arg("descs"))
.def("load_remote_agent",
py::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::NixlTransferAgent::loadRemoteAgent),
py::arg("name"), py::arg("agent_desc"))
.def("load_remote_agent_by_connection",
py::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
&kvc::NixlTransferAgent::loadRemoteAgent),
py::arg("name"), py::arg("connection_info"))
.def("get_local_agent_desc", &kvc::NixlTransferAgent::getLocalAgentDesc)
.def("get_local_connection_info", &kvc::NixlTransferAgent::getLocalConnectionInfo)
.def("invalidate_remote_agent", &kvc::NixlTransferAgent::invalidateRemoteAgent, py::arg("name"))
.def(
"submit_transfer_requests",
[](kvc::NixlTransferAgent& self, kvc::TransferRequest const& request)
{ return self.submitTransferRequests(request).release(); },
py::arg("request"), py::return_value_policy::take_ownership, py::call_guard<py::gil_scoped_release>())
.def(
"notify_sync_message", &kvc::NixlTransferAgent::notifySyncMessage, py::arg("name"), py::arg("sync_message"))
.def("get_notified_sync_messages", &kvc::NixlTransferAgent::getNotifiedSyncMessages)
.def("check_remote_descs", &kvc::NixlTransferAgent::checkRemoteDescs, py::arg("name"), py::arg("memory_descs"));
#endif
#ifdef ENABLE_MOONCAKE
// MooncakeTransferStatus class - release GIL for blocking operations
py::class_<kvc::MooncakeTransferStatus, kvc::TransferStatus>(m, "MooncakeTransferStatus")
.def("is_completed", &kvc::MooncakeTransferStatus::isCompleted, py::call_guard<py::gil_scoped_release>())
.def("wait", &kvc::MooncakeTransferStatus::wait, py::arg("timeout_ms") = -1,
py::call_guard<py::gil_scoped_release>());
// MooncakeTransferAgent class
py::class_<kvc::MooncakeTransferAgent, kvc::BaseTransferAgent>(m, "MooncakeTransferAgent")
.def(py::init<kvc::BaseAgentConfig const&>(), py::arg("config"))
.def("register_memory", &kvc::MooncakeTransferAgent::registerMemory, py::arg("descs"))
.def("deregister_memory", &kvc::MooncakeTransferAgent::deregisterMemory, py::arg("descs"))
.def("load_remote_agent",
py::overload_cast<std::string const&, kvc::AgentDesc const&>(&kvc::MooncakeTransferAgent::loadRemoteAgent),
py::arg("name"), py::arg("agent_desc"))
.def("load_remote_agent_by_connection",
py::overload_cast<std::string const&, kvc::ConnectionInfoType const&>(
&kvc::MooncakeTransferAgent::loadRemoteAgent),
py::arg("name"), py::arg("connection_info"))
.def("get_local_agent_desc", &kvc::MooncakeTransferAgent::getLocalAgentDesc)
.def("get_local_connection_info", &kvc::MooncakeTransferAgent::getLocalConnectionInfo)
.def("invalidate_remote_agent", &kvc::MooncakeTransferAgent::invalidateRemoteAgent, py::arg("name"))
.def(
"submit_transfer_requests",
[](kvc::MooncakeTransferAgent& self, kvc::TransferRequest const& request)
{ return self.submitTransferRequests(request).release(); },
py::arg("request"), py::return_value_policy::take_ownership, py::call_guard<py::gil_scoped_release>())
.def("notify_sync_message", &kvc::MooncakeTransferAgent::notifySyncMessage, py::arg("name"),
py::arg("sync_message"))
.def("get_notified_sync_messages", &kvc::MooncakeTransferAgent::getNotifiedSyncMessages)
.def("check_remote_descs", &kvc::MooncakeTransferAgent::checkRemoteDescs, py::arg("name"),
py::arg("memory_descs"));
#endif
// Factory function to create transfer agent by backend name (uses dynamic loading)
m.def(
"make_transfer_agent",
[](std::string const& backend, kvc::BaseAgentConfig const& config) -> kvc::BaseTransferAgent*
{ return kvc::makeTransferAgent(backend, &config).release(); },
py::arg("backend"), py::arg("config"), py::return_value_policy::take_ownership,
"Create a transfer agent by backend name ('nixl' or 'mooncake'). Uses dynamic loading.");
// Expose which backends are available
#ifdef ENABLE_NIXL
m.attr("NIXL_ENABLED") = true;
#else
m.attr("NIXL_ENABLED") = false;
#endif
#ifdef ENABLE_MOONCAKE
m.attr("MOONCAKE_ENABLED") = true;
#else
m.attr("MOONCAKE_ENABLED") = false;
#endif
}

View File

@ -22,6 +22,7 @@
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <arpa/inet.h>
#include <chrono>
#include <dirent.h>
#include <fcntl.h>
#include <ifaddrs.h>
@ -31,6 +32,7 @@
#include <set>
#include <sys/file.h>
#include <sys/stat.h>
#include <thread>
#include <unistd.h>
#include <vector>
@ -318,10 +320,40 @@ NixlTransferStatus::NixlTransferStatus(nixlAgent* agent, nixlXferReqH* handle)
TLLM_CHECK(mHandle);
}
void NixlTransferStatus::wait() const
TransferState NixlTransferStatus::wait(int64_t timeout_ms) const
{
while (!isCompleted())
;
auto startTime = std::chrono::steady_clock::now();
while (true)
{
auto status = mRawAgent->getXferStatus(mHandle);
if (status == NIXL_SUCCESS)
{
return TransferState::kSUCCESS;
}
else if (status != NIXL_IN_PROG)
{
return TransferState::kFAILURE;
}
// If timeout_ms < 0, wait indefinitely until status is not NIXL_IN_PROG
if (timeout_ms < 0)
{
std::this_thread::yield();
continue;
}
// Check if timeout has elapsed
auto elapsed
= std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - startTime)
.count();
if (elapsed >= timeout_ms)
{
return TransferState::kIN_PROGRESS;
}
std::this_thread::yield();
}
}
[[nodiscard]] bool NixlTransferStatus::isCompleted() const
@ -333,6 +365,7 @@ NixlTransferAgent::NixlTransferAgent(BaseAgentConfig const& config)
: mName{config.mName}
{
nixl_status_t status;
if (config.useListenThread)
{
FileLock lock("/tmp/trtllm_nixl_port.lock");
if (!lock.lock())
@ -341,10 +374,18 @@ NixlTransferAgent::NixlTransferAgent(BaseAgentConfig const& config)
}
auto envPort = common::getEnvNixlPort();
uint16_t port = envPort > 0 ? getIncrmentPort(envPort) : getAvailablePort();
nixlAgentConfig nixlConfig{config.useProgThread, true, port};
nixlAgentConfig nixlConfig{
config.useProgThread, true, port, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT, config.numWorkers};
mAddress = getAvailableIP() + ":" + std::to_string(port);
mRawAgent = std::make_unique<nixlAgent>(config.mName, std::move(nixlConfig));
}
else
{
mAddress.clear();
nixlAgentConfig nixlConfig{
config.useProgThread, false, 0, nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT, config.numWorkers};
mRawAgent = std::make_unique<nixlAgent>(config.mName, std::move(nixlConfig));
}
std::string nixlBackend = common::getEnvNixlBackend();
// List of supported backends - extend this list as new backends are added
@ -645,7 +686,8 @@ void NixlLoopbackAgent::executeLoopbackRequest(
std::unique_ptr<TransferStatus> status = this->submitLoopbackRequests(memoryDescs, fileDescs, isOffload);
TLLM_CHECK_WITH_INFO(status != nullptr, "submitLoopbackRequests failed");
status->wait();
TransferState transferState = status->wait();
TLLM_CHECK_WITH_INFO(transferState == TransferState::kSUCCESS, "submitLoopbackRequests failed");
this->deregisterMemory(memoryDescs);
this->deregisterFiles(fileDescs);

View File

@ -45,7 +45,7 @@ public:
[[nodiscard]] bool isCompleted() const override;
void wait() const override;
[[nodiscard]] TransferState wait(int64_t timeout_ms = -1) const override;
private:
nixlAgent* mRawAgent{};

View File

@ -230,59 +230,62 @@ inline __device__ __host__ T divUp(T m, T n)
// Return (block_size, cluster_size, loads_per_thread)
std::tuple<int, int, int> adjustGridConfig(int numTokens, int dim, int eltsPerThread)
{
// Start with preferred block_size and cluster_size
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
int clusterSize = 8;
#else
int clusterSize = 1;
#endif
static int SM = tensorrt_llm::common::getSMVersion();
int clusterSize = SM >= 90 ? 8 : 1;
int blockSize = 128;
// ========================== Adjust the grid configuration ==========================
int threadsNeeded = divUp(dim, eltsPerThread);
int loadsPerThread = 1;
blockSize = divUp(threadsNeeded, clusterSize);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
while (threadsNeeded % clusterSize != 0 && clusterSize > 1)
if (clusterSize > 1)
{
clusterSize /= 2;
while (threadsNeeded % clusterSize != 0 && clusterSize > 1)
{
clusterSize /= 2;
}
blockSize = divUp(threadsNeeded, clusterSize);
while (blockSize < 128 && clusterSize >= 2)
{
blockSize *= 2;
clusterSize /= 2;
}
int smCount = getMultiProcessorCount();
while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512)
{
blockSize *= 2;
clusterSize /= 2;
}
}
blockSize = divUp(threadsNeeded, clusterSize);
while (blockSize < 128 && clusterSize >= 2)
{
blockSize *= 2;
clusterSize /= 2;
}
int smCount = getMultiProcessorCount();
while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512)
{
blockSize *= 2;
clusterSize /= 2;
}
#endif
// Trying to scale up use multiple loads or CGA
while (blockSize > 1024)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
if (clusterSize < 8)
// Scale up with CGA if supported
if (SM >= 90)
{
clusterSize = clusterSize << 1;
if (clusterSize < 8)
{
clusterSize = clusterSize << 1;
}
else
{
break;
}
}
else
{
break;
if (loadsPerThread < 8)
{
loadsPerThread += 1;
}
else
{
break;
}
}
#else
if (loadsPerThread < 8)
{
loadsPerThread += 1;
}
else
{
break;
}
#endif
blockSize = divUp(threadsNeeded, clusterSize * loadsPerThread);
}
return {blockSize, clusterSize, loadsPerThread};
@ -420,9 +423,9 @@ __global__ void __launch_bounds__(1024) oneshotAllreduceFusionKernel(T* outputPt
}
float blockSum = blockReduceSum<float, true>(threadSum);
__shared__ float sharedVal[8]; // Temporary variable to share the sum within block
float fullSum = blockSum;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
__shared__ float sharedVal[8]; // Temporary variable to share the sum within block
namespace cg = cooperative_groups;
cg::cluster_group cluster = cg::this_cluster();
int const numBlocks = cluster.num_blocks();
@ -459,6 +462,8 @@ using detail::adjustGridConfig;
void oneshotAllreduceFusionOp(AllReduceFusionParams const& params)
{
static int const kSMVersion = tensorrt_llm::common::getSMVersion();
int const numTokens = params.numTokens;
int const tokenDim = params.tokenDim;
int const eltsPerThread = sizeof(float4) / getDTypeSize(params.dType);
@ -466,38 +471,31 @@ void oneshotAllreduceFusionOp(AllReduceFusionParams const& params)
auto [blockSize, clusterSize, loadsPerThread] = adjustGridConfig(numTokens, tokenDim, eltsPerThread);
dim3 grid(numTokens, clusterSize, 1);
TLLM_CHECK_WITH_INFO(blockSize <= 1024 && loadsPerThread == 1,
"Hidden Dimension %d exceeds the maximum supported hidden dimension (%d)", tokenDim,
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1024 * 8 * eltsPerThread);
#else
1024 * eltsPerThread);
#endif
TLLM_LOG_DEBUG(
"[MNNVL AllReduceOneShot] Dispatch: grid size: (%d, %d, 1), block_size: %d, cluster_size: %d, "
"loads_per_thread: %d, "
"threads_needed: %d",
numTokens, clusterSize, blockSize, clusterSize, loadsPerThread, divUp(tokenDim, eltsPerThread));
TLLM_CHECK_WITH_INFO(blockSize <= 1024 && loadsPerThread == 1,
"Hidden Dimension %d exceeds the maximum supported hidden dimension (%d)", tokenDim,
1024 * (kSMVersion >= 90 ? 8 : 1) * eltsPerThread);
cudaLaunchAttribute attrs[2];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL() ? 1 : 0;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
attrs[1].id = cudaLaunchAttributeClusterDimension;
attrs[1].val.clusterDim.x = 1;
attrs[1].val.clusterDim.y = clusterSize;
attrs[1].val.clusterDim.z = 1;
#endif
cudaLaunchConfig_t config
{
.gridDim = grid, .blockDim = blockSize, .dynamicSmemBytes = 0, .stream = params.stream, .attrs = attrs,
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
.numAttrs = 2,
#else
.numAttrs = 1,
#endif
cudaLaunchConfig_t config{
.gridDim = grid,
.blockDim = blockSize,
.dynamicSmemBytes = 0,
.stream = params.stream,
.attrs = attrs,
.numAttrs = kSMVersion >= 90 ? 2U : 1U,
};
#define LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, T, RMSNORM) \
@ -831,9 +829,9 @@ __global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OU
float blockSum = blockReduceSum<float, true>(threadSum);
float fullSum = blockSum;
__shared__ float sharedVal[8];
// Use CGA Reduction if supported
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
__shared__ float sharedVal[8];
int const numBlocks = cluster.num_blocks();
if (numBlocks > 1)
{
@ -876,6 +874,11 @@ __global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OU
}
constexpr int kELTS_SIZE = sizeof(T_IN);
// Issue ACQBLK at the end. Assuming preceding kernel will not modify the buffer_flags.
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
// Update the buffer pointers
flag.waitAndUpdate({static_cast<uint32_t>(divUp<uint32_t>(numTokens, worldSize) * worldSize * dim * kELTS_SIZE),
static_cast<uint32_t>(numTokens * dim * kELTS_SIZE), 0, 0});
@ -883,6 +886,7 @@ __global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OU
void twoshotAllreduceFusionOp(AllReduceFusionParams const& params)
{
static int const kSMVersion = tensorrt_llm::common::getSMVersion();
int const numTokens = params.numTokens;
int const tokenDim = params.tokenDim;
int const numEltsPerThread = sizeof(float4) / getDTypeSize(params.dType);
@ -959,17 +963,13 @@ void twoshotAllreduceFusionOp(AllReduceFusionParams const& params)
rnConfig.attrs = rnAttrs;
rnAttrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
rnAttrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL() ? 1 : 0;
#ifndef DISABLE_CGA
rnAttrs[1].id = cudaLaunchAttributeClusterDimension;
rnAttrs[1].val.clusterDim.x = 1;
rnAttrs[1].val.clusterDim.y = rnClusterSize;
rnAttrs[1].val.clusterDim.z = 1;
rnConfig.numAttrs = 2;
#else
rnConfig.numAttrs = 1;
#endif
rnConfig.numAttrs = (kSMVersion >= 90) ? 2U : 1U;
bool const rnUseCGA = rnClusterSize > 1;
bool const rnUseCGA = kSMVersion >= 90 && rnClusterSize > 1;
int const dimPadded = divUp(tokenDim, numEltsPerThread * rnNumThreads) * numEltsPerThread * rnNumThreads;
int const iters = dimPadded / rnNumThreads;

View File

@ -48,6 +48,12 @@ namespace kernels::moe_comm
#define SWITCH_TOP_K(top_k, TOP_K, ...) \
switch (top_k) \
{ \
case 22: \
{ \
constexpr int TOP_K = 22; \
__VA_ARGS__; \
break; \
} \
case 16: \
{ \
constexpr int TOP_K = 16; \
@ -362,88 +368,98 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
int thread_idx = ThreadingPolicy::offset();
int local_token_idx = ThreadingPolicy::token_idx();
if (local_token_idx >= local_num_tokens)
if (local_num_tokens == 0)
{
return;
}
// Prepare per-policy shared-memory tiles for this token
extern __shared__ int smem[];
int* smem_topk_target_ranks;
int* smem_topk_send_indices;
int warps_per_block = blockDim.x / warpSize;
if constexpr (std::is_same<ThreadingPolicy, WarpPolicy>::value)
{
int lane_id = threadIdx.x / warpSize;
smem_topk_target_ranks = smem + lane_id * TOP_K;
smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K;
// Special case: If local_num_tokens == 0,
// we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization.
// Other threads should return.
if (local_token_idx > 0)
return;
}
else
{
smem_topk_target_ranks = smem;
smem_topk_send_indices = smem + TOP_K;
}
// Threads that do not have a token to process should return.
if (local_token_idx >= local_num_tokens)
return;
uint64_t already_copied = 0;
for (int k = 0; k < TOP_K; k++)
{
int expert_id = token_selected_experts[local_token_idx * TOP_K + k];
// Use contiguous partitioning to determine target rank
int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank);
if (already_copied & (1ULL << target_rank))
// Prepare per-policy shared-memory tiles for this token
extern __shared__ int smem[];
int* smem_topk_target_ranks;
int* smem_topk_send_indices;
int warps_per_block = blockDim.x / warpSize;
if constexpr (std::is_same<ThreadingPolicy, WarpPolicy>::value)
{
int lane_id = threadIdx.x / warpSize;
smem_topk_target_ranks = smem + lane_id * TOP_K;
smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K;
}
else
{
smem_topk_target_ranks = smem;
smem_topk_send_indices = smem + TOP_K;
}
uint64_t already_copied = 0;
for (int k = 0; k < TOP_K; k++)
{
int expert_id = token_selected_experts[local_token_idx * TOP_K + k];
// Use contiguous partitioning to determine target rank
int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank);
if (already_copied & (1ULL << target_rank))
{
if (thread_idx == 0)
{
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1;
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1;
// Mirror to shared memory immediately
smem_topk_target_ranks[k] = -1;
smem_topk_send_indices[k] = -1;
}
continue;
}
// Only one thread per warp should increment the counter
int dst_token_idx;
if (thread_idx == 0)
{
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1;
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1;
dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1);
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank;
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx;
// Mirror to shared memory immediately
smem_topk_target_ranks[k] = -1;
smem_topk_send_indices[k] = -1;
smem_topk_target_ranks[k] = target_rank;
smem_topk_send_indices[k] = dst_token_idx;
}
continue;
already_copied |= 1ULL << target_rank;
}
// Sync before dispatching data
ThreadingPolicy::sync();
// Only one thread per warp should increment the counter
int dst_token_idx;
if (thread_idx == 0)
{
dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1);
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank;
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx;
// Mirror to shared memory immediately
smem_topk_target_ranks[k] = target_rank;
smem_topk_send_indices[k] = dst_token_idx;
}
already_copied |= 1ULL << target_rank;
}
// Sync before dispatching data
ThreadingPolicy::sync();
// Read staged routing once into registers per thread
int topk_target_ranks[TOP_K];
int topk_send_indices[TOP_K];
// Read staged routing once into registers per thread
int topk_target_ranks[TOP_K];
int topk_send_indices[TOP_K];
#pragma unroll
for (int k = 0; k < TOP_K; ++k)
{
topk_target_ranks[k] = smem_topk_target_ranks[k];
topk_send_indices[k] = smem_topk_send_indices[k];
for (int k = 0; k < TOP_K; ++k)
{
topk_target_ranks[k] = smem_topk_target_ranks[k];
topk_send_indices[k] = smem_topk_send_indices[k];
}
// Perform a single source load and TOP_K fanout per payload
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
{
uint8_t const* src_data = static_cast<uint8_t const*>(ptrs.src_data_ptrs[payload_idx]);
int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx];
uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token;
vectorized_dispatch<TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, max_tokens_per_rank,
payload_idx, ptrs, topk_target_ranks, topk_send_indices);
}
ThreadingPolicy::sync();
}
// Perform a single source load and TOP_K fanout per payload
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
{
uint8_t const* src_data = static_cast<uint8_t const*>(ptrs.src_data_ptrs[payload_idx]);
int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx];
uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token;
vectorized_dispatch<TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, max_tokens_per_rank, payload_idx,
ptrs, topk_target_ranks, topk_send_indices);
}
ThreadingPolicy::sync();
bool is_first_warp = threadIdx.x / warpSize == 0;
if (is_first_warp)
{
@ -452,8 +468,15 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
bool is_last_token = false;
if (lane_id == 0)
{
int cnt = atomicAdd(ptrs.local_token_counter, 1);
is_last_token = cnt + 1 == local_num_tokens;
if (local_num_tokens != 0)
{
int cnt = atomicAdd(ptrs.local_token_counter, 1);
is_last_token = cnt + 1 == local_num_tokens;
}
else
{
is_last_token = true;
}
}
is_last_token = __shfl_sync(0xffffffff, is_last_token, 0);
@ -523,7 +546,7 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
// Validate parameters
TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK);
TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks);
TLLM_CHECK(params.local_num_tokens > 0);
TLLM_CHECK(params.local_num_tokens >= 0);
TLLM_CHECK(params.num_payloads > 0 && params.num_payloads <= kMaxPayloads);
// Prepare kernel pointers struct
@ -568,6 +591,11 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
if (params.one_block_per_token)
{
int grid_size = params.local_num_tokens;
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
if (grid_size == 0)
{
grid_size = 1;
}
int shared_bytes = 2 * params.top_k * (int) sizeof(int);
SWITCH_TOP_K(params.top_k, TOP_K,
moeA2ADispatchKernel<BlockPolicy, TOP_K><<<grid_size, kBlockSize, shared_bytes, params.stream>>>(
@ -577,6 +605,11 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
else
{
int grid_size = ceilDiv(params.local_num_tokens, kWarpsPerBlock);
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
if (grid_size == 0)
{
grid_size = 1;
}
int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int) sizeof(int);
SWITCH_TOP_K(params.top_k, TOP_K,
moeA2ADispatchKernel<WarpPolicy, TOP_K><<<grid_size, kBlockSize, shared_bytes, params.stream>>>(
@ -626,7 +659,70 @@ __device__ void vectorized_combine_impl(
// Load directly into the per-k accumulator; reduce across k below
acc[k].load(recv_buffer + base_token + offset);
}
if constexpr (TOP_K == 16)
// Reduce acc[TOP_K] into acc[0]
if constexpr (TOP_K == 22)
{
T* a0 = reinterpret_cast<T*>(&acc[0]);
T* a1 = reinterpret_cast<T*>(&acc[1]);
T* a2 = reinterpret_cast<T*>(&acc[2]);
T* a3 = reinterpret_cast<T*>(&acc[3]);
T* a4 = reinterpret_cast<T*>(&acc[4]);
T* a5 = reinterpret_cast<T*>(&acc[5]);
T* a6 = reinterpret_cast<T*>(&acc[6]);
T* a7 = reinterpret_cast<T*>(&acc[7]);
T* a8 = reinterpret_cast<T*>(&acc[8]);
T* a9 = reinterpret_cast<T*>(&acc[9]);
T* a10 = reinterpret_cast<T*>(&acc[10]);
T* a11 = reinterpret_cast<T*>(&acc[11]);
T* a12 = reinterpret_cast<T*>(&acc[12]);
T* a13 = reinterpret_cast<T*>(&acc[13]);
T* a14 = reinterpret_cast<T*>(&acc[14]);
T* a15 = reinterpret_cast<T*>(&acc[15]);
T* a16 = reinterpret_cast<T*>(&acc[16]);
T* a17 = reinterpret_cast<T*>(&acc[17]);
T* a18 = reinterpret_cast<T*>(&acc[18]);
T* a19 = reinterpret_cast<T*>(&acc[19]);
T* a20 = reinterpret_cast<T*>(&acc[20]);
T* a21 = reinterpret_cast<T*>(&acc[21]);
#pragma unroll
for (int j = 0; j < elems_per_vec; ++j)
{
a0[j] += a1[j];
a2[j] += a3[j];
a4[j] += a5[j];
a6[j] += a7[j];
a8[j] += a9[j];
a10[j] += a11[j];
a12[j] += a13[j];
a14[j] += a15[j];
a16[j] += a17[j];
a18[j] += a19[j];
a20[j] += a21[j];
}
#pragma unroll
for (int j = 0; j < elems_per_vec; ++j)
{
a0[j] += a2[j];
a4[j] += a6[j];
a8[j] += a10[j];
a12[j] += a14[j];
a16[j] += a18[j];
}
#pragma unroll
for (int j = 0; j < elems_per_vec; ++j)
{
a0[j] += a4[j];
a8[j] += a12[j];
a16[j] += a20[j];
}
#pragma unroll
for (int j = 0; j < elems_per_vec; ++j)
{
a0[j] += a8[j];
a0[j] += a16[j];
}
}
else if constexpr (TOP_K == 16)
{
T* a0 = reinterpret_cast<T*>(&acc[0]);
T* a1 = reinterpret_cast<T*>(&acc[1]);
@ -710,9 +806,7 @@ __device__ void vectorized_combine_impl(
a0[j] += a8[j];
}
}
// Reduce acc[TOP_K] into acc[0]
if constexpr (TOP_K == 8)
else if constexpr (TOP_K == 8)
{
T* a0 = reinterpret_cast<T*>(&acc[0]);
T* a1 = reinterpret_cast<T*>(&acc[1]);
@ -897,9 +991,19 @@ __global__ void moeA2ACombineKernel(
int local_token_idx = ThreadingPolicy::token_idx();
int const size_per_token = elements_per_token * sizeof(T);
if (local_token_idx >= local_num_tokens)
if (local_num_tokens == 0)
{
return;
// Special case: If local_num_tokens == 0,
// we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization.
// Other threads should return.
if (local_token_idx > 0)
return;
}
else
{
// Threads that do not have a token to process should return.
if (local_token_idx >= local_num_tokens)
return;
}
#if !DISABLE_SYNC_FOR_PROFILING
@ -951,6 +1055,9 @@ __global__ void moeA2ACombineKernel(
__syncthreads();
#endif
if (local_num_tokens == 0)
return;
// Get output location for this token (using src_data_ptrs[0] as output)
T* token_output = static_cast<T*>(ptrs.src_data_ptrs[0]) + local_token_idx * elements_per_token;
@ -1003,7 +1110,7 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
// Validate parameters
TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK);
TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks);
TLLM_CHECK(params.local_num_tokens > 0);
TLLM_CHECK(params.local_num_tokens >= 0);
TLLM_CHECK(params.elements_per_token > 0);
// Configure kernel launch
@ -1011,6 +1118,15 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
int const kWarpsPerBlock = kBlockSize / 32; // warpSize
int grid_size_warp = ceilDiv(params.local_num_tokens, kWarpsPerBlock);
int grid_size_block = params.local_num_tokens;
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
if (grid_size_warp == 0)
{
grid_size_warp = 1;
}
if (grid_size_block == 0)
{
grid_size_block = 1;
}
// Prepare kernel pointers struct for combine
CombineKernelPointers kernel_ptrs = {}; // Zero-initialize

View File

@ -26,7 +26,7 @@ namespace kernels::moe_comm
{
// Configuration constants
static constexpr int kMaxTopK = 16; // Maximum top-k experts per token
static constexpr int kMaxTopK = 22; // Maximum top-k experts per token
static constexpr int kMaxPayloads = 4; // Maximum number of different payload types
static constexpr int kMaxRanks = 64; // Maximum supported EP size

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f6509dd36fb92554c6078595951a8de698d7bdaa07b9b817bfcdd255d4303bca
size 687070
oid sha256:4f1f3679968b8f6dea77f53534af9eb1348b6f476d4c3880833b41dd4cc9c803
size 687860

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b22d606e19b52047ae67319d61f138562f2b81df08ccde3f8fa04f040d408d7a
size 669688
oid sha256:a0d7061b400ab387309af00ae12f7a840b5abb91757183f415ca18329bbdb358
size 670478

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2a70e335677a1b0f9d98267fe7701735e42f105720403489276d48a4247ea1b5
size 423835
oid sha256:4a91ff0238b0c8f1d40f8441f22a60a2c64d344b8550de68737292ff449d1d7e
size 426203

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8289200bf78517033295966e9dbdf5c647da9aa7089669ff473ba436fef6a798
size 1230152
oid sha256:4d094c39dbdd372166facb297a4a91be80fb231bf3cca89afa97e61cc725f67e
size 1228572

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:97cc5f8d42d92332a92fa216847bbacccc7ef9f9d5208bd26585cd702d03fe57
size 1725040
oid sha256:1fe830d32459fd9a25d54e1d00a98720afd938d9e9042e2b5903f969e991d72d
size 1721882

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1264927817c08da144e387a7258f6c6fe424c0ff159f3ab0d6ffa3c4e3947598
size 375671
oid sha256:09af1ef9197c628c4a31cc58276ee6dcfad03f751069a78b5242594f93ea8c97
size 378039

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:950fb45e94ffc8e2ec9f5a4b682075be55cb85d6415b3eeb172ce2cf7d53220d
size 1140954
oid sha256:9e93bb514c30bc5a4cda8f402a386ab85d079f9b97aeff04788cf3c8a8cc87a6
size 1137008

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ba97e1bf342788eaf74a78f542f870d3967214aed98b98600fae772aad5bad5f
size 653960
oid sha256:0dc47824dfc41004c5b243ce9f40eefeee15c69b88474e33ec13137ef56604e8
size 651592

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:337cc83d1880b1496e2f054285472b693c181e081819f425ddf2ea45a5dfe9f4
size 1130682
oid sha256:c0f042eabb29ee9db7ddf9791840337a7544653b295e4b2a5068b7f80bcd8251
size 1128314

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:859ffffa18f1c9c8068a1cfedec487c2e0eab84af2c3720eaa7bb2a044ea16f6
size 1534006
oid sha256:7a9d887dd0acea6d82a25e0dda908f4c5421eaa1ddbfeeb49d382c079156d67e
size 1535586

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:02bc55faacb50d0501c590ed11b40d802b374618cbde58db725cc67495762064
size 698136
oid sha256:22a7eaab8e44194acd83621e5546f164ad9cbeda8b67867f864a235036a03931
size 690242

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:510d6c9942dea4bef976c2307fc63f1d7341d78ad8b41cca3bf80bae0a377575
size 380847
oid sha256:e22fe2dde7f5542975db7517b37cdce0eaa656fed2bc58378b37a872c54a43ef
size 374533

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d0e0d34e15f533f756ac4ad6ef8889e5ed7556d859b6263509f608f2e7194e0a
size 964134

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6fd7941b92a10c3116b3d93b50ce94d90627ed020e1aa4263b2c46926db60250
size 1008328

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:04439f4bdd5bf15dce0d59e455545236ed5b98c963a9b491c40d473eb766a04f
size 988580
oid sha256:ec624d7dceea5234b9dd4e43125f271e46ed4f2a4118837a23e00eb89571dcb2
size 985422

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:46413d67059a15237e0f7f26b4d75c1965d135c4b28a1effe3b6f40a51dbe543
size 606983

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0c526229c1eea9eec08dd2c4a6d9f2052e54d6ece9f4fdf0b9a73f371e72ae36
size 614063
oid sha256:d33f3798292038d22d4e61732da397b3466a8983892fcc14448f63e5705d2dd0
size 629062

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2d07d4142403fc5d3004e6831b12f1cf3236c397e61448cbe49e7c7e47a5aef4
size 2482034

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:26232545325ecf363f12b49db62c92a1294dc155ea22cb6e6593fc920b734aec
size 1862432

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ba18343579abe3aee6e75545c8ec25a244d24864ff69c23146ee2a13b5eccdd4
size 1916872
oid sha256:41df1bdb745c0efd7961c44dbcd30a1bad202103d301ca785b5b7cdef3cd0ce9
size 1882140

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0e6ba601df471fff8f5beb1bdb9af6b8f41f57903ee762bb042b023917882d95
size 2608304
oid sha256:053ddc81e3885a583adb9bfbfea6a263f023638a2162430dc62faeba1b101d37
size 2527002

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:25f59e66bbafb18273bf7fc34eade730ef12e805e59abb6ef345b6d8574b8eb8
size 565135

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:91906c0234a97455f5733ec6e359d61e9b9a0481aa72fd5eec72ae5cc04b8d22
size 571425
oid sha256:2194a3863b3dd880c0788260e6368d8453050e7c02e7055eb8d9b87f4ce32429
size 588001

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:19a154061aa471d1ef60d6a6e6cc32afe7d5fc9e0856059c22220d333f855318
size 2291002

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6233042269b80360ec9b76dff1204188338c85d4c647c234df21405138e8e584
size 704076
oid sha256:3fbf61a84913ea7865981c9d2df49a2c4db4aff6959e0864ba619878af8894dd
size 641720

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:73c371164cb89be22699cfc3973d6b3bc03a892fed504f74e89f17b7130deb12
size 1765330

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ee37ada8d3e1d32b5b7227008e29a73e1b2e2dcfcd9d63a25f818a607445d4ca
size 1798458
oid sha256:17b06132679a9db8eb555012bfb53fe941ea092126af235837deff4848b3b13b
size 1786618

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1c69925c289bbda6bb7579eb6c84d1432612a96485ee97bdc04dcbba035c93da
size 2342284
oid sha256:f2ffd14c273aeb544cf055e6b984f25b03116eb30d067c98bf00af306ec55962
size 2335970

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0f310dc88b134cfee3c3ef703bb764c175bfeacbef3845ad8e75fbf3bbe9d75c
size 604267
oid sha256:0bb606a262a25c8cdb18ee9beff02931a133ebebe7777600479241d291825b9e
size 602689

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:547ad9a31f84c26651688c8911e566c9a05ac88283de8d54c8031017a4f51105
size 917634
oid sha256:90c07881943263544ffc233453b9b5613351e07fdef3dd21bb893134fecc304f
size 916844

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:05b540e55a0124ab1599f69dae20172b17ef20688a24edc8db841f90a1952e8f
size 1384932
oid sha256:b6d7ee26961634f6a7b62d8adae6c97927e85d9fbc8182ef0b1d59ee9d5e2cfb
size 1378616

Some files were not shown because too many files have changed in this diff Show More