mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] add fp4 gemm + allreduce (#9729)
Signed-off-by: benzh Signed-off-by: benzh-2025
This commit is contained in:
parent
c1b0b7350f
commit
6df2c8a074
@ -141,7 +141,7 @@ public:
|
||||
// Epilogue
|
||||
////////////////
|
||||
using FusionCallbacks = cutlass::epilogue::fusion::LinearCombination<ElementD, float, void, float>;
|
||||
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, true>;
|
||||
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, false>;
|
||||
using EpilogueScheduleType = typename MmaAdapter<MmaType, IsFP4>::EpilogueSchedule;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using FusionOp
|
||||
|
||||
@ -191,6 +191,7 @@ GemmAllReduceImplRunner<GemmTraits>::GemmAllReduceImplRunner()
|
||||
break;
|
||||
// Blackwell
|
||||
case 100:
|
||||
case 103:
|
||||
registry_builder.addSm100<GemmTraits, GemmAllReduceImpl::kNVLS_2SHOT, _2SM, TileShape::TileShape_128x256x128,
|
||||
ClusterShape::ClusterShape_4x1x1>();
|
||||
break;
|
||||
|
||||
@ -104,7 +104,8 @@ add_library(
|
||||
loraOp.cpp
|
||||
finegrained_mixed_dtype_gemm_thop.cpp
|
||||
tinygemm2.cpp
|
||||
dsv3RopeOp.cpp)
|
||||
dsv3RopeOp.cpp
|
||||
fusedGemmAllreduceOp.cpp)
|
||||
set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
target_link_libraries(
|
||||
th_common PRIVATE ${TORCH_LIBRARIES} th_utils ${Python3_LIBRARIES}
|
||||
|
||||
300
cpp/tensorrt_llm/thop/fusedGemmAllreduceOp.cpp
Normal file
300
cpp/tensorrt_llm/thop/fusedGemmAllreduceOp.cpp
Normal file
@ -0,0 +1,300 @@
|
||||
/*
|
||||
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/include/allreduce_gemm_runner.h"
|
||||
#include "tensorrt_llm/runtime/ipcNvlsMemory.h"
|
||||
#include "tensorrt_llm/thop/thUtils.h"
|
||||
|
||||
#include <ATen/cuda/EmptyTensor.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplRunner;
|
||||
using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplInterface;
|
||||
using tensorrt_llm::kernels::opened_cutlass_kernels::GemmTypes;
|
||||
using tensorrt_llm::kernels::opened_cutlass_kernels::PersistentWorkspaceInterface;
|
||||
|
||||
namespace
|
||||
{
|
||||
struct AllocationKey
|
||||
{
|
||||
int64_t device_index;
|
||||
std::set<int> group;
|
||||
|
||||
bool operator==(AllocationKey const& other) const
|
||||
{
|
||||
return device_index == other.device_index && group == other.group;
|
||||
}
|
||||
|
||||
std::string toString() const
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << "AllocationKey(device: " << device_index << ", group: [";
|
||||
for (int rank : group)
|
||||
{
|
||||
ss << rank << ", ";
|
||||
}
|
||||
ss << "])";
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
|
||||
struct AllocationKeyHash
|
||||
{
|
||||
size_t operator()(AllocationKey const& key) const
|
||||
{
|
||||
size_t seed = 0;
|
||||
|
||||
// Hash the device index
|
||||
hash_combine(seed, key.device_index);
|
||||
|
||||
// Hash the set elements
|
||||
for (auto const& elem : key.group)
|
||||
{
|
||||
hash_combine(seed, elem);
|
||||
}
|
||||
|
||||
return seed;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
static void hash_combine(size_t& seed, T const& val)
|
||||
{
|
||||
seed ^= std::hash<T>()(val) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
};
|
||||
|
||||
class IpcNvlsHandleWrapper
|
||||
{
|
||||
public:
|
||||
IpcNvlsHandleWrapper(size_t size, std::set<int> groups)
|
||||
: mSize(size)
|
||||
{
|
||||
mHandle = tensorrt_llm::runtime::ipcNvlsAllocate(size, groups);
|
||||
}
|
||||
|
||||
tensorrt_llm::runtime::IpcNvlsHandle* getHandle() const
|
||||
{
|
||||
return mHandle;
|
||||
}
|
||||
|
||||
size_t getSize() const
|
||||
{
|
||||
return mSize;
|
||||
}
|
||||
|
||||
~IpcNvlsHandleWrapper()
|
||||
{
|
||||
tensorrt_llm::runtime::ipcNvlsFree(mHandle);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t mSize;
|
||||
tensorrt_llm::runtime::IpcNvlsHandle* mHandle;
|
||||
};
|
||||
|
||||
std::once_flag init_flag;
|
||||
|
||||
size_t getPreferredWorkspaceSize()
|
||||
{
|
||||
// 128MB
|
||||
static size_t preferredWorkspaceSize = 134217728;
|
||||
std::call_once(init_flag,
|
||||
[&]()
|
||||
{
|
||||
char const* envWorkspaceSize = std::getenv("TRTLLM_GEMM_ALLREDUCE_WORKSPACE_SIZE");
|
||||
size_t workspaceSize = 0;
|
||||
if (envWorkspaceSize != nullptr)
|
||||
{
|
||||
workspaceSize = std::atoi(envWorkspaceSize);
|
||||
}
|
||||
preferredWorkspaceSize = std::max(preferredWorkspaceSize, workspaceSize);
|
||||
});
|
||||
return preferredWorkspaceSize;
|
||||
}
|
||||
|
||||
class GemmAllreduceNvlsMemoryManager
|
||||
{
|
||||
public:
|
||||
GemmAllreduceNvlsMemoryManager()
|
||||
{
|
||||
TLLM_LOG_DEBUG("GemmAllreduceNvlsMemoryManager constructor");
|
||||
}
|
||||
|
||||
~GemmAllreduceNvlsMemoryManager()
|
||||
{
|
||||
TLLM_LOG_DEBUG("GemmAllreduceNvlsMemoryManager destructor");
|
||||
}
|
||||
|
||||
std::pair<PersistentWorkspaceInterface*, tensorrt_llm::runtime::IpcNvlsHandle*> getWorkspace(
|
||||
GemmAllReduceImplInterface* runner, GemmAllReduceImplInterface::ProblemArgs const& problem,
|
||||
AllocationKey const& key)
|
||||
{
|
||||
int M = std::get<0>(problem.problem_size);
|
||||
int N = std::get<1>(problem.problem_size);
|
||||
size_t requiredSize = M * N * 2;
|
||||
size_t preferredWorkspaceSize = getPreferredWorkspaceSize();
|
||||
if (requiredSize > preferredWorkspaceSize)
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << "Please set TRTLLM_GEMM_ALLREDUCE_WORKSPACE_SIZE to at least " << requiredSize << " bytes";
|
||||
TLLM_THROW("%s", ss.str().c_str());
|
||||
}
|
||||
|
||||
auto handle = mHandles[key];
|
||||
if (handle == nullptr)
|
||||
{
|
||||
TLLM_LOG_DEBUG("Creating allreduce workspace for %s", key.toString().c_str());
|
||||
handle = std::make_shared<IpcNvlsHandleWrapper>(preferredWorkspaceSize, key.group);
|
||||
GemmAllReduceImplInterface::ProblemArgs tmpArgs;
|
||||
int maxN = 16384;
|
||||
int maxM = preferredWorkspaceSize / (maxN * 2);
|
||||
tmpArgs.argProblemShape(maxM, maxN, 512, 1)
|
||||
.argRanks(problem.rank, problem.ranks)
|
||||
.argLaunchConfig(runner->getSupportedLaunchConfigs()[0]);
|
||||
auto workspace = runner->getPersistentWorkspace(tmpArgs);
|
||||
workspace->allocate();
|
||||
mWorkspaces[key] = workspace;
|
||||
mHandles[key] = handle;
|
||||
}
|
||||
return std::make_pair(mWorkspaces[key].get(), mHandles[key]->getHandle());
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<AllocationKey, std::shared_ptr<PersistentWorkspaceInterface>, AllocationKeyHash> mWorkspaces;
|
||||
std::unordered_map<AllocationKey, std::shared_ptr<IpcNvlsHandleWrapper>, AllocationKeyHash> mHandles;
|
||||
};
|
||||
|
||||
GemmAllreduceNvlsMemoryManager* getGemmAllreduceNvlsMemoryManager()
|
||||
{
|
||||
static GemmAllreduceNvlsMemoryManager gNvlsMemoryManager;
|
||||
return &gNvlsMemoryManager;
|
||||
}
|
||||
|
||||
at::Tensor runGemmImpl(GemmAllReduceImplInterface* runner, GemmAllReduceImplInterface::ProblemArgs& problem,
|
||||
at::ScalarType outputDtype, c10::cuda::CUDAStream stream)
|
||||
{
|
||||
AllocationKey key{stream.device_index(), problem.ranks};
|
||||
auto [workspace, handle] = getGemmAllreduceNvlsMemoryManager()->getWorkspace(runner, problem, key);
|
||||
problem.argD((void*) handle->uc_ptr, (void*) handle->mc_ptr, (void**) handle->ipc_uc_ptrs.data());
|
||||
problem.argWorkspace(workspace);
|
||||
runner->run(problem, stream);
|
||||
size_t dSize
|
||||
= std::get<0>(problem.problem_size) * std::get<1>(problem.problem_size) * c10::elementSize(outputDtype);
|
||||
auto D = at::detail::empty_cuda({std::get<0>(problem.problem_size), std::get<1>(problem.problem_size)}, outputDtype,
|
||||
stream.device(), std::nullopt);
|
||||
TLLM_CUDA_CHECK(cudaMemcpyAsync(
|
||||
D.data_ptr(), reinterpret_cast<void const*>(handle->uc_ptr), dSize, cudaMemcpyDeviceToDevice, stream));
|
||||
return D;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace torch_ext
|
||||
{
|
||||
|
||||
class Fp4GemmAllreduceRunner : public torch::CustomClassHolder
|
||||
{
|
||||
public:
|
||||
explicit Fp4GemmAllreduceRunner(at::ScalarType outputDtype, int64_t rank, torch::List<int64_t> group)
|
||||
: mOutputDtype(outputDtype)
|
||||
, mRank(rank)
|
||||
{
|
||||
for (int64_t rank : group)
|
||||
{
|
||||
mGroup.insert(static_cast<int>(rank));
|
||||
}
|
||||
|
||||
if (outputDtype == at::ScalarType::Half)
|
||||
{
|
||||
using Traits = GemmTypes<cutlass::float_e2m1_t, cutlass::float_e2m1_t, cutlass::half_t, cutlass::half_t,
|
||||
cutlass::float_ue4m3_t, cutlass::float_ue4m3_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::RowMajor, cutlass::layout::RowMajor>;
|
||||
mRunner = std::make_shared<GemmAllReduceImplRunner<Traits>>();
|
||||
}
|
||||
else if (outputDtype == at::ScalarType::BFloat16)
|
||||
{
|
||||
using Traits = GemmTypes<cutlass::float_e2m1_t, cutlass::float_e2m1_t, cutlass::bfloat16_t,
|
||||
cutlass::bfloat16_t, cutlass::float_ue4m3_t, cutlass::float_ue4m3_t, cutlass::layout::RowMajor,
|
||||
cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, cutlass::layout::RowMajor>;
|
||||
mRunner = std::make_shared<GemmAllReduceImplRunner<Traits>>();
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Unsupported output dtype: %s", torch::toString(outputDtype));
|
||||
}
|
||||
|
||||
mConfigs = mRunner->getSupportedLaunchConfigs();
|
||||
}
|
||||
|
||||
at::Tensor runGemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale,
|
||||
at::Tensor const& mat2Scale, at::Tensor const& alpha, int64_t configIdx) const
|
||||
{
|
||||
if (configIdx < 0)
|
||||
configIdx = 0;
|
||||
|
||||
TORCH_CHECK(configIdx < int64_t(mConfigs.size()), "configIdx out of bounds");
|
||||
const int64_t M = mat1.size(0);
|
||||
const int64_t N = mat2.size(0);
|
||||
const int64_t K = mat1.size(1) * 2;
|
||||
|
||||
GemmAllReduceImplInterface::ProblemArgs problemArgs;
|
||||
problemArgs.argProblemShape(M, N, K, 1);
|
||||
problemArgs.argA(mat1.data_ptr());
|
||||
problemArgs.argB(mat2.data_ptr());
|
||||
problemArgs.argAScale(mat1Scale.data_ptr());
|
||||
problemArgs.argBScale(mat2Scale.data_ptr());
|
||||
problemArgs.argC(nullptr);
|
||||
problemArgs.argAlphaPtr(reinterpret_cast<float const*>(alpha.const_data_ptr()));
|
||||
problemArgs.argBeta(0.f);
|
||||
problemArgs.argRanks(mRank, mGroup);
|
||||
problemArgs.argLaunchConfig(mConfigs[configIdx]);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(mat1.get_device());
|
||||
return runGemmImpl(mRunner.get(), problemArgs, mOutputDtype, stream);
|
||||
}
|
||||
|
||||
int64_t getNumConfigs() const
|
||||
{
|
||||
return static_cast<int64_t>(mConfigs.size());
|
||||
}
|
||||
|
||||
private:
|
||||
at::ScalarType mOutputDtype;
|
||||
int mRank;
|
||||
std::set<int> mGroup;
|
||||
std::shared_ptr<GemmAllReduceImplInterface> mRunner{nullptr};
|
||||
std::vector<GemmAllReduceImplInterface::LaunchConfig> mConfigs;
|
||||
};
|
||||
|
||||
} // namespace torch_ext
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.class_<torch_ext::Fp4GemmAllreduceRunner>("Fp4GemmAllreduceRunner")
|
||||
.def(torch::init<at::ScalarType, int64_t, torch::List<int64_t>>())
|
||||
.def("run_gemm", &torch_ext::Fp4GemmAllreduceRunner::runGemm)
|
||||
.def("get_num_configs", &torch_ext::Fp4GemmAllreduceRunner::getNumConfigs);
|
||||
}
|
||||
@ -1869,3 +1869,99 @@ def record_stream(tensor: torch.Tensor, stream_id: int) -> None:
|
||||
stream = get_stream(stream_id)
|
||||
assert stream is not None
|
||||
tensor.record_stream(stream)
|
||||
|
||||
|
||||
class Fp4GemmAllreduceRunner(TunableRunner):
|
||||
runner_dict = dict()
|
||||
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
|
||||
0, 0, get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2), ),
|
||||
constraint_specs=(ConstraintSpec(
|
||||
2, 0, fp4_scale_infer_shape), ))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_dtype: torch.dtype,
|
||||
tp_rank: int,
|
||||
tp_group: List[int],
|
||||
):
|
||||
self.output_dtype = output_dtype
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_group_str = '-'.join(str(g) for g in tp_group)
|
||||
instance_key = (output_dtype, self.tp_group_str)
|
||||
if instance_key not in Fp4GemmAllreduceRunner.runner_dict:
|
||||
Fp4GemmAllreduceRunner.runner_dict[
|
||||
instance_key] = torch.classes.trtllm.Fp4GemmAllreduceRunner(
|
||||
output_dtype, tp_rank, tp_group)
|
||||
self.fp4_gemm_all_reduce_runner = Fp4GemmAllreduceRunner.runner_dict[
|
||||
instance_key]
|
||||
|
||||
def unique_id(self):
|
||||
return (self.output_dtype, self.tp_group_str)
|
||||
|
||||
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile, **kwargs) -> List[int]:
|
||||
return list(range(self.fp4_gemm_all_reduce_runner.get_num_configs()))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
tactic: int = 0,
|
||||
) -> torch.Tensor:
|
||||
mat1, mat2, mat1_scale, mat2_scale, global_scale = inputs
|
||||
return self.fp4_gemm_all_reduce_runner.run_gemm(
|
||||
mat1,
|
||||
mat2,
|
||||
mat1_scale,
|
||||
mat2_scale,
|
||||
global_scale,
|
||||
tactic,
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("trtllm::nvfp4_gemm_allreduce", mutates_args=())
|
||||
def nvfp4_gemm_allreduce(
|
||||
act_fp4: torch.Tensor,
|
||||
weight_fp4: torch.Tensor,
|
||||
act_sf: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
alpha: torch.Tensor,
|
||||
output_dtype: torch.dtype,
|
||||
tp_rank: int,
|
||||
tp_group: List[int],
|
||||
) -> torch.Tensor:
|
||||
AutoTuner.get()
|
||||
|
||||
# Use Cutlass runner with predefined configs
|
||||
nvfp4_gemm_allreduce_runner = Fp4GemmAllreduceRunner(
|
||||
output_dtype, tp_rank, tp_group)
|
||||
|
||||
# TODO: Enable auto-tuning
|
||||
# runner_type = type(nvfp4_gemm_allreduce_runner).__name__
|
||||
# _, best_tactic = tuner.choose_one(
|
||||
# f"trtllm::nvfp4_gemm_allreduce::{runner_type}",
|
||||
# [nvfp4_gemm_allreduce_runner],
|
||||
# nvfp4_gemm_allreduce_runner.tuning_config,
|
||||
# [act_fp4, weight, act_sf, weight_scale, alpha],
|
||||
# )
|
||||
|
||||
best_tactic = -1
|
||||
|
||||
return nvfp4_gemm_allreduce_runner(
|
||||
inputs=[act_fp4, weight_fp4, act_sf, weight_scale, alpha],
|
||||
tactic=best_tactic)
|
||||
|
||||
|
||||
@nvfp4_gemm_allreduce.register_fake
|
||||
def _(
|
||||
act_fp4: torch.Tensor,
|
||||
weight_fp4: torch.Tensor,
|
||||
act_sf: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
alpha: torch.Tensor,
|
||||
output_dtype: torch.dtype,
|
||||
tp_rank: int,
|
||||
tp_group: List[int],
|
||||
) -> torch.Tensor:
|
||||
return act_fp4.new_empty((act_fp4.size(0), weight_fp4.size(0)),
|
||||
dtype=output_dtype)
|
||||
|
||||
@ -14,7 +14,8 @@ from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
|
||||
AllReduceParams, MoEAllReduce)
|
||||
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
|
||||
BaseWeightMapper
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm._utils import get_sm_version, mpi_disabled
|
||||
from tensorrt_llm.bindings import ipc_nvls_supported
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
from tensorrt_llm.inputs.multimodal import MultimodalParams
|
||||
from tensorrt_llm.logger import logger
|
||||
@ -673,9 +674,51 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
# Disable fusion for small models due to accuracy issues
|
||||
self.enable_fusion &= config.hidden_size > 4096
|
||||
|
||||
self.PRE_MLP_FUSION = self.mapping.has_tp(
|
||||
enable_gemm_allreduce_fusion = (os.environ.get(
|
||||
"TRTLLM_GEMM_ALLREDUCE_FUSION_ENABLED", "1") == "1")
|
||||
mpi_enabled = not mpi_disabled()
|
||||
dtype_supported = config.torch_dtype in (torch.float16, torch.bfloat16)
|
||||
tp_valid = self.mapping.tp_size > 1
|
||||
quant_valid = self.is_nvfp4 is not None and self.is_nvfp4
|
||||
|
||||
device_supported = get_sm_version() >= 100
|
||||
nvls_supported = ipc_nvls_supported()
|
||||
|
||||
use_fused_gemm_allreduce = all([
|
||||
enable_gemm_allreduce_fusion, mpi_enabled, dtype_supported,
|
||||
tp_valid, quant_valid, device_supported, nvls_supported
|
||||
])
|
||||
|
||||
def check_in_out_features(in_features, out_features):
|
||||
in_feature_valid = in_features % 128 == 0 and in_features >= 1024
|
||||
out_feature_valid = out_features % 64 == 0 and out_features >= 1024
|
||||
return all([in_feature_valid, out_feature_valid])
|
||||
|
||||
num_heads = config.num_attention_heads
|
||||
head_dim = getattr(config, 'head_dim', None)
|
||||
if not isinstance(head_dim, int):
|
||||
head_dim = config.hidden_size // num_heads
|
||||
|
||||
in_features = num_heads * head_dim
|
||||
out_features = config.hidden_size
|
||||
in_out_features_valid = check_in_out_features(in_features, out_features)
|
||||
|
||||
attn_fused_gemm_allreduce = all(
|
||||
[use_fused_gemm_allreduce, in_out_features_valid])
|
||||
self.PRE_MLP_FUSION = not attn_fused_gemm_allreduce and self.mapping.has_tp(
|
||||
) and not self.enable_attention_dp and self.enable_fusion
|
||||
self.POST_MLP_FUSION = self.mapping.has_tp() and self.enable_fusion
|
||||
|
||||
in_features = config.intermediate_size
|
||||
out_features = config.hidden_size
|
||||
in_features_aligned_with_tp = in_features % self.mapping.tp_size == 0
|
||||
in_out_features_valid = check_in_out_features(
|
||||
in_features // self.mapping.tp_size, out_features)
|
||||
mlp_fused_gemm_allreduce = all([
|
||||
use_fused_gemm_allreduce, in_features_aligned_with_tp,
|
||||
in_out_features_valid
|
||||
])
|
||||
self.POST_MLP_FUSION = not mlp_fused_gemm_allreduce and self.mapping.has_tp(
|
||||
) and self.enable_fusion
|
||||
|
||||
if self.is_nvfp4:
|
||||
self.pre_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4
|
||||
|
||||
@ -14,7 +14,8 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
|
||||
from tensorrt_llm._torch.peft.lora.layer import LoraLayer
|
||||
from tensorrt_llm._utils import is_device_integrated
|
||||
from tensorrt_llm._utils import is_device_integrated, mpi_disabled
|
||||
from tensorrt_llm.bindings import ipc_nvls_supported
|
||||
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
|
||||
AllReduceStrategy)
|
||||
from tensorrt_llm.logger import logger
|
||||
@ -307,6 +308,11 @@ class LinearMethodBase(ABC):
|
||||
bias: Optional[torch.Tensor], *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
|
||||
bias: Optional[torch.Tensor], tp_rank: int,
|
||||
tp_group: List[int], *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def load_weights(self,
|
||||
module: Linear,
|
||||
weights: List[Dict],
|
||||
@ -908,8 +914,7 @@ class NVFP4LinearMethod(LinearMethodBase):
|
||||
else:
|
||||
module.register_parameter("bias", None)
|
||||
|
||||
def apply(self, module: Linear, input: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]):
|
||||
def _input_prepare(self, module: Linear, input: torch.Tensor):
|
||||
if isinstance(input, Fp4QuantizedTensor):
|
||||
# Input is already quantized - this should not happen if pre_quant_scale exists
|
||||
# because we disable FP4 output for attention output when pre_quant_scale is present
|
||||
@ -935,7 +940,11 @@ class NVFP4LinearMethod(LinearMethodBase):
|
||||
|
||||
act_fp4, act_sf = torch.ops.trtllm.fp4_quantize(
|
||||
input, module.input_scale, module.scaling_vector_size, False)
|
||||
return act_fp4, act_sf
|
||||
|
||||
def apply(self, module: Linear, input: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]):
|
||||
act_fp4, act_sf = self._input_prepare(module, input)
|
||||
# Use unified interface - supports CUTLASS, cuBLASLt, CuteDSL
|
||||
# Convert list to comma-separated string for torch.compile compatibility
|
||||
allowed_backends_str = ','.join(module.nvfp4_allowed_backends)
|
||||
@ -956,6 +965,21 @@ class NVFP4LinearMethod(LinearMethodBase):
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
|
||||
bias: Optional[torch.Tensor], tp_rank: int,
|
||||
tp_group: List[int]):
|
||||
act_fp4, act_sf = self._input_prepare(module, input)
|
||||
output = torch.ops.trtllm.nvfp4_gemm_allreduce(
|
||||
act_fp4, module.weight, act_sf, module.weight_scale, module.alpha,
|
||||
module.dtype, tp_rank, tp_group)
|
||||
# Take the dim of out_features if padded. Make sure the output is contiguous
|
||||
if output.shape[-1] > module.out_features:
|
||||
output = output[..., :module.out_features].contiguous()
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
def load_kv_scales(self, weights: List[Dict]):
|
||||
k_scale, v_scale = [], []
|
||||
for w in weights:
|
||||
@ -2133,6 +2157,23 @@ class Linear(nn.Module):
|
||||
self.use_custom_cublas_mm = use_custom_cublas_mm
|
||||
self.lora = lora
|
||||
|
||||
mpi_enabled = not mpi_disabled()
|
||||
dtype_supported = self.dtype in (torch.float16, torch.bfloat16)
|
||||
in_features_aligned = self.in_features % 128 == 0
|
||||
out_features_aligned = self.out_features % 64 == 0
|
||||
tp_valid = self.tp_mode is not None and self.tp_mode == TensorParallelMode.ROW and self.tp_size > 1
|
||||
quant_valid = self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4(
|
||||
)
|
||||
|
||||
device_supported = get_sm_version() >= 100
|
||||
nvls_supported = ipc_nvls_supported()
|
||||
|
||||
self.use_fused_gemm_allreduce = all([
|
||||
self.reduce_output, mpi_enabled, dtype_supported,
|
||||
in_features_aligned, out_features_aligned, tp_valid, quant_valid,
|
||||
device_supported, nvls_supported
|
||||
])
|
||||
|
||||
self.enable_cuda_core = False
|
||||
if torch.cuda.is_available():
|
||||
capability = torch.cuda.get_device_capability(
|
||||
@ -2224,13 +2265,20 @@ class Linear(nn.Module):
|
||||
lora_params: Optional[dict] | None = None,
|
||||
layer_idx: Optional[int] | None = None):
|
||||
output = self.quant_method.apply(self, input, bias)
|
||||
|
||||
if self.lora is not None and bool(lora_params):
|
||||
lora_result = self.lora(input, lora_params, layer_idx)
|
||||
if lora_result is not None:
|
||||
output = output + lora_result
|
||||
return output
|
||||
|
||||
def apply_linear_allreduce(self,
|
||||
input,
|
||||
bias,
|
||||
layer_idx: Optional[int] | None = None):
|
||||
output = self.quant_method.apply_linear_allreduce(
|
||||
self, input, bias, self.tp_rank, self.mapping.tp_group)
|
||||
return output
|
||||
|
||||
def _maybe_fuse_bias_into_allreduce(
|
||||
self,
|
||||
bias: Optional[torch.Tensor],
|
||||
@ -2257,16 +2305,23 @@ class Linear(nn.Module):
|
||||
layer_idx: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.tp_mode == TensorParallelMode.ROW:
|
||||
use_fused_gemm_allreduce = self.use_fused_gemm_allreduce and lora_params is None
|
||||
if use_fused_gemm_allreduce and all_reduce_params is not None:
|
||||
use_fused_gemm_allreduce = all_reduce_params.enable_allreduce and all_reduce_params.fusion_op == AllReduceFusionOp.NONE
|
||||
|
||||
bias = None if (self.tp_rank > 0) else self.bias
|
||||
if self.reduce_output:
|
||||
fuse_bias = self._maybe_fuse_bias_into_allreduce(
|
||||
bias, all_reduce_params)
|
||||
bias = None if fuse_bias else bias
|
||||
output = self.apply_linear(input, bias, lora_params, layer_idx)
|
||||
output = self.all_reduce(
|
||||
output,
|
||||
all_reduce_params=all_reduce_params,
|
||||
)
|
||||
if use_fused_gemm_allreduce:
|
||||
output = self.apply_linear_allreduce(
|
||||
input, self.bias, layer_idx)
|
||||
else:
|
||||
fuse_bias = self._maybe_fuse_bias_into_allreduce(
|
||||
bias, all_reduce_params)
|
||||
bias = None if fuse_bias else bias
|
||||
output = self.apply_linear(input, bias, lora_params,
|
||||
layer_idx)
|
||||
output = self.all_reduce(
|
||||
output, all_reduce_params=all_reduce_params)
|
||||
else:
|
||||
output = self.apply_linear(input, bias, lora_params, layer_idx)
|
||||
elif self.tp_mode == TensorParallelMode.COLUMN:
|
||||
|
||||
@ -7,11 +7,16 @@ import pytest
|
||||
import torch
|
||||
from mpi4py import MPI
|
||||
from torch import nn
|
||||
from utils.util import skip_pre_blackwell
|
||||
|
||||
import tensorrt_llm
|
||||
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
|
||||
from tensorrt_llm._torch.autotuner import autotune
|
||||
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
|
||||
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceParams
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.math_utils import pad_up
|
||||
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
|
||||
|
||||
cloudpickle.register_pickle_by_value(sys.modules[__name__])
|
||||
MPI.pickle.__init__(
|
||||
@ -161,7 +166,6 @@ def row_linear_forward(x, hidden_size, dtype, tensor_parallel_size,
|
||||
)
|
||||
l0.load_weights([dict(weight=weights[0])])
|
||||
l0.cuda()
|
||||
|
||||
xs = torch.chunk(x, 2, dim=-1)
|
||||
l0 = torch.compile(l0, fullgraph=True)
|
||||
output = l0.forward(xs[tensor_parallel_rank])
|
||||
@ -333,3 +337,150 @@ def test_row_linear_norm_fusion(seq_len, hidden_size, mpi_pool_executor):
|
||||
[l0_weight], hidden_size, dtype)] * 2))
|
||||
for r in results:
|
||||
assert r is True
|
||||
|
||||
|
||||
def check_accuracy(a, b, atol, rtol, percent):
|
||||
assert a.shape == b.shape
|
||||
assert a.dtype == b.dtype
|
||||
a = a.to(torch.float32)
|
||||
b = b.to(torch.float32)
|
||||
left = torch.abs(a - b)
|
||||
right = atol + rtol * torch.abs(b)
|
||||
count = torch.sum(left > right)
|
||||
mismatch_percent = count / a.numel()
|
||||
if not (mismatch_percent < 1 - percent):
|
||||
raise Exception("Mismatch percentage is %f for rtol %f" %
|
||||
(mismatch_percent, rtol))
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def fp4_row_linear_allreduce(tp_size, local_rank, seq_len, output_size,
|
||||
hidden_size, dtype, output_ref, x_sf_global,
|
||||
w_sf_global, x_fp4s, w_fp4, x_sf_blocks,
|
||||
w_sf_block_unswizzled):
|
||||
output_ref = output_ref.cuda()
|
||||
x_sf_global = x_sf_global.cuda()
|
||||
w_sf_global = w_sf_global.cuda()
|
||||
x_fp4 = x_fp4s[local_rank].cuda()
|
||||
w_fp4 = w_fp4.cuda()
|
||||
x_sf_block = x_sf_blocks[local_rank].cuda()
|
||||
w_sf_block_unswizzled = w_sf_block_unswizzled.cuda()
|
||||
|
||||
qc = QuantConfig(quant_algo=QuantAlgo.NVFP4)
|
||||
l0 = Linear(
|
||||
in_features=hidden_size,
|
||||
out_features=output_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
quant_config=qc,
|
||||
mapping=Mapping(
|
||||
world_size=tp_size,
|
||||
tp_size=tp_size,
|
||||
rank=local_rank,
|
||||
),
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
)
|
||||
|
||||
l0.load_weights([{
|
||||
'input_scale':
|
||||
1.0 / x_sf_global.cpu(),
|
||||
'weight':
|
||||
w_fp4.cpu(),
|
||||
'weight_scale':
|
||||
w_sf_block_unswizzled.view(torch.float8_e4m3fn),
|
||||
'weight_scale_2':
|
||||
1.0 / w_sf_global.cpu()
|
||||
}])
|
||||
|
||||
l0.cuda()
|
||||
# TODO: parameters['weight']' size mismatch at index 0
|
||||
# l0 = torch.compile(l0)
|
||||
with torch.inference_mode(), autotune():
|
||||
output = l0.forward((x_fp4, x_sf_block))
|
||||
|
||||
torch.cuda.synchronize()
|
||||
check_accuracy(output, output_ref, atol=0.05, rtol=0.05, percent=0.99)
|
||||
|
||||
|
||||
def fp4_row_linear_allreduce_run_single_rank(func, tp_size, seq_len,
|
||||
output_size, hidden_size, dtype,
|
||||
output_ref, x_sf_global,
|
||||
w_sf_global, x_fp4s, w_fp4,
|
||||
x_sf_blocks,
|
||||
w_sf_block_unswizzled):
|
||||
local_rank = tensorrt_llm.mpi_rank()
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
try:
|
||||
func(tp_size, local_rank, seq_len, output_size, hidden_size, dtype,
|
||||
output_ref, x_sf_global, w_sf_global, x_fp4s, w_fp4, x_sf_blocks,
|
||||
w_sf_block_unswizzled)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
raise
|
||||
return True
|
||||
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason='needs 2 GPUs to run this test')
|
||||
@pytest.mark.parametrize("seq_len", [256, 400], ids=lambda x: f"seqlen:{x}")
|
||||
@pytest.mark.parametrize("output_size", [32, 64], ids=lambda x: f"output:{x}")
|
||||
@pytest.mark.parametrize("hidden_size", [128, 256], ids=lambda x: f"hidden:{x}")
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16],
|
||||
ids=lambda x: f"dtype:{x}")
|
||||
@pytest.mark.parametrize("mpi_pool_executor", [2],
|
||||
indirect=True,
|
||||
ids=lambda x: f"tp_size:{x}")
|
||||
def test_fp4_row_linear_allreduce(seq_len, output_size, hidden_size, dtype,
|
||||
mpi_pool_executor):
|
||||
torch.manual_seed(42)
|
||||
tp_size = mpi_pool_executor.num_workers
|
||||
|
||||
x = torch.randn((seq_len, hidden_size), dtype=dtype).cuda()
|
||||
w = torch.randn((output_size, hidden_size), dtype=dtype).cuda()
|
||||
|
||||
scaling_vector_size = 16
|
||||
x_sf_global = (448 * 6) / x.abs().max().float()
|
||||
x_fp4, x_sf_block = torch.ops.trtllm.fp4_quantize(x, x_sf_global,
|
||||
scaling_vector_size,
|
||||
False)
|
||||
w_sf_global = (448 * 6) / w.abs().max().float()
|
||||
w_fp4, w_sf_block = torch.ops.trtllm.fp4_quantize(w, w_sf_global,
|
||||
scaling_vector_size,
|
||||
False)
|
||||
w_sf_block_unswizzled = (torch.ops.trtllm.block_scale_interleave_reverse(
|
||||
w_sf_block.cpu().view(pad_up(output_size, 128), -1)))
|
||||
|
||||
with torch.inference_mode():
|
||||
alpha_ref = 1.0 / (w_sf_global * x_sf_global)
|
||||
output_ref = torch.ops.trtllm.fp4_gemm(
|
||||
x_fp4, w_fp4, x_sf_block, w_sf_block, alpha_ref,
|
||||
fp4_utils.FP4GemmType.W4A4_NVFP4_NVFP4, dtype)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
xs = [x.contiguous().cuda() for x in torch.chunk(x, tp_size, dim=-1)]
|
||||
x_fp4s = []
|
||||
x_sf_blocks = []
|
||||
for i in range(tp_size):
|
||||
_fp4, _sf_block = torch.ops.trtllm.fp4_quantize(xs[i], x_sf_global,
|
||||
scaling_vector_size,
|
||||
False)
|
||||
x_fp4s.append(_fp4.cpu())
|
||||
x_sf_blocks.append(_sf_block.cpu())
|
||||
|
||||
output_ref = output_ref.cpu()
|
||||
x_sf_global = x_sf_global.cpu()
|
||||
w_sf_global = w_sf_global.cpu()
|
||||
w_fp4 = w_fp4.cpu()
|
||||
w_sf_block_unswizzled = w_sf_block_unswizzled.cpu()
|
||||
|
||||
results = mpi_pool_executor.map(
|
||||
fp4_row_linear_allreduce_run_single_rank,
|
||||
*zip(*[(fp4_row_linear_allreduce, tp_size, seq_len, output_size,
|
||||
hidden_size, dtype, output_ref, x_sf_global, w_sf_global,
|
||||
x_fp4s, w_fp4, x_sf_blocks, w_sf_block_unswizzled)] * tp_size))
|
||||
|
||||
for r in results:
|
||||
assert r is True
|
||||
|
||||
Loading…
Reference in New Issue
Block a user