[TRTLLM-10296][fix] Fix the potential misaligned access due to vectorized ld/st instructions in NVLinkOneSided A2A. (#10539)

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
This commit is contained in:
Bo Li 2026-01-20 11:08:04 +08:00 committed by GitHub
parent dbb858ae0c
commit f3a985ce27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 74 additions and 56 deletions

View File

@ -33,6 +33,8 @@ namespace torch_ext
namespace moe_comm
{
static constexpr size_t CACHELINE_ALIGNMENT = 128;
// TODO: Is Alignment necessary?
// Helper function to align offset to specified byte boundary
inline size_t alignOffset(size_t offset, size_t alignment)
@ -46,7 +48,6 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens)
// TODO: Use lambdas to encapsulate offset and alignment for each entry, which is less error prone and easier to
// read.
constexpr size_t SIZEOF_INT32 = 4;
constexpr size_t CACHELINE_ALIGNMENT = 128;
MoeA2ADataOffsets offsets;
size_t offset = 0;
@ -203,12 +204,18 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
TORCH_CHECK(payload.is_contiguous(), "All payloads must be contiguous");
}
// Calculate buffer sizes for all payloads
// Each payload buffer needs space for data from ALL ranks: epSize * maxTokensPerRank * elementsPerToken
int64_t totalBytesNeeded = 0;
std::vector<int64_t> payloadByteSizes;
// Record the cacheline aligned start offset for each payload's recv buffer.
// 1. We assume the base workspace ptr of each rank is aligned (checked in this OP)
// 2. offsets[PAYLOAD_DATA_OFFSET_INDEX] is aligned (ensured in calculateOffsets)
// 3. We align the currentOffset during update.
// In this way, it is guaranteed that the recv buffer is (over-)aligned, sufficient for 128bit vectorized ld/st.
std::vector<int> payloadElementSizes;
std::vector<int> payloadElementsPerToken;
std::vector<size_t> payloadRecvBufferOffsets;
// Start offset for the first payload
size_t currentOffset = static_cast<size_t>(offsets[PAYLOAD_DATA_OFFSET_INDEX]);
for (auto const& payload : inputPayloads)
{
CHECK_CONTIGUOUS(payload);
@ -216,16 +223,24 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
TORCH_CHECK(payload.dim() == 2, "payload must be a 2D tensor");
TORCH_CHECK(
payload.size(0) == localNumTokens, "payload must have the same first dimension as tokenSelectedExperts");
// Unlike recv buffer for payloads, payload itself is not allocated by us and we cannot control its alignment.
// We only make sure the payload start offset is 16-byte aligned, while the actual vectorized ld/st width is
// dynamically determined based on bytes per token of this payload.
TORCH_CHECK(reinterpret_cast<uintptr_t>(payload.data_ptr()) % 16 == 0, "payload must be 16-byte aligned");
int elementsPerToken = static_cast<int>(payload.size(1));
int elementSize = static_cast<int>(payload.dtype().itemsize());
// Each payload buffer stores data from ALL ranks
int64_t bytesPerPayload = epSize * runtimeMaxTokensPerRank * elementsPerToken * elementSize;
payloadByteSizes.push_back(bytesPerPayload);
payloadElementSizes.push_back(elementSize);
payloadElementsPerToken.push_back(elementsPerToken);
totalBytesNeeded += bytesPerPayload;
payloadRecvBufferOffsets.push_back(currentOffset);
// Update offset and align to cacheline boundary for the next payload recv buffer.
currentOffset += bytesPerPayload;
currentOffset = alignOffset(currentOffset, CACHELINE_ALIGNMENT);
}
CHECK_TH_CUDA(workspace);
@ -236,16 +251,18 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
// Validate workspace size - must include space for auxiliary data + payloads
int64_t sizePerRank = workspace.size(1);
int64_t requiredSize = offsets[PAYLOAD_DATA_OFFSET_INDEX] + totalBytesNeeded;
int64_t requiredSize = static_cast<int64_t>(currentOffset);
TORCH_CHECK(sizePerRank >= requiredSize,
"Workspace size per rank insufficient for dispatch. "
"Need at least ",
requiredSize, " bytes (", offsets[PAYLOAD_DATA_OFFSET_INDEX], " for auxiliary data + ", totalBytesNeeded,
" for payloads), but got ", sizePerRank);
requiredSize, " bytes (", offsets[PAYLOAD_DATA_OFFSET_INDEX], " for auxiliary data + payloads), but got ",
sizePerRank);
// Get base workspace pointer
uint8_t* workspacePtr = workspace.data_ptr<uint8_t>();
uint8_t* rankWorkSpacePtr = workspacePtr + epRank * workspace.stride(0);
TORCH_CHECK(reinterpret_cast<uintptr_t>(rankWorkSpacePtr) % CACHELINE_ALIGNMENT == 0,
"rankWorkSpacePtr must be %d-byte aligned", CACHELINE_ALIGNMENT);
// Setup payload descriptors for source data
int num_payloads = static_cast<int>(inputPayloads.size());
@ -288,13 +305,10 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
params.completion_flags[target_rank]
= reinterpret_cast<uint32_t*>(targetWorkSpacePtr + offsets[DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX]);
size_t offset = static_cast<size_t>(offsets[PAYLOAD_DATA_OFFSET_INDEX]);
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
{
// Store pointer for current payload
params.recv_buffers[target_rank][payload_idx] = targetWorkSpacePtr + offset;
// Update offset for next payload
offset += payloadByteSizes[payload_idx];
// Store pointer for current payload using pre-calculated aligned offset
params.recv_buffers[target_rank][payload_idx] = targetWorkSpacePtr + payloadRecvBufferOffsets[payload_idx];
}
}
@ -310,22 +324,17 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
// Create tensor views for the current rank's receive buffers only
std::vector<torch::Tensor> recvTensors;
size_t offset = static_cast<size_t>(offsets[PAYLOAD_DATA_OFFSET_INDEX]);
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
{
auto const& payload = inputPayloads[payload_idx];
// Create tensor view for this payload
auto recvTensor = torch::from_blob(rankWorkSpacePtr + offset,
// Create tensor view for this payload using pre-calculated aligned offset
auto recvTensor = torch::from_blob(rankWorkSpacePtr + payloadRecvBufferOffsets[payload_idx],
{epSize, runtimeMaxTokensPerRank, payloadElementsPerToken[payload_idx]}, payload.options());
recvTensors.push_back(recvTensor);
// Update offset for next payload
offset += payloadByteSizes[payload_idx];
}
// Compute aligned offset after dispatch payloads for combine payload region
constexpr size_t CACHELINE_ALIGNMENT = 128;
int64_t combinePayloadOffset = static_cast<int64_t>(alignOffset(static_cast<size_t>(offset), CACHELINE_ALIGNMENT));
int64_t combinePayloadOffset = static_cast<int64_t>(alignOffset(currentOffset, CACHELINE_ALIGNMENT));
return std::make_tuple(std::move(recvTensors), combinePayloadOffset);
}
@ -356,6 +365,9 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke
TORCH_CHECK(payload.size(0) == epSize, "payload first dimension must equal epSize");
TORCH_CHECK(
payload.size(1) == runtimeMaxTokensPerRank, "payload second dimension must equal runtimeMaxTokensPerRank");
// We only make sure the payload start offset is 16-byte aligned, while the actual vectorized ld/st width is
// dynamically determined based on bytes per token of this payload.
TORCH_CHECK(reinterpret_cast<uintptr_t>(payload.data_ptr()) % 16 == 0, "payload must be 16-byte aligned");
int64_t elementsPerToken = payload.size(2);
TORCH_CHECK(elementsPerToken > 0, "elementsPerToken must be positive");
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
@ -411,6 +423,7 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke
" for payload), but got ", sizePerRank);
// Create output tensor (local on current rank), no need for initialization
// Typically, newly allocated GPU torch tensors are at least 16-byte aligned.
torch::Tensor output = torch::empty({localNumTokens, elementsPerToken}, payload.options());
// Setup combine parameters

View File

@ -54,25 +54,31 @@ class MoeAlltoAll:
dtype: torch.dtype,
extra_payload_bytes_per_token: int = 0) -> int:
element_size = dtype.itemsize
# Auxiliary data size
aux_size = MoeAlltoAll.get_aux_data_size(ep_size, max_num_tokens)
workspace_size = MoeAlltoAll.get_aux_data_size(ep_size, max_num_tokens)
# Dispatch needs workspace for [ep_size, max_tokens] tokens,
# but due to the variety of quantization recipes, we cannot know the exact size,
# so we conservatively estimate assuming no quantization.
payload_size_dispatch = ep_size * max_num_tokens * (
hidden_size * element_size # (Unquantized) token hidden states
+ top_k * 4 # token_selected_experts
+ top_k * 4 # token_final_scales
+ extra_payload_bytes_per_token # extra payload bytes per token
)
# but due to the variety of quantization recipes, we cannot know the exact size, so we conservatively estimate assuming no quantization.
# Meanwhile, we consider the alignment requirement as in moeA2ADispatchOp and moeA2ACombineOp.
# (Unquantized) token hidden states
workspace_size += ep_size * max_num_tokens * hidden_size * element_size
workspace_size = pad_up(workspace_size, 128)
# token_selected_experts
workspace_size += ep_size * max_num_tokens * top_k * 4
workspace_size = pad_up(workspace_size, 128)
# token_final_scales
workspace_size += ep_size * max_num_tokens * top_k * 4
workspace_size = pad_up(workspace_size, 128)
# extra payload bytes per token
workspace_size += ep_size * max_num_tokens * extra_payload_bytes_per_token
workspace_size = pad_up(workspace_size, 128)
# Required workspace for combine [ep_size, max_tokens] tokens
payload_size_combine = ep_size * max_num_tokens * hidden_size * element_size
workspace_size += ep_size * max_num_tokens * hidden_size * element_size
workspace_size = pad_up(workspace_size, 128)
# Pad to 128 bytes to ensure alignment. This matches the implementation of C++ torch OP code.
return pad_up(aux_size, 128) + pad_up(
payload_size_dispatch, 128) + pad_up(payload_size_combine, 128)
return workspace_size
@classmethod
def _init_constants(cls):

View File

@ -81,32 +81,31 @@ class NVLinkOneSided(Communication):
extra_payload_bytes_per_token: int = 0,
) -> int:
element_size = dtype.itemsize
# Auxiliary data size
aux_size = NVLinkOneSided.get_aux_data_size(ep_size, max_num_tokens)
workspace_size = NVLinkOneSided.get_aux_data_size(ep_size, max_num_tokens)
# Dispatch needs workspace for [ep_size, max_tokens] tokens,
# but due to the variety of quantization recipes, we cannot know the exact size,
# so we conservatively estimate assuming no quantization.
payload_size_dispatch = (
ep_size
* max_num_tokens
* (
hidden_size * element_size # (Unquantized) token hidden states
+ top_k * 4 # token_selected_experts
+ top_k * 4 # token_final_scales
+ extra_payload_bytes_per_token # extra payload bytes per token
)
)
# but due to the variety of quantization recipes, we cannot know the exact size, so we conservatively estimate assuming no quantization.
# Meanwhile, we consider the alignment requirement as in moeA2ADispatchOp and moeA2ACombineOp.
# (Unquantized) token hidden states
workspace_size += ep_size * max_num_tokens * hidden_size * element_size
workspace_size = pad_up(workspace_size, 128)
# token_selected_experts
workspace_size += ep_size * max_num_tokens * top_k * 4
workspace_size = pad_up(workspace_size, 128)
# token_final_scales
workspace_size += ep_size * max_num_tokens * top_k * 4
workspace_size = pad_up(workspace_size, 128)
# extra payload bytes per token
workspace_size += ep_size * max_num_tokens * extra_payload_bytes_per_token
workspace_size = pad_up(workspace_size, 128)
# Required workspace for combine [ep_size, max_tokens] tokens
payload_size_combine = ep_size * max_num_tokens * hidden_size * element_size
workspace_size += ep_size * max_num_tokens * hidden_size * element_size
workspace_size = pad_up(workspace_size, 128)
# Pad to 128 bytes to ensure alignment. This matches the implementation of C++ torch OP code.
return (
pad_up(aux_size, 128)
+ pad_up(payload_size_dispatch, 128)
+ pad_up(payload_size_combine, 128)
)
return workspace_size
@classmethod
def _init_constants(cls):