mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat: Fallback to NCCL for various patterns when input size is large. (#4080)
* Fallback to NCCL for various patterns when input size is large. Move the previous implementation to cpp side. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> * Revising. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> --------- Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
parent
7d94c9561f
commit
bb7bcc75c2
@ -27,9 +27,11 @@
|
||||
#include "tensorrt_llm/kernels/userbuffers/ub_interface.h"
|
||||
#include "tensorrt_llm/runtime/torchUtils.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
#include "tensorrt_llm/thop/fp4Quantize.h"
|
||||
#include "tensorrt_llm/thop/fp8Op.h"
|
||||
#include "tensorrt_llm/thop/thUtils.h"
|
||||
#include "tensorrt_llm/thop/userbuffersTensor.h"
|
||||
|
||||
#include "userbuffersTensor.h"
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
#include <ATen/cuda/EmptyTensor.h>
|
||||
#include <nccl.h>
|
||||
@ -160,21 +162,21 @@ public:
|
||||
|
||||
// If strategy is set to UB, UB must be used as UB impl output is special and cannot be used
|
||||
// by others.
|
||||
AllReduceStrategyType runtimeStrategy = getRuntimeStrategy(seq_len, size);
|
||||
AllReduceStrategyType runtime_strategy = getRuntimeStrategy(seq_len, size);
|
||||
|
||||
// Log runtime strategy
|
||||
auto const rank = COMM_SESSION.getRank();
|
||||
logRunTimeStrategy(runtimeStrategy, rank);
|
||||
logRunTimeStrategy(runtime_strategy, rank);
|
||||
|
||||
// Dispatch to different allreduce implementations
|
||||
switch (runtimeStrategy)
|
||||
switch (runtime_strategy)
|
||||
{
|
||||
case AllReduceStrategyType::UB: return runUBAllReduce(input, residual, norm_weight, scale, bias);
|
||||
case AllReduceStrategyType::NCCL: return runNCCLAllReduce(input, residual, norm_weight, scale, bias);
|
||||
case AllReduceStrategyType::MIN_LATENCY:
|
||||
case AllReduceStrategyType::ONESHOT:
|
||||
case AllReduceStrategyType::TWOSHOT:
|
||||
return runFusionAllReduce(input, residual, norm_weight, scale, bias, workspace, runtimeStrategy);
|
||||
return runFusionAllReduce(input, residual, norm_weight, scale, bias, workspace, runtime_strategy);
|
||||
default: TORCH_CHECK(false, "Invalid runtime strategy"); return {};
|
||||
}
|
||||
}
|
||||
@ -280,37 +282,18 @@ private:
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
int size = input.numel();
|
||||
int hidden_size = input.size(-1);
|
||||
|
||||
if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM)
|
||||
torch::Tensor reduce_output = torch::empty_like(input);
|
||||
NCCLCHECK(ncclAllReduce(input.data_ptr(), reduce_output.mutable_data_ptr(), size, (*getDtypeMap())[mType],
|
||||
ncclSum, *mNcclComm, stream));
|
||||
|
||||
if (mOp == AllReduceFusionOp::NONE)
|
||||
{
|
||||
torch::Tensor norm_out = torch::empty_like(input);
|
||||
torch::Tensor residual_out = torch::empty_like(input);
|
||||
|
||||
NCCLCHECK(ncclAllReduce(input.data_ptr(), residual_out.mutable_data_ptr(), size, (*getDtypeMap())[mType],
|
||||
ncclSum, *mNcclComm, stream));
|
||||
tensorrt_llm::kernels::AllReduceParams params;
|
||||
params.fusion_params.bias_buffer = bias ? bias.value().data_ptr() : nullptr;
|
||||
params.fusion_params.residual_buffer = residual ? residual.value().data_ptr() : nullptr;
|
||||
params.fusion_params.weight_buffer = norm_weight ? norm_weight.value().data_ptr() : nullptr;
|
||||
params.local_output_buffer_ptr = norm_out.mutable_data_ptr();
|
||||
params.elts_total = size;
|
||||
|
||||
params.fusion_params.hidden_size = hidden_size;
|
||||
params.fusion_params.eps = mEps;
|
||||
params.fusion_params.intermediate_buffer = residual_out.mutable_data_ptr();
|
||||
tensorrt_llm::kernels::residualRmsNorm(params, mType, stream, mOp);
|
||||
return {norm_out, residual_out};
|
||||
return {reduce_output};
|
||||
}
|
||||
else if (mOp == AllReduceFusionOp::NONE)
|
||||
{
|
||||
torch::Tensor output = torch::empty_like(input);
|
||||
NCCLCHECK(ncclAllReduce(input.data_ptr(), output.mutable_data_ptr(), size, (*getDtypeMap())[mType], ncclSum,
|
||||
*mNcclComm, stream));
|
||||
return {output};
|
||||
}
|
||||
TORCH_CHECK(false, "NCCL encounters unsupported fusion operation: " + tensorrt_llm::kernels::toString(mOp));
|
||||
return {};
|
||||
|
||||
// Treat any other patterns as fallback cases.
|
||||
return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, reduce_output);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> runFusionAllReduce(torch::Tensor const& input,
|
||||
@ -413,22 +396,22 @@ private:
|
||||
|| mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4)
|
||||
{
|
||||
// TODO: Better check for each pattern
|
||||
int64_t sfVecSize = 16;
|
||||
int64_t sf_vec_size = 16;
|
||||
int64_t m = 1;
|
||||
auto const& inputShape = input.sizes();
|
||||
auto const& r = inputShape.size();
|
||||
auto const& input_shape = input.sizes();
|
||||
auto const& r = input_shape.size();
|
||||
TORCH_CHECK(r >= 2, "Input should be >=2D tensor.");
|
||||
for (size_t i = 0; i < r - 1; i++)
|
||||
{
|
||||
m *= inputShape[i];
|
||||
m *= input_shape[i];
|
||||
}
|
||||
auto const k = inputShape[r - 1];
|
||||
TORCH_CHECK(k % sfVecSize == 0, "Input should be divisible by sfVecSize.");
|
||||
std::vector<int64_t> outputShape(inputShape.begin(), inputShape.end());
|
||||
outputShape[r - 1] = k / 2;
|
||||
auto const k = input_shape[r - 1];
|
||||
TORCH_CHECK(k % sf_vec_size == 0, "Input should be divisible by sfVecSize.");
|
||||
std::vector<int64_t> output_shape(input_shape.begin(), input_shape.end());
|
||||
output_shape[r - 1] = k / 2;
|
||||
|
||||
quant_out = at::detail::empty_cuda(outputShape, FLOAT4_E2M1X2, input.device(), std::nullopt);
|
||||
scale_out = at::detail::empty_cuda({tensorrt_llm::computeFP4SwizzledLayoutSFSize(m, k / sfVecSize)},
|
||||
quant_out = at::detail::empty_cuda(output_shape, FLOAT4_E2M1X2, input.device(), std::nullopt);
|
||||
scale_out = at::detail::empty_cuda({tensorrt_llm::computeFP4SwizzledLayoutSFSize(m, k / sf_vec_size)},
|
||||
SF_DTYPE, input.device(), std::nullopt);
|
||||
residual_out = torch::empty_like(residual.value());
|
||||
|
||||
@ -495,18 +478,86 @@ private:
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> fallbackRunSubsequentOps(torch::Tensor const& input,
|
||||
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
|
||||
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias,
|
||||
torch::Tensor& reduce_output) noexcept
|
||||
{
|
||||
// If we reach here, it means the extra fallback operations are required.
|
||||
// All patterns are broken into ALlReduce + residual_rms_norm + following operations (quantization, etc.)
|
||||
auto const size = input.numel();
|
||||
auto const hidden_size = input.size(-1);
|
||||
auto const stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
|
||||
torch::Tensor norm_out = torch::empty_like(input);
|
||||
|
||||
tensorrt_llm::kernels::AllReduceParams params;
|
||||
params.fusion_params.bias_buffer = bias ? bias.value().data_ptr() : nullptr;
|
||||
params.fusion_params.residual_buffer = residual ? residual.value().data_ptr() : nullptr;
|
||||
params.fusion_params.weight_buffer = norm_weight ? norm_weight.value().data_ptr() : nullptr;
|
||||
params.local_output_buffer_ptr = norm_out.mutable_data_ptr();
|
||||
params.elts_total = size;
|
||||
|
||||
params.fusion_params.hidden_size = hidden_size;
|
||||
params.fusion_params.eps = mEps;
|
||||
params.fusion_params.intermediate_buffer = reduce_output.mutable_data_ptr();
|
||||
tensorrt_llm::kernels::residualRmsNorm(params, mType, stream, AllReduceFusionOp::RESIDUAL_RMS_NORM);
|
||||
|
||||
// If no quantization is needed, return the norm and residual outputs.
|
||||
if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM)
|
||||
{
|
||||
return {norm_out, reduce_output};
|
||||
}
|
||||
|
||||
const int64_t sf_vecsize = 16;
|
||||
bool const sf_use_ue8m0 = false;
|
||||
bool const is_sf_swizzled_layout = true;
|
||||
TORCH_CHECK(scale, "scale is required for quantization ops");
|
||||
|
||||
// Attach the subsequent operations after the residual RMS norm all-reduce and return the final outputs.
|
||||
switch (mOp)
|
||||
{
|
||||
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8:
|
||||
{
|
||||
auto [quant_out, scale_out] = torch_ext::symmetric_static_quantize_per_tensor(norm_out, scale.value());
|
||||
return {quant_out, reduce_output};
|
||||
}
|
||||
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4:
|
||||
{
|
||||
auto [quant_out, scale_out]
|
||||
= torch_ext::fp4_quantize(norm_out, scale.value(), sf_vecsize, sf_use_ue8m0, is_sf_swizzled_layout);
|
||||
return {quant_out, scale_out, reduce_output};
|
||||
}
|
||||
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8:
|
||||
{
|
||||
auto [quant_out, scale_out] = torch_ext::symmetric_static_quantize_per_tensor(norm_out, scale.value());
|
||||
return {norm_out, quant_out, reduce_output};
|
||||
}
|
||||
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4:
|
||||
{
|
||||
auto [quant_out, scale_out]
|
||||
= torch_ext::fp4_quantize(norm_out, scale.value(), sf_vecsize, sf_use_ue8m0, is_sf_swizzled_layout);
|
||||
return {norm_out, quant_out, scale_out, reduce_output};
|
||||
}
|
||||
default: break;
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "Unsupported fusion operation: " + tensorrt_llm::kernels::toString(mOp));
|
||||
return {};
|
||||
}
|
||||
|
||||
AllReduceStrategyType getRuntimeStrategy(size_t seq_len, size_t size) noexcept
|
||||
{
|
||||
static char* forceNcclAllReduceStrategyChar = std::getenv("FORCE_NCCL_ALL_REDUCE_STRATEGY");
|
||||
bool forceNcclAllReduceStrategy = (forceNcclAllReduceStrategyChar != nullptr);
|
||||
AllReduceStrategyType runtimeStrategy;
|
||||
static char* force_nccl_all_reduce_strategy_char = std::getenv("FORCE_NCCL_ALL_REDUCE_STRATEGY");
|
||||
bool force_nccl_all_reduce_strategy = (force_nccl_all_reduce_strategy_char != nullptr);
|
||||
AllReduceStrategyType runtime_strategy;
|
||||
if (mStrategy == AllReduceStrategyType::UB)
|
||||
{
|
||||
runtimeStrategy = AllReduceStrategyType::UB;
|
||||
runtime_strategy = AllReduceStrategyType::UB;
|
||||
}
|
||||
else if (forceNcclAllReduceStrategy || mStrategy == AllReduceStrategyType::NCCL)
|
||||
else if (force_nccl_all_reduce_strategy || mStrategy == AllReduceStrategyType::NCCL)
|
||||
{
|
||||
runtimeStrategy = AllReduceStrategyType::NCCL;
|
||||
runtime_strategy = AllReduceStrategyType::NCCL;
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -514,14 +565,14 @@ private:
|
||||
static char* ifForBenchMark = std::getenv("OVERRIDE_HEURISTIC_ALLREDUCE_STRATEGY");
|
||||
if (ifForBenchMark != nullptr)
|
||||
{
|
||||
runtimeStrategy = mStrategy;
|
||||
runtime_strategy = mStrategy;
|
||||
}
|
||||
else
|
||||
{
|
||||
runtimeStrategy = selectImplementation(seq_len, size, mGroup.size(), mType);
|
||||
runtime_strategy = selectImplementation(seq_len, size, mGroup.size(), mType);
|
||||
}
|
||||
}
|
||||
return runtimeStrategy;
|
||||
return runtime_strategy;
|
||||
}
|
||||
|
||||
void logRunTimeStrategy(AllReduceStrategyType strategy, int rank) noexcept
|
||||
@ -557,9 +608,9 @@ private:
|
||||
static std::map<std::set<int>, std::tuple<bool, bool>> cache;
|
||||
if (cache.find(mGroup) != cache.end())
|
||||
{
|
||||
auto [isNVLINKSupported, isP2PSupported] = cache[mGroup];
|
||||
mIsNVLINKSupported = isNVLINKSupported;
|
||||
mIsP2PSupported = isP2PSupported;
|
||||
auto [is_NVLINK_supported, is_P2P_supported] = cache[mGroup];
|
||||
mIsNVLINKSupported = is_NVLINK_supported;
|
||||
mIsP2PSupported = is_P2P_supported;
|
||||
return;
|
||||
}
|
||||
setGroupTopology();
|
||||
@ -570,8 +621,8 @@ private:
|
||||
{
|
||||
auto const rank = COMM_SESSION.getRank();
|
||||
TLLM_LOG_INFO("Detecting local TP group for rank %d", rank);
|
||||
std::set<int> localGroup = getLocalGroup(mGroup);
|
||||
if (mGroup.size() != localGroup.size())
|
||||
std::set<int> local_group = getLocalGroup(mGroup);
|
||||
if (mGroup.size() != local_group.size())
|
||||
{
|
||||
mIsP2PSupported = false;
|
||||
mIsNVLINKSupported = false;
|
||||
@ -580,26 +631,27 @@ private:
|
||||
}
|
||||
TLLM_LOG_INFO("TP group is intra-node for rank %d", rank);
|
||||
|
||||
NvmlManager nvmlManager;
|
||||
std::unordered_set<int> visitedDevice;
|
||||
NvmlManager nvml_manager;
|
||||
std::unordered_set<int> visited_device;
|
||||
mIsP2PSupported = true;
|
||||
mIsNVLINKSupported = true;
|
||||
|
||||
// Use cudaDeviceCanAccessPeer to determine whether p2p is supported,
|
||||
// and use nvml to determine whether there are nvlink links between ranks.
|
||||
for (int firstDeviceId : localGroup)
|
||||
for (int first_device_id : local_group)
|
||||
{
|
||||
for (int secondDeviceId : localGroup)
|
||||
for (int second_device_id : local_group)
|
||||
{
|
||||
if (firstDeviceId == secondDeviceId || visitedDevice.find(secondDeviceId) != visitedDevice.end())
|
||||
if (first_device_id == second_device_id
|
||||
|| visited_device.find(second_device_id) != visited_device.end())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
int canAccessPeer = 0;
|
||||
TLLM_CUDA_CHECK(cudaDeviceCanAccessPeer(&canAccessPeer, firstDeviceId, secondDeviceId));
|
||||
int can_access_peer = 0;
|
||||
TLLM_CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, first_device_id, second_device_id));
|
||||
|
||||
if (!canAccessPeer)
|
||||
if (!can_access_peer)
|
||||
{
|
||||
mIsP2PSupported = false;
|
||||
mIsNVLINKSupported = false;
|
||||
@ -607,31 +659,31 @@ private:
|
||||
return;
|
||||
}
|
||||
|
||||
nvmlDevice_t firstDevice;
|
||||
NVML_CHECK(nvmlDeviceGetHandleByIndex(firstDeviceId, &firstDevice));
|
||||
nvmlDevice_t first_device;
|
||||
NVML_CHECK(nvmlDeviceGetHandleByIndex(first_device_id, &first_device));
|
||||
|
||||
bool isNVLINK = false;
|
||||
bool is_NVLINK = false;
|
||||
|
||||
for (unsigned int link = 0; link < NVML_NVLINK_MAX_LINKS; link++)
|
||||
{
|
||||
nvmlPciInfo_t remotePciInfo;
|
||||
if (nvmlDeviceGetNvLinkRemotePciInfo_v2(firstDevice, link, &remotePciInfo) != NVML_SUCCESS)
|
||||
nvmlPciInfo_t remote_pci_info;
|
||||
if (nvmlDeviceGetNvLinkRemotePciInfo_v2(first_device, link, &remote_pci_info) != NVML_SUCCESS)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
nvmlDevice_t remoteDevice;
|
||||
auto const result = nvmlDeviceGetHandleByPciBusId_v2(remotePciInfo.busId, &remoteDevice);
|
||||
nvmlDevice_t remote_device;
|
||||
auto const result = nvmlDeviceGetHandleByPciBusId_v2(remote_pci_info.busId, &remote_device);
|
||||
|
||||
if (result == NVML_SUCCESS)
|
||||
{
|
||||
// Two GPUs are connected directly through nvlink
|
||||
unsigned int remoteDeviceId;
|
||||
NVML_CHECK(nvmlDeviceGetIndex(remoteDevice, &remoteDeviceId));
|
||||
unsigned int remote_device_id;
|
||||
NVML_CHECK(nvmlDeviceGetIndex(remote_device, &remote_device_id));
|
||||
|
||||
if (remoteDeviceId == static_cast<unsigned int>(secondDeviceId))
|
||||
if (remote_device_id == static_cast<unsigned int>(second_device_id))
|
||||
{
|
||||
isNVLINK = true;
|
||||
is_NVLINK = true;
|
||||
}
|
||||
}
|
||||
else if (result == NVML_ERROR_NOT_FOUND)
|
||||
@ -640,21 +692,21 @@ private:
|
||||
// now remotePciInfo represents the pci information of nvswitch,
|
||||
// determine whether nvlink is supported by whether two GPUs are connected to the same
|
||||
// nvswitch.
|
||||
nvmlDevice_t secondDevice;
|
||||
NVML_CHECK(nvmlDeviceGetHandleByIndex(secondDeviceId, &secondDevice));
|
||||
nvmlDevice_t second_device;
|
||||
NVML_CHECK(nvmlDeviceGetHandleByIndex(second_device_id, &second_device));
|
||||
|
||||
for (unsigned int secondLink = 0; secondLink < NVML_NVLINK_MAX_LINKS; secondLink++)
|
||||
for (unsigned int second_link = 0; second_link < NVML_NVLINK_MAX_LINKS; second_link++)
|
||||
{
|
||||
nvmlPciInfo_t secondRemotePciInfo;
|
||||
if (nvmlDeviceGetNvLinkRemotePciInfo_v2(secondDevice, secondLink, &secondRemotePciInfo)
|
||||
nvmlPciInfo_t second_remote_pci_info;
|
||||
if (nvmlDeviceGetNvLinkRemotePciInfo_v2(second_device, second_link, &second_remote_pci_info)
|
||||
!= NVML_SUCCESS)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (strcmp(remotePciInfo.busId, secondRemotePciInfo.busId) == 0)
|
||||
if (strcmp(remote_pci_info.busId, second_remote_pci_info.busId) == 0)
|
||||
{
|
||||
isNVLINK = true;
|
||||
is_NVLINK = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -664,28 +716,74 @@ private:
|
||||
NVML_CHECK(result);
|
||||
}
|
||||
|
||||
if (isNVLINK)
|
||||
if (is_NVLINK)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
mIsNVLINKSupported &= isNVLINK;
|
||||
mIsNVLINKSupported &= is_NVLINK;
|
||||
}
|
||||
visitedDevice.insert(firstDeviceId);
|
||||
visited_device.insert(first_device_id);
|
||||
}
|
||||
}
|
||||
|
||||
bool ifFallbackToNCCL(size_t seq_len, size_t message_size_bytes, size_t max_workspace_size, bool is_auto) noexcept
|
||||
{
|
||||
// If messageSize is less than maxWorkspaceSize, use NCCL, regardless of the fusion type.
|
||||
if (message_size_bytes > max_workspace_size)
|
||||
{
|
||||
if (!is_auto)
|
||||
{
|
||||
TLLM_LOG_WARNING(
|
||||
"Since messageSize is greater than maxWorkspaceSize, fallback to AllReduceStrategy: NCCL");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// If Peer to Peer is not supported, fallback to NCCL.
|
||||
if (!mIsP2PSupported)
|
||||
{
|
||||
if (!is_auto)
|
||||
{
|
||||
TLLM_LOG_WARNING("Since Peer to Peer not supported, fallback to AllReduceStrategy: NCCL");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// If NVLINK is not supported, fallback to NCCL.
|
||||
if (!mIsNVLINKSupported)
|
||||
{
|
||||
if (!is_auto)
|
||||
{
|
||||
TLLM_LOG_WARNING("Since NVLINK not supported, fallback to AllReduceStrategy: NCCL");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
AllReduceStrategyType selectImplementation(
|
||||
size_t seq_len, size_t messageSize, int worldSize, nvinfer1::DataType type) noexcept
|
||||
size_t seq_len, size_t message_size, int world_size, nvinfer1::DataType type) noexcept
|
||||
{
|
||||
// Check that heuristic is only applied when AUTO is set.
|
||||
bool const isAuto = (mStrategy == AllReduceStrategyType::AUTO);
|
||||
bool const is_auto = (mStrategy == AllReduceStrategyType::AUTO);
|
||||
auto const message_size_bytes = message_size * tensorrt_llm::common::getDTypeSize(type);
|
||||
auto const max_workspace_size
|
||||
= tensorrt_llm::utils::customAllReduceUtils::getMaxRequiredWorkspaceSize(world_size);
|
||||
|
||||
// This rule based heuristic only chooses NCCL and MIN_LATENCY strategies.
|
||||
if (ifFallbackToNCCL(seq_len, message_size_bytes, max_workspace_size, is_auto))
|
||||
{
|
||||
return AllReduceStrategyType::NCCL;
|
||||
}
|
||||
|
||||
// Only the intersection of the supported fusion types of two implementations will go through the heuristic.
|
||||
// Otherwise, MIN_LATENCY strategy will be returned due to more fusion patterns it can support.
|
||||
// This rule based heuristic only chooses between NCCL and MIN_LATENCY strategies.
|
||||
|
||||
// Heurisitic will only be applied on NONE and RESIDUAL_RMS_NORM fusion types.
|
||||
// Because NCCL might be faster on some large messageSize cases.
|
||||
// Otherwise, MIN_LATENCY strategy will be directly returned due to more fusions it can support.
|
||||
// TODO: NCCL AllReduce + subsequent quantization ops (as fallback) can also support the fusion types.
|
||||
// This should be compared with MIN_LATENCY fused kernels to determine the best strategy.
|
||||
switch (mOp)
|
||||
{
|
||||
case AllReduceFusionOp::NONE:
|
||||
@ -693,89 +791,60 @@ private:
|
||||
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8:
|
||||
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8:
|
||||
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4:
|
||||
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4:
|
||||
default: return AllReduceStrategyType::MIN_LATENCY;
|
||||
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: return AllReduceStrategyType::MIN_LATENCY;
|
||||
// Suppose NCCL has fallback implementations for all fusion types.
|
||||
default: return AllReduceStrategyType::NCCL;
|
||||
}
|
||||
|
||||
// Check mOp to be supported by the heuristic.
|
||||
TORCH_CHECK(mOp == AllReduceFusionOp::NONE || mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM,
|
||||
"Only NONE and RESIDUAL_RMS_NORM are supported for heuristic.");
|
||||
"Only NONE and RESIDUAL_RMS_NORM are supported for NCCL/MIN_LATENCY heuristic.");
|
||||
|
||||
// If AUTO is set, but P2P is not supported, fallback to NCCL.
|
||||
if (!mIsP2PSupported)
|
||||
// Default to NCCL.
|
||||
AllReduceStrategyType strategy = AllReduceStrategyType::NCCL;
|
||||
|
||||
// Currently we will not remove ONESHOT and TWOSHOT from the strategy list
|
||||
// But torch flow user should not use them, but use AUTO or MIN_LATENCY instead.
|
||||
// NOTICE: When a fusion type is not supported by the corresponding strategy but strategy is not AUTO,
|
||||
// user should guarantee the correctness of the fusion pattern dispatching.
|
||||
if (!is_auto)
|
||||
{
|
||||
if (!isAuto)
|
||||
if (mStrategy == AllReduceStrategyType::ONESHOT || mStrategy == AllReduceStrategyType::TWOSHOT)
|
||||
{
|
||||
TLLM_LOG_WARNING("Since Peer to Peer not supported, fallback to AllReduceStrategy: NCCL");
|
||||
}
|
||||
return AllReduceStrategyType::NCCL;
|
||||
}
|
||||
|
||||
// If AUTO is set, but NVLINK is not supported, fallback to NCCL.
|
||||
if (isAuto && !mIsNVLINKSupported)
|
||||
{
|
||||
return AllReduceStrategyType::NCCL;
|
||||
}
|
||||
|
||||
auto const maxWorkspaceSize = tensorrt_llm::utils::customAllReduceUtils::getMaxRequiredWorkspaceSize(worldSize);
|
||||
|
||||
AllReduceStrategyType strat = AllReduceStrategyType::NCCL;
|
||||
auto const messageSizeBytes = messageSize * tensorrt_llm::common::getDTypeSize(type);
|
||||
|
||||
if (messageSizeBytes <= maxWorkspaceSize)
|
||||
{
|
||||
// Currently we will not remove ONESHOT and TWOSHOT from the strategy list
|
||||
// But torch flow user should not use them, but use AUTO or MIN_LATENCY instead.
|
||||
// NOTICE: When a fusion type is not supported by the corresponding strategy but strategy is not AUTO,
|
||||
// user should guarantee the correctness of the fusion pattern dispatching.
|
||||
if (!isAuto)
|
||||
{
|
||||
if (mStrategy == AllReduceStrategyType::ONESHOT || mStrategy == AllReduceStrategyType::TWOSHOT)
|
||||
{
|
||||
strat = AllReduceStrategyType::MIN_LATENCY;
|
||||
}
|
||||
else
|
||||
{
|
||||
strat = mStrategy;
|
||||
}
|
||||
}
|
||||
else if (worldSize <= 2)
|
||||
{
|
||||
strat = AllReduceStrategyType::MIN_LATENCY;
|
||||
}
|
||||
else if (worldSize <= 4)
|
||||
{
|
||||
if (messageSizeBytes < 1 * 1000 * 1000)
|
||||
{
|
||||
strat = AllReduceStrategyType::MIN_LATENCY;
|
||||
}
|
||||
else
|
||||
{
|
||||
strat = AllReduceStrategyType::NCCL;
|
||||
}
|
||||
strategy = AllReduceStrategyType::MIN_LATENCY;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (messageSizeBytes < 500 * 1000)
|
||||
{
|
||||
strat = AllReduceStrategyType::MIN_LATENCY;
|
||||
}
|
||||
else
|
||||
{
|
||||
strat = AllReduceStrategyType::NCCL;
|
||||
}
|
||||
strategy = mStrategy;
|
||||
}
|
||||
}
|
||||
else if (world_size <= 2)
|
||||
{
|
||||
strategy = AllReduceStrategyType::MIN_LATENCY;
|
||||
}
|
||||
else if (world_size <= 4)
|
||||
{
|
||||
if (message_size_bytes < 1 * 1000 * 1000)
|
||||
{
|
||||
strategy = AllReduceStrategyType::MIN_LATENCY;
|
||||
}
|
||||
else
|
||||
{
|
||||
strategy = AllReduceStrategyType::NCCL;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!isAuto)
|
||||
if (message_size_bytes < 500 * 1000)
|
||||
{
|
||||
TLLM_LOG_WARNING("Since messageSize > maxWorkspace, fallback to AllReduceStrategy: NCCL");
|
||||
strategy = AllReduceStrategyType::MIN_LATENCY;
|
||||
}
|
||||
else
|
||||
{
|
||||
strategy = AllReduceStrategyType::NCCL;
|
||||
}
|
||||
strat = AllReduceStrategyType::NCCL;
|
||||
}
|
||||
|
||||
return strat;
|
||||
return strategy;
|
||||
}
|
||||
|
||||
private:
|
||||
@ -793,10 +862,10 @@ private:
|
||||
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
std::vector<torch::Tensor> allreduce(torch::Tensor input, torch::optional<torch::Tensor> residual,
|
||||
torch::optional<torch::Tensor> norm_weight, torch::optional<torch::Tensor> scale,
|
||||
torch::optional<torch::Tensor> bias, torch::optional<torch::Tensor> workspace, torch::List<int64_t> group_,
|
||||
int64_t const strategy_, int64_t const fusion_op_, double const eps_)
|
||||
std::vector<torch::Tensor> allreduce(torch::Tensor const& input, torch::optional<torch::Tensor> const& residual,
|
||||
torch::optional<torch::Tensor> const& norm_weight, torch::optional<torch::Tensor> const& scale,
|
||||
torch::optional<torch::Tensor> const& bias, torch::optional<torch::Tensor> const& workspace,
|
||||
torch::List<int64_t> const& group_, int64_t const strategy_, int64_t const fusion_op_, double const eps_)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type());
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/thop/fp4Quantize.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/quantization.h"
|
||||
#include "tensorrt_llm/thop/thUtils.h"
|
||||
|
||||
28
cpp/tensorrt_llm/thop/fp4Quantize.h
Normal file
28
cpp/tensorrt_llm/thop/fp4Quantize.h
Normal file
@ -0,0 +1,28 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/thop/thUtils.h"
|
||||
|
||||
#include <ATen/cuda/EmptyTensor.h>
|
||||
#include <cstdint>
|
||||
|
||||
namespace torch_ext
|
||||
{
|
||||
std::tuple<torch::Tensor, torch::Tensor> fp4_quantize(torch::Tensor const& self, torch::Tensor const& globalScale,
|
||||
int64_t sfVecSize, bool sfUseUE8M0, bool isSfSwizzledLayout);
|
||||
}
|
||||
@ -14,6 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/thop/fp8Op.h"
|
||||
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
|
||||
#include "tensorrt_llm/common/cudaFp8Utils.h"
|
||||
#include "tensorrt_llm/thop/thUtils.h"
|
||||
|
||||
44
cpp/tensorrt_llm/thop/fp8Op.h
Normal file
44
cpp/tensorrt_llm/thop/fp8Op.h
Normal file
@ -0,0 +1,44 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/quantization.h"
|
||||
#include "tensorrt_llm/thop/thUtils.h"
|
||||
|
||||
#include <ATen/cuda/EmptyTensor.h>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace torch_ext
|
||||
{
|
||||
std::tuple<torch::Tensor, torch::Tensor> symmetric_quantize_weight(torch::Tensor weight);
|
||||
std::tuple<torch::Tensor, torch::Tensor> symmetric_quantize_activation(torch::Tensor activation);
|
||||
std::tuple<torch::Tensor, torch::Tensor> symmetric_quantize_per_tensor(torch::Tensor input);
|
||||
std::tuple<torch::Tensor, torch::Tensor> symmetric_static_quantize_weight(torch::Tensor weight, torch::Tensor scales);
|
||||
std::tuple<torch::Tensor, torch::Tensor> symmetric_static_quantize_activation(
|
||||
torch::Tensor activation, torch::Tensor scales);
|
||||
std::tuple<torch::Tensor, torch::Tensor> symmetric_static_quantize_per_tensor(
|
||||
torch::Tensor input, torch::Tensor scales);
|
||||
|
||||
torch::Tensor symmetric_dequantize_weight(torch::Tensor weight, torch::Tensor scales);
|
||||
torch::Tensor symmetric_dequantize_activation(torch::Tensor activation, torch::Tensor scales);
|
||||
torch::Tensor symmetric_dequantize_per_tensor(torch::Tensor input, torch::Tensor scales);
|
||||
|
||||
} // namespace torch_ext
|
||||
@ -142,67 +142,12 @@ class AllReduce(nn.Module):
|
||||
self.mapping = mapping
|
||||
self.workspace = None
|
||||
self.strategy = strategy
|
||||
self.max_workspace_size = CustomAllReduceHelper.max_workspace_size_auto(
|
||||
self.mapping.tp_size, support_deterministic=False)
|
||||
|
||||
self.fallback_func_mapping = {
|
||||
AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8:
|
||||
self.fallback_residual_rms_norm_quant_fp8,
|
||||
AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8:
|
||||
self.fallback_residual_rms_norm_out_quant_fp8,
|
||||
AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4:
|
||||
self.fallback_residual_rms_norm_quant_nvfp4,
|
||||
AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4:
|
||||
self.fallback_residual_rms_norm_out_quant_nvfp4,
|
||||
}
|
||||
|
||||
if self.mapping.tp_size > 1:
|
||||
# When Strategy is UB, it is guaranteed that the workspace is not used.
|
||||
if self.strategy != AllReduceStrategy.UB:
|
||||
self.workspace = get_allreduce_workspace(self.mapping)
|
||||
|
||||
@staticmethod
|
||||
def fallback_residual_rms_norm_quant_fp8(
|
||||
output: Tuple[torch.Tensor, ...],
|
||||
all_reduce_params: AllReduceParams,
|
||||
):
|
||||
norm_out, residual_out = output
|
||||
quant_fp8, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
|
||||
norm_out, all_reduce_params.scale)
|
||||
return quant_fp8, residual_out
|
||||
|
||||
@staticmethod
|
||||
def fallback_residual_rms_norm_out_quant_fp8(
|
||||
output: Tuple[torch.Tensor, ...],
|
||||
all_reduce_params: AllReduceParams,
|
||||
):
|
||||
norm_out, residual_out = output
|
||||
quant_fp8, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
|
||||
norm_out, all_reduce_params.scale)
|
||||
return norm_out, quant_fp8, residual_out
|
||||
|
||||
@staticmethod
|
||||
def fallback_residual_rms_norm_quant_nvfp4(
|
||||
output: Tuple[torch.Tensor, ...],
|
||||
all_reduce_params: AllReduceParams,
|
||||
):
|
||||
norm_out, residual_out = output
|
||||
quant_fp4, scale_factor = torch.ops.trtllm.fp4_quantize(
|
||||
norm_out, all_reduce_params.scale, 16, False)
|
||||
|
||||
return quant_fp4, scale_factor, residual_out
|
||||
|
||||
@staticmethod
|
||||
def fallback_residual_rms_norm_out_quant_nvfp4(
|
||||
output: Tuple[torch.Tensor, ...],
|
||||
all_reduce_params: AllReduceParams,
|
||||
):
|
||||
norm_out, residual_out = output
|
||||
quant_fp4, scale_factor = torch.ops.trtllm.fp4_quantize(
|
||||
norm_out, all_reduce_params.scale, 16, False)
|
||||
|
||||
return norm_out, quant_fp4, scale_factor, residual_out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
@ -241,16 +186,6 @@ class AllReduce(nn.Module):
|
||||
if all_reduce_params is None:
|
||||
all_reduce_params = AllReduceParams()
|
||||
|
||||
strategy = self.strategy
|
||||
fusion_op = all_reduce_params.fusion_op
|
||||
|
||||
# If the input size is larger than the max workspace size, fallback to NCCL strategy
|
||||
if input.numel() > self.max_workspace_size \
|
||||
and all_reduce_params.fusion_op != AllReduceFusionOp.NONE \
|
||||
and all_reduce_params.fusion_op != AllReduceFusionOp.RESIDUAL_RMS_NORM:
|
||||
strategy = AllReduceStrategy.NCCL
|
||||
fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
|
||||
|
||||
output = torch.ops.trtllm.allreduce(
|
||||
input=input,
|
||||
residual=all_reduce_params.residual,
|
||||
@ -259,17 +194,11 @@ class AllReduce(nn.Module):
|
||||
bias=all_reduce_params.bias,
|
||||
workspace=self.workspace,
|
||||
group=self.mapping.tp_group,
|
||||
strategy=strategy,
|
||||
op=fusion_op,
|
||||
strategy=self.strategy,
|
||||
op=all_reduce_params.fusion_op,
|
||||
eps=all_reduce_params.eps,
|
||||
)
|
||||
|
||||
if input.numel() > self.max_workspace_size \
|
||||
and all_reduce_params.fusion_op != AllReduceFusionOp.NONE \
|
||||
and all_reduce_params.fusion_op != AllReduceFusionOp.RESIDUAL_RMS_NORM:
|
||||
output = self.fallback_func_mapping[all_reduce_params.fusion_op](
|
||||
output, all_reduce_params)
|
||||
|
||||
return output if len(output) > 1 else output[0]
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user