mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
[TRTLLM-10296][fix] Fix the potential misaligned access due to vectorized ld/st instructions in NVLinkOneSided A2A. (#10539)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
This commit is contained in:
parent
dbb858ae0c
commit
f3a985ce27
@ -33,6 +33,8 @@ namespace torch_ext
|
||||
namespace moe_comm
|
||||
{
|
||||
|
||||
static constexpr size_t CACHELINE_ALIGNMENT = 128;
|
||||
|
||||
// TODO: Is Alignment necessary?
|
||||
// Helper function to align offset to specified byte boundary
|
||||
inline size_t alignOffset(size_t offset, size_t alignment)
|
||||
@ -46,7 +48,6 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens)
|
||||
// TODO: Use lambdas to encapsulate offset and alignment for each entry, which is less error prone and easier to
|
||||
// read.
|
||||
constexpr size_t SIZEOF_INT32 = 4;
|
||||
constexpr size_t CACHELINE_ALIGNMENT = 128;
|
||||
|
||||
MoeA2ADataOffsets offsets;
|
||||
size_t offset = 0;
|
||||
@ -203,12 +204,18 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
|
||||
TORCH_CHECK(payload.is_contiguous(), "All payloads must be contiguous");
|
||||
}
|
||||
|
||||
// Calculate buffer sizes for all payloads
|
||||
// Each payload buffer needs space for data from ALL ranks: epSize * maxTokensPerRank * elementsPerToken
|
||||
int64_t totalBytesNeeded = 0;
|
||||
std::vector<int64_t> payloadByteSizes;
|
||||
// Record the cacheline aligned start offset for each payload's recv buffer.
|
||||
// 1. We assume the base workspace ptr of each rank is aligned (checked in this OP)
|
||||
// 2. offsets[PAYLOAD_DATA_OFFSET_INDEX] is aligned (ensured in calculateOffsets)
|
||||
// 3. We align the currentOffset during update.
|
||||
// In this way, it is guaranteed that the recv buffer is (over-)aligned, sufficient for 128bit vectorized ld/st.
|
||||
|
||||
std::vector<int> payloadElementSizes;
|
||||
std::vector<int> payloadElementsPerToken;
|
||||
std::vector<size_t> payloadRecvBufferOffsets;
|
||||
|
||||
// Start offset for the first payload
|
||||
size_t currentOffset = static_cast<size_t>(offsets[PAYLOAD_DATA_OFFSET_INDEX]);
|
||||
for (auto const& payload : inputPayloads)
|
||||
{
|
||||
CHECK_CONTIGUOUS(payload);
|
||||
@ -216,16 +223,24 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
|
||||
TORCH_CHECK(payload.dim() == 2, "payload must be a 2D tensor");
|
||||
TORCH_CHECK(
|
||||
payload.size(0) == localNumTokens, "payload must have the same first dimension as tokenSelectedExperts");
|
||||
// Unlike recv buffer for payloads, payload itself is not allocated by us and we cannot control its alignment.
|
||||
// We only make sure the payload start offset is 16-byte aligned, while the actual vectorized ld/st width is
|
||||
// dynamically determined based on bytes per token of this payload.
|
||||
TORCH_CHECK(reinterpret_cast<uintptr_t>(payload.data_ptr()) % 16 == 0, "payload must be 16-byte aligned");
|
||||
|
||||
int elementsPerToken = static_cast<int>(payload.size(1));
|
||||
int elementSize = static_cast<int>(payload.dtype().itemsize());
|
||||
// Each payload buffer stores data from ALL ranks
|
||||
int64_t bytesPerPayload = epSize * runtimeMaxTokensPerRank * elementsPerToken * elementSize;
|
||||
|
||||
payloadByteSizes.push_back(bytesPerPayload);
|
||||
payloadElementSizes.push_back(elementSize);
|
||||
payloadElementsPerToken.push_back(elementsPerToken);
|
||||
totalBytesNeeded += bytesPerPayload;
|
||||
|
||||
payloadRecvBufferOffsets.push_back(currentOffset);
|
||||
|
||||
// Update offset and align to cacheline boundary for the next payload recv buffer.
|
||||
currentOffset += bytesPerPayload;
|
||||
currentOffset = alignOffset(currentOffset, CACHELINE_ALIGNMENT);
|
||||
}
|
||||
|
||||
CHECK_TH_CUDA(workspace);
|
||||
@ -236,16 +251,18 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
|
||||
|
||||
// Validate workspace size - must include space for auxiliary data + payloads
|
||||
int64_t sizePerRank = workspace.size(1);
|
||||
int64_t requiredSize = offsets[PAYLOAD_DATA_OFFSET_INDEX] + totalBytesNeeded;
|
||||
int64_t requiredSize = static_cast<int64_t>(currentOffset);
|
||||
TORCH_CHECK(sizePerRank >= requiredSize,
|
||||
"Workspace size per rank insufficient for dispatch. "
|
||||
"Need at least ",
|
||||
requiredSize, " bytes (", offsets[PAYLOAD_DATA_OFFSET_INDEX], " for auxiliary data + ", totalBytesNeeded,
|
||||
" for payloads), but got ", sizePerRank);
|
||||
requiredSize, " bytes (", offsets[PAYLOAD_DATA_OFFSET_INDEX], " for auxiliary data + payloads), but got ",
|
||||
sizePerRank);
|
||||
|
||||
// Get base workspace pointer
|
||||
uint8_t* workspacePtr = workspace.data_ptr<uint8_t>();
|
||||
uint8_t* rankWorkSpacePtr = workspacePtr + epRank * workspace.stride(0);
|
||||
TORCH_CHECK(reinterpret_cast<uintptr_t>(rankWorkSpacePtr) % CACHELINE_ALIGNMENT == 0,
|
||||
"rankWorkSpacePtr must be %d-byte aligned", CACHELINE_ALIGNMENT);
|
||||
|
||||
// Setup payload descriptors for source data
|
||||
int num_payloads = static_cast<int>(inputPayloads.size());
|
||||
@ -288,13 +305,10 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
|
||||
params.completion_flags[target_rank]
|
||||
= reinterpret_cast<uint32_t*>(targetWorkSpacePtr + offsets[DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX]);
|
||||
|
||||
size_t offset = static_cast<size_t>(offsets[PAYLOAD_DATA_OFFSET_INDEX]);
|
||||
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
|
||||
{
|
||||
// Store pointer for current payload
|
||||
params.recv_buffers[target_rank][payload_idx] = targetWorkSpacePtr + offset;
|
||||
// Update offset for next payload
|
||||
offset += payloadByteSizes[payload_idx];
|
||||
// Store pointer for current payload using pre-calculated aligned offset
|
||||
params.recv_buffers[target_rank][payload_idx] = targetWorkSpacePtr + payloadRecvBufferOffsets[payload_idx];
|
||||
}
|
||||
}
|
||||
|
||||
@ -310,22 +324,17 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
|
||||
|
||||
// Create tensor views for the current rank's receive buffers only
|
||||
std::vector<torch::Tensor> recvTensors;
|
||||
size_t offset = static_cast<size_t>(offsets[PAYLOAD_DATA_OFFSET_INDEX]);
|
||||
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
|
||||
{
|
||||
auto const& payload = inputPayloads[payload_idx];
|
||||
// Create tensor view for this payload
|
||||
auto recvTensor = torch::from_blob(rankWorkSpacePtr + offset,
|
||||
// Create tensor view for this payload using pre-calculated aligned offset
|
||||
auto recvTensor = torch::from_blob(rankWorkSpacePtr + payloadRecvBufferOffsets[payload_idx],
|
||||
{epSize, runtimeMaxTokensPerRank, payloadElementsPerToken[payload_idx]}, payload.options());
|
||||
recvTensors.push_back(recvTensor);
|
||||
|
||||
// Update offset for next payload
|
||||
offset += payloadByteSizes[payload_idx];
|
||||
}
|
||||
|
||||
// Compute aligned offset after dispatch payloads for combine payload region
|
||||
constexpr size_t CACHELINE_ALIGNMENT = 128;
|
||||
int64_t combinePayloadOffset = static_cast<int64_t>(alignOffset(static_cast<size_t>(offset), CACHELINE_ALIGNMENT));
|
||||
int64_t combinePayloadOffset = static_cast<int64_t>(alignOffset(currentOffset, CACHELINE_ALIGNMENT));
|
||||
|
||||
return std::make_tuple(std::move(recvTensors), combinePayloadOffset);
|
||||
}
|
||||
@ -356,6 +365,9 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke
|
||||
TORCH_CHECK(payload.size(0) == epSize, "payload first dimension must equal epSize");
|
||||
TORCH_CHECK(
|
||||
payload.size(1) == runtimeMaxTokensPerRank, "payload second dimension must equal runtimeMaxTokensPerRank");
|
||||
// We only make sure the payload start offset is 16-byte aligned, while the actual vectorized ld/st width is
|
||||
// dynamically determined based on bytes per token of this payload.
|
||||
TORCH_CHECK(reinterpret_cast<uintptr_t>(payload.data_ptr()) % 16 == 0, "payload must be 16-byte aligned");
|
||||
int64_t elementsPerToken = payload.size(2);
|
||||
TORCH_CHECK(elementsPerToken > 0, "elementsPerToken must be positive");
|
||||
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
|
||||
@ -411,6 +423,7 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke
|
||||
" for payload), but got ", sizePerRank);
|
||||
|
||||
// Create output tensor (local on current rank), no need for initialization
|
||||
// Typically, newly allocated GPU torch tensors are at least 16-byte aligned.
|
||||
torch::Tensor output = torch::empty({localNumTokens, elementsPerToken}, payload.options());
|
||||
|
||||
// Setup combine parameters
|
||||
|
||||
@ -54,25 +54,31 @@ class MoeAlltoAll:
|
||||
dtype: torch.dtype,
|
||||
extra_payload_bytes_per_token: int = 0) -> int:
|
||||
element_size = dtype.itemsize
|
||||
|
||||
# Auxiliary data size
|
||||
aux_size = MoeAlltoAll.get_aux_data_size(ep_size, max_num_tokens)
|
||||
workspace_size = MoeAlltoAll.get_aux_data_size(ep_size, max_num_tokens)
|
||||
|
||||
# Dispatch needs workspace for [ep_size, max_tokens] tokens,
|
||||
# but due to the variety of quantization recipes, we cannot know the exact size,
|
||||
# so we conservatively estimate assuming no quantization.
|
||||
payload_size_dispatch = ep_size * max_num_tokens * (
|
||||
hidden_size * element_size # (Unquantized) token hidden states
|
||||
+ top_k * 4 # token_selected_experts
|
||||
+ top_k * 4 # token_final_scales
|
||||
+ extra_payload_bytes_per_token # extra payload bytes per token
|
||||
)
|
||||
# but due to the variety of quantization recipes, we cannot know the exact size, so we conservatively estimate assuming no quantization.
|
||||
# Meanwhile, we consider the alignment requirement as in moeA2ADispatchOp and moeA2ACombineOp.
|
||||
# (Unquantized) token hidden states
|
||||
workspace_size += ep_size * max_num_tokens * hidden_size * element_size
|
||||
workspace_size = pad_up(workspace_size, 128)
|
||||
# token_selected_experts
|
||||
workspace_size += ep_size * max_num_tokens * top_k * 4
|
||||
workspace_size = pad_up(workspace_size, 128)
|
||||
# token_final_scales
|
||||
workspace_size += ep_size * max_num_tokens * top_k * 4
|
||||
workspace_size = pad_up(workspace_size, 128)
|
||||
# extra payload bytes per token
|
||||
workspace_size += ep_size * max_num_tokens * extra_payload_bytes_per_token
|
||||
workspace_size = pad_up(workspace_size, 128)
|
||||
|
||||
# Required workspace for combine [ep_size, max_tokens] tokens
|
||||
payload_size_combine = ep_size * max_num_tokens * hidden_size * element_size
|
||||
workspace_size += ep_size * max_num_tokens * hidden_size * element_size
|
||||
workspace_size = pad_up(workspace_size, 128)
|
||||
|
||||
# Pad to 128 bytes to ensure alignment. This matches the implementation of C++ torch OP code.
|
||||
return pad_up(aux_size, 128) + pad_up(
|
||||
payload_size_dispatch, 128) + pad_up(payload_size_combine, 128)
|
||||
return workspace_size
|
||||
|
||||
@classmethod
|
||||
def _init_constants(cls):
|
||||
|
||||
@ -81,32 +81,31 @@ class NVLinkOneSided(Communication):
|
||||
extra_payload_bytes_per_token: int = 0,
|
||||
) -> int:
|
||||
element_size = dtype.itemsize
|
||||
|
||||
# Auxiliary data size
|
||||
aux_size = NVLinkOneSided.get_aux_data_size(ep_size, max_num_tokens)
|
||||
workspace_size = NVLinkOneSided.get_aux_data_size(ep_size, max_num_tokens)
|
||||
|
||||
# Dispatch needs workspace for [ep_size, max_tokens] tokens,
|
||||
# but due to the variety of quantization recipes, we cannot know the exact size,
|
||||
# so we conservatively estimate assuming no quantization.
|
||||
payload_size_dispatch = (
|
||||
ep_size
|
||||
* max_num_tokens
|
||||
* (
|
||||
hidden_size * element_size # (Unquantized) token hidden states
|
||||
+ top_k * 4 # token_selected_experts
|
||||
+ top_k * 4 # token_final_scales
|
||||
+ extra_payload_bytes_per_token # extra payload bytes per token
|
||||
)
|
||||
)
|
||||
# but due to the variety of quantization recipes, we cannot know the exact size, so we conservatively estimate assuming no quantization.
|
||||
# Meanwhile, we consider the alignment requirement as in moeA2ADispatchOp and moeA2ACombineOp.
|
||||
# (Unquantized) token hidden states
|
||||
workspace_size += ep_size * max_num_tokens * hidden_size * element_size
|
||||
workspace_size = pad_up(workspace_size, 128)
|
||||
# token_selected_experts
|
||||
workspace_size += ep_size * max_num_tokens * top_k * 4
|
||||
workspace_size = pad_up(workspace_size, 128)
|
||||
# token_final_scales
|
||||
workspace_size += ep_size * max_num_tokens * top_k * 4
|
||||
workspace_size = pad_up(workspace_size, 128)
|
||||
# extra payload bytes per token
|
||||
workspace_size += ep_size * max_num_tokens * extra_payload_bytes_per_token
|
||||
workspace_size = pad_up(workspace_size, 128)
|
||||
|
||||
# Required workspace for combine [ep_size, max_tokens] tokens
|
||||
payload_size_combine = ep_size * max_num_tokens * hidden_size * element_size
|
||||
workspace_size += ep_size * max_num_tokens * hidden_size * element_size
|
||||
workspace_size = pad_up(workspace_size, 128)
|
||||
|
||||
# Pad to 128 bytes to ensure alignment. This matches the implementation of C++ torch OP code.
|
||||
return (
|
||||
pad_up(aux_size, 128)
|
||||
+ pad_up(payload_size_dispatch, 128)
|
||||
+ pad_up(payload_size_combine, 128)
|
||||
)
|
||||
return workspace_size
|
||||
|
||||
@classmethod
|
||||
def _init_constants(cls):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user