[None][feat] add fp4 gemm + allreduce (#9729)

Signed-off-by: benzh 
Signed-off-by: benzh-2025
This commit is contained in:
benzh-2025 2026-01-13 21:11:13 +08:00 committed by GitHub
parent c1b0b7350f
commit 6df2c8a074
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 665 additions and 18 deletions

View File

@ -141,7 +141,7 @@ public:
// Epilogue
////////////////
using FusionCallbacks = cutlass::epilogue::fusion::LinearCombination<ElementD, float, void, float>;
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, true>;
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, false>;
using EpilogueScheduleType = typename MmaAdapter<MmaType, IsFP4>::EpilogueSchedule;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using FusionOp

View File

@ -191,6 +191,7 @@ GemmAllReduceImplRunner<GemmTraits>::GemmAllReduceImplRunner()
break;
// Blackwell
case 100:
case 103:
registry_builder.addSm100<GemmTraits, GemmAllReduceImpl::kNVLS_2SHOT, _2SM, TileShape::TileShape_128x256x128,
ClusterShape::ClusterShape_4x1x1>();
break;

View File

@ -104,7 +104,8 @@ add_library(
loraOp.cpp
finegrained_mixed_dtype_gemm_thop.cpp
tinygemm2.cpp
dsv3RopeOp.cpp)
dsv3RopeOp.cpp
fusedGemmAllreduceOp.cpp)
set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON)
target_link_libraries(
th_common PRIVATE ${TORCH_LIBRARIES} th_utils ${Python3_LIBRARIES}

View File

@ -0,0 +1,300 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cutlass_extensions/gemm_configs.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/include/allreduce_gemm_runner.h"
#include "tensorrt_llm/runtime/ipcNvlsMemory.h"
#include "tensorrt_llm/thop/thUtils.h"
#include <ATen/cuda/EmptyTensor.h>
#include <cstddef>
#include <cuda_fp16.h>
#include <cstdint>
#include <functional>
#include <type_traits>
#include <vector>
using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplRunner;
using tensorrt_llm::kernels::opened_cutlass_kernels::GemmAllReduceImplInterface;
using tensorrt_llm::kernels::opened_cutlass_kernels::GemmTypes;
using tensorrt_llm::kernels::opened_cutlass_kernels::PersistentWorkspaceInterface;
namespace
{
struct AllocationKey
{
int64_t device_index;
std::set<int> group;
bool operator==(AllocationKey const& other) const
{
return device_index == other.device_index && group == other.group;
}
std::string toString() const
{
std::stringstream ss;
ss << "AllocationKey(device: " << device_index << ", group: [";
for (int rank : group)
{
ss << rank << ", ";
}
ss << "])";
return ss.str();
}
};
struct AllocationKeyHash
{
size_t operator()(AllocationKey const& key) const
{
size_t seed = 0;
// Hash the device index
hash_combine(seed, key.device_index);
// Hash the set elements
for (auto const& elem : key.group)
{
hash_combine(seed, elem);
}
return seed;
}
private:
template <typename T>
static void hash_combine(size_t& seed, T const& val)
{
seed ^= std::hash<T>()(val) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
};
class IpcNvlsHandleWrapper
{
public:
IpcNvlsHandleWrapper(size_t size, std::set<int> groups)
: mSize(size)
{
mHandle = tensorrt_llm::runtime::ipcNvlsAllocate(size, groups);
}
tensorrt_llm::runtime::IpcNvlsHandle* getHandle() const
{
return mHandle;
}
size_t getSize() const
{
return mSize;
}
~IpcNvlsHandleWrapper()
{
tensorrt_llm::runtime::ipcNvlsFree(mHandle);
}
private:
size_t mSize;
tensorrt_llm::runtime::IpcNvlsHandle* mHandle;
};
std::once_flag init_flag;
size_t getPreferredWorkspaceSize()
{
// 128MB
static size_t preferredWorkspaceSize = 134217728;
std::call_once(init_flag,
[&]()
{
char const* envWorkspaceSize = std::getenv("TRTLLM_GEMM_ALLREDUCE_WORKSPACE_SIZE");
size_t workspaceSize = 0;
if (envWorkspaceSize != nullptr)
{
workspaceSize = std::atoi(envWorkspaceSize);
}
preferredWorkspaceSize = std::max(preferredWorkspaceSize, workspaceSize);
});
return preferredWorkspaceSize;
}
class GemmAllreduceNvlsMemoryManager
{
public:
GemmAllreduceNvlsMemoryManager()
{
TLLM_LOG_DEBUG("GemmAllreduceNvlsMemoryManager constructor");
}
~GemmAllreduceNvlsMemoryManager()
{
TLLM_LOG_DEBUG("GemmAllreduceNvlsMemoryManager destructor");
}
std::pair<PersistentWorkspaceInterface*, tensorrt_llm::runtime::IpcNvlsHandle*> getWorkspace(
GemmAllReduceImplInterface* runner, GemmAllReduceImplInterface::ProblemArgs const& problem,
AllocationKey const& key)
{
int M = std::get<0>(problem.problem_size);
int N = std::get<1>(problem.problem_size);
size_t requiredSize = M * N * 2;
size_t preferredWorkspaceSize = getPreferredWorkspaceSize();
if (requiredSize > preferredWorkspaceSize)
{
std::stringstream ss;
ss << "Please set TRTLLM_GEMM_ALLREDUCE_WORKSPACE_SIZE to at least " << requiredSize << " bytes";
TLLM_THROW("%s", ss.str().c_str());
}
auto handle = mHandles[key];
if (handle == nullptr)
{
TLLM_LOG_DEBUG("Creating allreduce workspace for %s", key.toString().c_str());
handle = std::make_shared<IpcNvlsHandleWrapper>(preferredWorkspaceSize, key.group);
GemmAllReduceImplInterface::ProblemArgs tmpArgs;
int maxN = 16384;
int maxM = preferredWorkspaceSize / (maxN * 2);
tmpArgs.argProblemShape(maxM, maxN, 512, 1)
.argRanks(problem.rank, problem.ranks)
.argLaunchConfig(runner->getSupportedLaunchConfigs()[0]);
auto workspace = runner->getPersistentWorkspace(tmpArgs);
workspace->allocate();
mWorkspaces[key] = workspace;
mHandles[key] = handle;
}
return std::make_pair(mWorkspaces[key].get(), mHandles[key]->getHandle());
}
private:
std::unordered_map<AllocationKey, std::shared_ptr<PersistentWorkspaceInterface>, AllocationKeyHash> mWorkspaces;
std::unordered_map<AllocationKey, std::shared_ptr<IpcNvlsHandleWrapper>, AllocationKeyHash> mHandles;
};
GemmAllreduceNvlsMemoryManager* getGemmAllreduceNvlsMemoryManager()
{
static GemmAllreduceNvlsMemoryManager gNvlsMemoryManager;
return &gNvlsMemoryManager;
}
at::Tensor runGemmImpl(GemmAllReduceImplInterface* runner, GemmAllReduceImplInterface::ProblemArgs& problem,
at::ScalarType outputDtype, c10::cuda::CUDAStream stream)
{
AllocationKey key{stream.device_index(), problem.ranks};
auto [workspace, handle] = getGemmAllreduceNvlsMemoryManager()->getWorkspace(runner, problem, key);
problem.argD((void*) handle->uc_ptr, (void*) handle->mc_ptr, (void**) handle->ipc_uc_ptrs.data());
problem.argWorkspace(workspace);
runner->run(problem, stream);
size_t dSize
= std::get<0>(problem.problem_size) * std::get<1>(problem.problem_size) * c10::elementSize(outputDtype);
auto D = at::detail::empty_cuda({std::get<0>(problem.problem_size), std::get<1>(problem.problem_size)}, outputDtype,
stream.device(), std::nullopt);
TLLM_CUDA_CHECK(cudaMemcpyAsync(
D.data_ptr(), reinterpret_cast<void const*>(handle->uc_ptr), dSize, cudaMemcpyDeviceToDevice, stream));
return D;
}
} // namespace
namespace torch_ext
{
class Fp4GemmAllreduceRunner : public torch::CustomClassHolder
{
public:
explicit Fp4GemmAllreduceRunner(at::ScalarType outputDtype, int64_t rank, torch::List<int64_t> group)
: mOutputDtype(outputDtype)
, mRank(rank)
{
for (int64_t rank : group)
{
mGroup.insert(static_cast<int>(rank));
}
if (outputDtype == at::ScalarType::Half)
{
using Traits = GemmTypes<cutlass::float_e2m1_t, cutlass::float_e2m1_t, cutlass::half_t, cutlass::half_t,
cutlass::float_ue4m3_t, cutlass::float_ue4m3_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor, cutlass::layout::RowMajor>;
mRunner = std::make_shared<GemmAllReduceImplRunner<Traits>>();
}
else if (outputDtype == at::ScalarType::BFloat16)
{
using Traits = GemmTypes<cutlass::float_e2m1_t, cutlass::float_e2m1_t, cutlass::bfloat16_t,
cutlass::bfloat16_t, cutlass::float_ue4m3_t, cutlass::float_ue4m3_t, cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, cutlass::layout::RowMajor>;
mRunner = std::make_shared<GemmAllReduceImplRunner<Traits>>();
}
else
{
TLLM_THROW("Unsupported output dtype: %s", torch::toString(outputDtype));
}
mConfigs = mRunner->getSupportedLaunchConfigs();
}
at::Tensor runGemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale,
at::Tensor const& mat2Scale, at::Tensor const& alpha, int64_t configIdx) const
{
if (configIdx < 0)
configIdx = 0;
TORCH_CHECK(configIdx < int64_t(mConfigs.size()), "configIdx out of bounds");
const int64_t M = mat1.size(0);
const int64_t N = mat2.size(0);
const int64_t K = mat1.size(1) * 2;
GemmAllReduceImplInterface::ProblemArgs problemArgs;
problemArgs.argProblemShape(M, N, K, 1);
problemArgs.argA(mat1.data_ptr());
problemArgs.argB(mat2.data_ptr());
problemArgs.argAScale(mat1Scale.data_ptr());
problemArgs.argBScale(mat2Scale.data_ptr());
problemArgs.argC(nullptr);
problemArgs.argAlphaPtr(reinterpret_cast<float const*>(alpha.const_data_ptr()));
problemArgs.argBeta(0.f);
problemArgs.argRanks(mRank, mGroup);
problemArgs.argLaunchConfig(mConfigs[configIdx]);
auto stream = at::cuda::getCurrentCUDAStream(mat1.get_device());
return runGemmImpl(mRunner.get(), problemArgs, mOutputDtype, stream);
}
int64_t getNumConfigs() const
{
return static_cast<int64_t>(mConfigs.size());
}
private:
at::ScalarType mOutputDtype;
int mRank;
std::set<int> mGroup;
std::shared_ptr<GemmAllReduceImplInterface> mRunner{nullptr};
std::vector<GemmAllReduceImplInterface::LaunchConfig> mConfigs;
};
} // namespace torch_ext
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.class_<torch_ext::Fp4GemmAllreduceRunner>("Fp4GemmAllreduceRunner")
.def(torch::init<at::ScalarType, int64_t, torch::List<int64_t>>())
.def("run_gemm", &torch_ext::Fp4GemmAllreduceRunner::runGemm)
.def("get_num_configs", &torch_ext::Fp4GemmAllreduceRunner::getNumConfigs);
}

View File

@ -1869,3 +1869,99 @@ def record_stream(tensor: torch.Tensor, stream_id: int) -> None:
stream = get_stream(stream_id)
assert stream is not None
tensor.record_stream(stream)
class Fp4GemmAllreduceRunner(TunableRunner):
runner_dict = dict()
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
0, 0, get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2), ),
constraint_specs=(ConstraintSpec(
2, 0, fp4_scale_infer_shape), ))
def __init__(
self,
output_dtype: torch.dtype,
tp_rank: int,
tp_group: List[int],
):
self.output_dtype = output_dtype
self.tp_rank = tp_rank
self.tp_group_str = '-'.join(str(g) for g in tp_group)
instance_key = (output_dtype, self.tp_group_str)
if instance_key not in Fp4GemmAllreduceRunner.runner_dict:
Fp4GemmAllreduceRunner.runner_dict[
instance_key] = torch.classes.trtllm.Fp4GemmAllreduceRunner(
output_dtype, tp_rank, tp_group)
self.fp4_gemm_all_reduce_runner = Fp4GemmAllreduceRunner.runner_dict[
instance_key]
def unique_id(self):
return (self.output_dtype, self.tp_group_str)
def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile, **kwargs) -> List[int]:
return list(range(self.fp4_gemm_all_reduce_runner.get_num_configs()))
def forward(
self,
inputs: List[torch.Tensor],
tactic: int = 0,
) -> torch.Tensor:
mat1, mat2, mat1_scale, mat2_scale, global_scale = inputs
return self.fp4_gemm_all_reduce_runner.run_gemm(
mat1,
mat2,
mat1_scale,
mat2_scale,
global_scale,
tactic,
)
@torch.library.custom_op("trtllm::nvfp4_gemm_allreduce", mutates_args=())
def nvfp4_gemm_allreduce(
act_fp4: torch.Tensor,
weight_fp4: torch.Tensor,
act_sf: torch.Tensor,
weight_scale: torch.Tensor,
alpha: torch.Tensor,
output_dtype: torch.dtype,
tp_rank: int,
tp_group: List[int],
) -> torch.Tensor:
AutoTuner.get()
# Use Cutlass runner with predefined configs
nvfp4_gemm_allreduce_runner = Fp4GemmAllreduceRunner(
output_dtype, tp_rank, tp_group)
# TODO: Enable auto-tuning
# runner_type = type(nvfp4_gemm_allreduce_runner).__name__
# _, best_tactic = tuner.choose_one(
# f"trtllm::nvfp4_gemm_allreduce::{runner_type}",
# [nvfp4_gemm_allreduce_runner],
# nvfp4_gemm_allreduce_runner.tuning_config,
# [act_fp4, weight, act_sf, weight_scale, alpha],
# )
best_tactic = -1
return nvfp4_gemm_allreduce_runner(
inputs=[act_fp4, weight_fp4, act_sf, weight_scale, alpha],
tactic=best_tactic)
@nvfp4_gemm_allreduce.register_fake
def _(
act_fp4: torch.Tensor,
weight_fp4: torch.Tensor,
act_sf: torch.Tensor,
weight_scale: torch.Tensor,
alpha: torch.Tensor,
output_dtype: torch.dtype,
tp_rank: int,
tp_group: List[int],
) -> torch.Tensor:
return act_fp4.new_empty((act_fp4.size(0), weight_fp4.size(0)),
dtype=output_dtype)

View File

@ -14,7 +14,8 @@ from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
AllReduceParams, MoEAllReduce)
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
BaseWeightMapper
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm._utils import get_sm_version, mpi_disabled
from tensorrt_llm.bindings import ipc_nvls_supported
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.inputs.multimodal import MultimodalParams
from tensorrt_llm.logger import logger
@ -673,9 +674,51 @@ class LlamaDecoderLayer(DecoderLayer):
# Disable fusion for small models due to accuracy issues
self.enable_fusion &= config.hidden_size > 4096
self.PRE_MLP_FUSION = self.mapping.has_tp(
enable_gemm_allreduce_fusion = (os.environ.get(
"TRTLLM_GEMM_ALLREDUCE_FUSION_ENABLED", "1") == "1")
mpi_enabled = not mpi_disabled()
dtype_supported = config.torch_dtype in (torch.float16, torch.bfloat16)
tp_valid = self.mapping.tp_size > 1
quant_valid = self.is_nvfp4 is not None and self.is_nvfp4
device_supported = get_sm_version() >= 100
nvls_supported = ipc_nvls_supported()
use_fused_gemm_allreduce = all([
enable_gemm_allreduce_fusion, mpi_enabled, dtype_supported,
tp_valid, quant_valid, device_supported, nvls_supported
])
def check_in_out_features(in_features, out_features):
in_feature_valid = in_features % 128 == 0 and in_features >= 1024
out_feature_valid = out_features % 64 == 0 and out_features >= 1024
return all([in_feature_valid, out_feature_valid])
num_heads = config.num_attention_heads
head_dim = getattr(config, 'head_dim', None)
if not isinstance(head_dim, int):
head_dim = config.hidden_size // num_heads
in_features = num_heads * head_dim
out_features = config.hidden_size
in_out_features_valid = check_in_out_features(in_features, out_features)
attn_fused_gemm_allreduce = all(
[use_fused_gemm_allreduce, in_out_features_valid])
self.PRE_MLP_FUSION = not attn_fused_gemm_allreduce and self.mapping.has_tp(
) and not self.enable_attention_dp and self.enable_fusion
self.POST_MLP_FUSION = self.mapping.has_tp() and self.enable_fusion
in_features = config.intermediate_size
out_features = config.hidden_size
in_features_aligned_with_tp = in_features % self.mapping.tp_size == 0
in_out_features_valid = check_in_out_features(
in_features // self.mapping.tp_size, out_features)
mlp_fused_gemm_allreduce = all([
use_fused_gemm_allreduce, in_features_aligned_with_tp,
in_out_features_valid
])
self.POST_MLP_FUSION = not mlp_fused_gemm_allreduce and self.mapping.has_tp(
) and self.enable_fusion
if self.is_nvfp4:
self.pre_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4

View File

@ -14,7 +14,8 @@ from torch.nn.parameter import Parameter
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
from tensorrt_llm._torch.peft.lora.layer import LoraLayer
from tensorrt_llm._utils import is_device_integrated
from tensorrt_llm._utils import is_device_integrated, mpi_disabled
from tensorrt_llm.bindings import ipc_nvls_supported
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
AllReduceStrategy)
from tensorrt_llm.logger import logger
@ -307,6 +308,11 @@ class LinearMethodBase(ABC):
bias: Optional[torch.Tensor], *args, **kwargs):
raise NotImplementedError
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor], tp_rank: int,
tp_group: List[int], *args, **kwargs):
raise NotImplementedError
def load_weights(self,
module: Linear,
weights: List[Dict],
@ -908,8 +914,7 @@ class NVFP4LinearMethod(LinearMethodBase):
else:
module.register_parameter("bias", None)
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
def _input_prepare(self, module: Linear, input: torch.Tensor):
if isinstance(input, Fp4QuantizedTensor):
# Input is already quantized - this should not happen if pre_quant_scale exists
# because we disable FP4 output for attention output when pre_quant_scale is present
@ -935,7 +940,11 @@ class NVFP4LinearMethod(LinearMethodBase):
act_fp4, act_sf = torch.ops.trtllm.fp4_quantize(
input, module.input_scale, module.scaling_vector_size, False)
return act_fp4, act_sf
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
act_fp4, act_sf = self._input_prepare(module, input)
# Use unified interface - supports CUTLASS, cuBLASLt, CuteDSL
# Convert list to comma-separated string for torch.compile compatibility
allowed_backends_str = ','.join(module.nvfp4_allowed_backends)
@ -956,6 +965,21 @@ class NVFP4LinearMethod(LinearMethodBase):
output = output + bias
return output
def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor], tp_rank: int,
tp_group: List[int]):
act_fp4, act_sf = self._input_prepare(module, input)
output = torch.ops.trtllm.nvfp4_gemm_allreduce(
act_fp4, module.weight, act_sf, module.weight_scale, module.alpha,
module.dtype, tp_rank, tp_group)
# Take the dim of out_features if padded. Make sure the output is contiguous
if output.shape[-1] > module.out_features:
output = output[..., :module.out_features].contiguous()
if bias is not None:
output = output + bias
return output
def load_kv_scales(self, weights: List[Dict]):
k_scale, v_scale = [], []
for w in weights:
@ -2133,6 +2157,23 @@ class Linear(nn.Module):
self.use_custom_cublas_mm = use_custom_cublas_mm
self.lora = lora
mpi_enabled = not mpi_disabled()
dtype_supported = self.dtype in (torch.float16, torch.bfloat16)
in_features_aligned = self.in_features % 128 == 0
out_features_aligned = self.out_features % 64 == 0
tp_valid = self.tp_mode is not None and self.tp_mode == TensorParallelMode.ROW and self.tp_size > 1
quant_valid = self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4(
)
device_supported = get_sm_version() >= 100
nvls_supported = ipc_nvls_supported()
self.use_fused_gemm_allreduce = all([
self.reduce_output, mpi_enabled, dtype_supported,
in_features_aligned, out_features_aligned, tp_valid, quant_valid,
device_supported, nvls_supported
])
self.enable_cuda_core = False
if torch.cuda.is_available():
capability = torch.cuda.get_device_capability(
@ -2224,13 +2265,20 @@ class Linear(nn.Module):
lora_params: Optional[dict] | None = None,
layer_idx: Optional[int] | None = None):
output = self.quant_method.apply(self, input, bias)
if self.lora is not None and bool(lora_params):
lora_result = self.lora(input, lora_params, layer_idx)
if lora_result is not None:
output = output + lora_result
return output
def apply_linear_allreduce(self,
input,
bias,
layer_idx: Optional[int] | None = None):
output = self.quant_method.apply_linear_allreduce(
self, input, bias, self.tp_rank, self.mapping.tp_group)
return output
def _maybe_fuse_bias_into_allreduce(
self,
bias: Optional[torch.Tensor],
@ -2257,16 +2305,23 @@ class Linear(nn.Module):
layer_idx: Optional[int] = None,
) -> torch.Tensor:
if self.tp_mode == TensorParallelMode.ROW:
use_fused_gemm_allreduce = self.use_fused_gemm_allreduce and lora_params is None
if use_fused_gemm_allreduce and all_reduce_params is not None:
use_fused_gemm_allreduce = all_reduce_params.enable_allreduce and all_reduce_params.fusion_op == AllReduceFusionOp.NONE
bias = None if (self.tp_rank > 0) else self.bias
if self.reduce_output:
fuse_bias = self._maybe_fuse_bias_into_allreduce(
bias, all_reduce_params)
bias = None if fuse_bias else bias
output = self.apply_linear(input, bias, lora_params, layer_idx)
output = self.all_reduce(
output,
all_reduce_params=all_reduce_params,
)
if use_fused_gemm_allreduce:
output = self.apply_linear_allreduce(
input, self.bias, layer_idx)
else:
fuse_bias = self._maybe_fuse_bias_into_allreduce(
bias, all_reduce_params)
bias = None if fuse_bias else bias
output = self.apply_linear(input, bias, lora_params,
layer_idx)
output = self.all_reduce(
output, all_reduce_params=all_reduce_params)
else:
output = self.apply_linear(input, bias, lora_params, layer_idx)
elif self.tp_mode == TensorParallelMode.COLUMN:

View File

@ -7,11 +7,16 @@ import pytest
import torch
from mpi4py import MPI
from torch import nn
from utils.util import skip_pre_blackwell
import tensorrt_llm
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
from tensorrt_llm._torch.autotuner import autotune
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceParams
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.math_utils import pad_up
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
cloudpickle.register_pickle_by_value(sys.modules[__name__])
MPI.pickle.__init__(
@ -161,7 +166,6 @@ def row_linear_forward(x, hidden_size, dtype, tensor_parallel_size,
)
l0.load_weights([dict(weight=weights[0])])
l0.cuda()
xs = torch.chunk(x, 2, dim=-1)
l0 = torch.compile(l0, fullgraph=True)
output = l0.forward(xs[tensor_parallel_rank])
@ -333,3 +337,150 @@ def test_row_linear_norm_fusion(seq_len, hidden_size, mpi_pool_executor):
[l0_weight], hidden_size, dtype)] * 2))
for r in results:
assert r is True
def check_accuracy(a, b, atol, rtol, percent):
assert a.shape == b.shape
assert a.dtype == b.dtype
a = a.to(torch.float32)
b = b.to(torch.float32)
left = torch.abs(a - b)
right = atol + rtol * torch.abs(b)
count = torch.sum(left > right)
mismatch_percent = count / a.numel()
if not (mismatch_percent < 1 - percent):
raise Exception("Mismatch percentage is %f for rtol %f" %
(mismatch_percent, rtol))
@torch.inference_mode
def fp4_row_linear_allreduce(tp_size, local_rank, seq_len, output_size,
hidden_size, dtype, output_ref, x_sf_global,
w_sf_global, x_fp4s, w_fp4, x_sf_blocks,
w_sf_block_unswizzled):
output_ref = output_ref.cuda()
x_sf_global = x_sf_global.cuda()
w_sf_global = w_sf_global.cuda()
x_fp4 = x_fp4s[local_rank].cuda()
w_fp4 = w_fp4.cuda()
x_sf_block = x_sf_blocks[local_rank].cuda()
w_sf_block_unswizzled = w_sf_block_unswizzled.cuda()
qc = QuantConfig(quant_algo=QuantAlgo.NVFP4)
l0 = Linear(
in_features=hidden_size,
out_features=output_size,
bias=False,
dtype=dtype,
quant_config=qc,
mapping=Mapping(
world_size=tp_size,
tp_size=tp_size,
rank=local_rank,
),
tensor_parallel_mode=TensorParallelMode.ROW,
)
l0.load_weights([{
'input_scale':
1.0 / x_sf_global.cpu(),
'weight':
w_fp4.cpu(),
'weight_scale':
w_sf_block_unswizzled.view(torch.float8_e4m3fn),
'weight_scale_2':
1.0 / w_sf_global.cpu()
}])
l0.cuda()
# TODO: parameters['weight']' size mismatch at index 0
# l0 = torch.compile(l0)
with torch.inference_mode(), autotune():
output = l0.forward((x_fp4, x_sf_block))
torch.cuda.synchronize()
check_accuracy(output, output_ref, atol=0.05, rtol=0.05, percent=0.99)
def fp4_row_linear_allreduce_run_single_rank(func, tp_size, seq_len,
output_size, hidden_size, dtype,
output_ref, x_sf_global,
w_sf_global, x_fp4s, w_fp4,
x_sf_blocks,
w_sf_block_unswizzled):
local_rank = tensorrt_llm.mpi_rank()
torch.cuda.set_device(local_rank)
try:
func(tp_size, local_rank, seq_len, output_size, hidden_size, dtype,
output_ref, x_sf_global, w_sf_global, x_fp4s, w_fp4, x_sf_blocks,
w_sf_block_unswizzled)
except Exception as e:
print(f"Error: {e}")
raise
return True
@skip_pre_blackwell
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason='needs 2 GPUs to run this test')
@pytest.mark.parametrize("seq_len", [256, 400], ids=lambda x: f"seqlen:{x}")
@pytest.mark.parametrize("output_size", [32, 64], ids=lambda x: f"output:{x}")
@pytest.mark.parametrize("hidden_size", [128, 256], ids=lambda x: f"hidden:{x}")
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16],
ids=lambda x: f"dtype:{x}")
@pytest.mark.parametrize("mpi_pool_executor", [2],
indirect=True,
ids=lambda x: f"tp_size:{x}")
def test_fp4_row_linear_allreduce(seq_len, output_size, hidden_size, dtype,
mpi_pool_executor):
torch.manual_seed(42)
tp_size = mpi_pool_executor.num_workers
x = torch.randn((seq_len, hidden_size), dtype=dtype).cuda()
w = torch.randn((output_size, hidden_size), dtype=dtype).cuda()
scaling_vector_size = 16
x_sf_global = (448 * 6) / x.abs().max().float()
x_fp4, x_sf_block = torch.ops.trtllm.fp4_quantize(x, x_sf_global,
scaling_vector_size,
False)
w_sf_global = (448 * 6) / w.abs().max().float()
w_fp4, w_sf_block = torch.ops.trtllm.fp4_quantize(w, w_sf_global,
scaling_vector_size,
False)
w_sf_block_unswizzled = (torch.ops.trtllm.block_scale_interleave_reverse(
w_sf_block.cpu().view(pad_up(output_size, 128), -1)))
with torch.inference_mode():
alpha_ref = 1.0 / (w_sf_global * x_sf_global)
output_ref = torch.ops.trtllm.fp4_gemm(
x_fp4, w_fp4, x_sf_block, w_sf_block, alpha_ref,
fp4_utils.FP4GemmType.W4A4_NVFP4_NVFP4, dtype)
torch.cuda.synchronize()
xs = [x.contiguous().cuda() for x in torch.chunk(x, tp_size, dim=-1)]
x_fp4s = []
x_sf_blocks = []
for i in range(tp_size):
_fp4, _sf_block = torch.ops.trtllm.fp4_quantize(xs[i], x_sf_global,
scaling_vector_size,
False)
x_fp4s.append(_fp4.cpu())
x_sf_blocks.append(_sf_block.cpu())
output_ref = output_ref.cpu()
x_sf_global = x_sf_global.cpu()
w_sf_global = w_sf_global.cpu()
w_fp4 = w_fp4.cpu()
w_sf_block_unswizzled = w_sf_block_unswizzled.cpu()
results = mpi_pool_executor.map(
fp4_row_linear_allreduce_run_single_rank,
*zip(*[(fp4_row_linear_allreduce, tp_size, seq_len, output_size,
hidden_size, dtype, output_ref, x_sf_global, w_sf_global,
x_fp4s, w_fp4, x_sf_blocks, w_sf_block_unswizzled)] * tp_size))
for r in results:
assert r is True