diff --git a/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp b/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp index af6d7cb37d..29ad780d4c 100644 --- a/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp +++ b/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp @@ -33,6 +33,8 @@ namespace torch_ext namespace moe_comm { +static constexpr size_t CACHELINE_ALIGNMENT = 128; + // TODO: Is Alignment necessary? // Helper function to align offset to specified byte boundary inline size_t alignOffset(size_t offset, size_t alignment) @@ -46,7 +48,6 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens) // TODO: Use lambdas to encapsulate offset and alignment for each entry, which is less error prone and easier to // read. constexpr size_t SIZEOF_INT32 = 4; - constexpr size_t CACHELINE_ALIGNMENT = 128; MoeA2ADataOffsets offsets; size_t offset = 0; @@ -203,12 +204,18 @@ std::tuple, int64_t> moeA2ADispatchOp(torch::Tensor c TORCH_CHECK(payload.is_contiguous(), "All payloads must be contiguous"); } - // Calculate buffer sizes for all payloads - // Each payload buffer needs space for data from ALL ranks: epSize * maxTokensPerRank * elementsPerToken - int64_t totalBytesNeeded = 0; - std::vector payloadByteSizes; + // Record the cacheline aligned start offset for each payload's recv buffer. + // 1. We assume the base workspace ptr of each rank is aligned (checked in this OP) + // 2. offsets[PAYLOAD_DATA_OFFSET_INDEX] is aligned (ensured in calculateOffsets) + // 3. We align the currentOffset during update. + // In this way, it is guaranteed that the recv buffer is (over-)aligned, sufficient for 128bit vectorized ld/st. + std::vector payloadElementSizes; std::vector payloadElementsPerToken; + std::vector payloadRecvBufferOffsets; + + // Start offset for the first payload + size_t currentOffset = static_cast(offsets[PAYLOAD_DATA_OFFSET_INDEX]); for (auto const& payload : inputPayloads) { CHECK_CONTIGUOUS(payload); @@ -216,16 +223,24 @@ std::tuple, int64_t> moeA2ADispatchOp(torch::Tensor c TORCH_CHECK(payload.dim() == 2, "payload must be a 2D tensor"); TORCH_CHECK( payload.size(0) == localNumTokens, "payload must have the same first dimension as tokenSelectedExperts"); + // Unlike recv buffer for payloads, payload itself is not allocated by us and we cannot control its alignment. + // We only make sure the payload start offset is 16-byte aligned, while the actual vectorized ld/st width is + // dynamically determined based on bytes per token of this payload. + TORCH_CHECK(reinterpret_cast(payload.data_ptr()) % 16 == 0, "payload must be 16-byte aligned"); int elementsPerToken = static_cast(payload.size(1)); int elementSize = static_cast(payload.dtype().itemsize()); // Each payload buffer stores data from ALL ranks int64_t bytesPerPayload = epSize * runtimeMaxTokensPerRank * elementsPerToken * elementSize; - payloadByteSizes.push_back(bytesPerPayload); payloadElementSizes.push_back(elementSize); payloadElementsPerToken.push_back(elementsPerToken); - totalBytesNeeded += bytesPerPayload; + + payloadRecvBufferOffsets.push_back(currentOffset); + + // Update offset and align to cacheline boundary for the next payload recv buffer. + currentOffset += bytesPerPayload; + currentOffset = alignOffset(currentOffset, CACHELINE_ALIGNMENT); } CHECK_TH_CUDA(workspace); @@ -236,16 +251,18 @@ std::tuple, int64_t> moeA2ADispatchOp(torch::Tensor c // Validate workspace size - must include space for auxiliary data + payloads int64_t sizePerRank = workspace.size(1); - int64_t requiredSize = offsets[PAYLOAD_DATA_OFFSET_INDEX] + totalBytesNeeded; + int64_t requiredSize = static_cast(currentOffset); TORCH_CHECK(sizePerRank >= requiredSize, "Workspace size per rank insufficient for dispatch. " "Need at least ", - requiredSize, " bytes (", offsets[PAYLOAD_DATA_OFFSET_INDEX], " for auxiliary data + ", totalBytesNeeded, - " for payloads), but got ", sizePerRank); + requiredSize, " bytes (", offsets[PAYLOAD_DATA_OFFSET_INDEX], " for auxiliary data + payloads), but got ", + sizePerRank); // Get base workspace pointer uint8_t* workspacePtr = workspace.data_ptr(); uint8_t* rankWorkSpacePtr = workspacePtr + epRank * workspace.stride(0); + TORCH_CHECK(reinterpret_cast(rankWorkSpacePtr) % CACHELINE_ALIGNMENT == 0, + "rankWorkSpacePtr must be %d-byte aligned", CACHELINE_ALIGNMENT); // Setup payload descriptors for source data int num_payloads = static_cast(inputPayloads.size()); @@ -288,13 +305,10 @@ std::tuple, int64_t> moeA2ADispatchOp(torch::Tensor c params.completion_flags[target_rank] = reinterpret_cast(targetWorkSpacePtr + offsets[DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX]); - size_t offset = static_cast(offsets[PAYLOAD_DATA_OFFSET_INDEX]); for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++) { - // Store pointer for current payload - params.recv_buffers[target_rank][payload_idx] = targetWorkSpacePtr + offset; - // Update offset for next payload - offset += payloadByteSizes[payload_idx]; + // Store pointer for current payload using pre-calculated aligned offset + params.recv_buffers[target_rank][payload_idx] = targetWorkSpacePtr + payloadRecvBufferOffsets[payload_idx]; } } @@ -310,22 +324,17 @@ std::tuple, int64_t> moeA2ADispatchOp(torch::Tensor c // Create tensor views for the current rank's receive buffers only std::vector recvTensors; - size_t offset = static_cast(offsets[PAYLOAD_DATA_OFFSET_INDEX]); for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++) { auto const& payload = inputPayloads[payload_idx]; - // Create tensor view for this payload - auto recvTensor = torch::from_blob(rankWorkSpacePtr + offset, + // Create tensor view for this payload using pre-calculated aligned offset + auto recvTensor = torch::from_blob(rankWorkSpacePtr + payloadRecvBufferOffsets[payload_idx], {epSize, runtimeMaxTokensPerRank, payloadElementsPerToken[payload_idx]}, payload.options()); recvTensors.push_back(recvTensor); - - // Update offset for next payload - offset += payloadByteSizes[payload_idx]; } // Compute aligned offset after dispatch payloads for combine payload region - constexpr size_t CACHELINE_ALIGNMENT = 128; - int64_t combinePayloadOffset = static_cast(alignOffset(static_cast(offset), CACHELINE_ALIGNMENT)); + int64_t combinePayloadOffset = static_cast(alignOffset(currentOffset, CACHELINE_ALIGNMENT)); return std::make_tuple(std::move(recvTensors), combinePayloadOffset); } @@ -356,6 +365,9 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke TORCH_CHECK(payload.size(0) == epSize, "payload first dimension must equal epSize"); TORCH_CHECK( payload.size(1) == runtimeMaxTokensPerRank, "payload second dimension must equal runtimeMaxTokensPerRank"); + // We only make sure the payload start offset is 16-byte aligned, while the actual vectorized ld/st width is + // dynamically determined based on bytes per token of this payload. + TORCH_CHECK(reinterpret_cast(payload.data_ptr()) % 16 == 0, "payload must be 16-byte aligned"); int64_t elementsPerToken = payload.size(2); TORCH_CHECK(elementsPerToken > 0, "elementsPerToken must be positive"); TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)"); @@ -411,6 +423,7 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke " for payload), but got ", sizePerRank); // Create output tensor (local on current rank), no need for initialization + // Typically, newly allocated GPU torch tensors are at least 16-byte aligned. torch::Tensor output = torch::empty({localNumTokens, elementsPerToken}, payload.options()); // Setup combine parameters diff --git a/tensorrt_llm/_torch/distributed/moe_alltoall.py b/tensorrt_llm/_torch/distributed/moe_alltoall.py index 1eeea09aca..231a671f3b 100644 --- a/tensorrt_llm/_torch/distributed/moe_alltoall.py +++ b/tensorrt_llm/_torch/distributed/moe_alltoall.py @@ -54,25 +54,31 @@ class MoeAlltoAll: dtype: torch.dtype, extra_payload_bytes_per_token: int = 0) -> int: element_size = dtype.itemsize + # Auxiliary data size - aux_size = MoeAlltoAll.get_aux_data_size(ep_size, max_num_tokens) + workspace_size = MoeAlltoAll.get_aux_data_size(ep_size, max_num_tokens) # Dispatch needs workspace for [ep_size, max_tokens] tokens, - # but due to the variety of quantization recipes, we cannot know the exact size, - # so we conservatively estimate assuming no quantization. - payload_size_dispatch = ep_size * max_num_tokens * ( - hidden_size * element_size # (Unquantized) token hidden states - + top_k * 4 # token_selected_experts - + top_k * 4 # token_final_scales - + extra_payload_bytes_per_token # extra payload bytes per token - ) + # but due to the variety of quantization recipes, we cannot know the exact size, so we conservatively estimate assuming no quantization. + # Meanwhile, we consider the alignment requirement as in moeA2ADispatchOp and moeA2ACombineOp. + # (Unquantized) token hidden states + workspace_size += ep_size * max_num_tokens * hidden_size * element_size + workspace_size = pad_up(workspace_size, 128) + # token_selected_experts + workspace_size += ep_size * max_num_tokens * top_k * 4 + workspace_size = pad_up(workspace_size, 128) + # token_final_scales + workspace_size += ep_size * max_num_tokens * top_k * 4 + workspace_size = pad_up(workspace_size, 128) + # extra payload bytes per token + workspace_size += ep_size * max_num_tokens * extra_payload_bytes_per_token + workspace_size = pad_up(workspace_size, 128) # Required workspace for combine [ep_size, max_tokens] tokens - payload_size_combine = ep_size * max_num_tokens * hidden_size * element_size + workspace_size += ep_size * max_num_tokens * hidden_size * element_size + workspace_size = pad_up(workspace_size, 128) - # Pad to 128 bytes to ensure alignment. This matches the implementation of C++ torch OP code. - return pad_up(aux_size, 128) + pad_up( - payload_size_dispatch, 128) + pad_up(payload_size_combine, 128) + return workspace_size @classmethod def _init_constants(cls): diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py index c03270753e..2cbe066204 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py @@ -81,32 +81,31 @@ class NVLinkOneSided(Communication): extra_payload_bytes_per_token: int = 0, ) -> int: element_size = dtype.itemsize + # Auxiliary data size - aux_size = NVLinkOneSided.get_aux_data_size(ep_size, max_num_tokens) + workspace_size = NVLinkOneSided.get_aux_data_size(ep_size, max_num_tokens) # Dispatch needs workspace for [ep_size, max_tokens] tokens, - # but due to the variety of quantization recipes, we cannot know the exact size, - # so we conservatively estimate assuming no quantization. - payload_size_dispatch = ( - ep_size - * max_num_tokens - * ( - hidden_size * element_size # (Unquantized) token hidden states - + top_k * 4 # token_selected_experts - + top_k * 4 # token_final_scales - + extra_payload_bytes_per_token # extra payload bytes per token - ) - ) + # but due to the variety of quantization recipes, we cannot know the exact size, so we conservatively estimate assuming no quantization. + # Meanwhile, we consider the alignment requirement as in moeA2ADispatchOp and moeA2ACombineOp. + # (Unquantized) token hidden states + workspace_size += ep_size * max_num_tokens * hidden_size * element_size + workspace_size = pad_up(workspace_size, 128) + # token_selected_experts + workspace_size += ep_size * max_num_tokens * top_k * 4 + workspace_size = pad_up(workspace_size, 128) + # token_final_scales + workspace_size += ep_size * max_num_tokens * top_k * 4 + workspace_size = pad_up(workspace_size, 128) + # extra payload bytes per token + workspace_size += ep_size * max_num_tokens * extra_payload_bytes_per_token + workspace_size = pad_up(workspace_size, 128) # Required workspace for combine [ep_size, max_tokens] tokens - payload_size_combine = ep_size * max_num_tokens * hidden_size * element_size + workspace_size += ep_size * max_num_tokens * hidden_size * element_size + workspace_size = pad_up(workspace_size, 128) - # Pad to 128 bytes to ensure alignment. This matches the implementation of C++ torch OP code. - return ( - pad_up(aux_size, 128) - + pad_up(payload_size_dispatch, 128) - + pad_up(payload_size_combine, 128) - ) + return workspace_size @classmethod def _init_constants(cls):