TensorRT-LLMs/cpp/tests/unit_tests/kernels/fusedMoeCommKernelTest.cpp
Zongfei Jing 53163bf1df
[TRTLLM-6876][feat] Add low precision all2all for mnnvl (#7155)
Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
2025-08-28 18:26:16 +08:00

1430 lines
57 KiB
C++

/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <atomic>
#include <chrono>
#include <functional>
#include <gtest/gtest.h>
#include <memory>
#include <random>
#include <thread>
#include <vector>
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/fusedMoeCommKernels.h"
using namespace tensorrt_llm::kernels;
class FusedMoeCommTestBase : public ::testing::Test
{
protected:
static bool shouldSkip()
{
int deviceCount = tensorrt_llm::common::getDeviceCount();
if (deviceCount <= 0)
{
return true;
}
int sm = tensorrt_llm::common::getSMVersion();
if (sm < 90)
{
return true;
}
return false;
}
void SetUp() override
{
if (shouldSkip())
{
skipped = true;
GTEST_SKIP() << "Skipping due to no/unsupported GPU";
}
TLLM_CUDA_CHECK(cudaStreamCreate(&stream));
std::srand(42); // Initialize random seed
}
void TearDown() override
{
if (!skipped)
{
TLLM_CUDA_CHECK(cudaStreamDestroy(stream));
}
}
cudaDataType_t getCudaDataType(int elementSize)
{
switch (elementSize)
{
case 1: return CUDA_R_8U;
case 2: return CUDA_R_16F;
case 4: return CUDA_R_32F;
case 8: return CUDA_R_64F;
case 16: return CUDA_C_64F;
default: TLLM_THROW("Unsupported element size: %d", elementSize);
};
}
bool skipped = false;
cudaStream_t stream = nullptr;
// Helper function to allocate and initialize test data
template <typename T>
void allocateAndInitializeData(
T** hostPtr, T** devicePtr, size_t count, std::function<T(size_t)> generator = nullptr)
{
*hostPtr = new T[count];
TLLM_CUDA_CHECK(cudaMalloc(devicePtr, count * sizeof(T)));
if (generator)
{
for (size_t i = 0; i < count; i++)
{
(*hostPtr)[i] = generator(i);
}
}
else
{
// Default initialization with random values
for (size_t i = 0; i < count; i++)
{
if constexpr (std::is_same_v<T, float>)
{
(*hostPtr)[i] = static_cast<float>(rand()) / RAND_MAX * 10.0f;
}
else if constexpr (std::is_same_v<T, int>)
{
(*hostPtr)[i] = rand() % 1000;
}
else
{
(*hostPtr)[i] = static_cast<T>(rand() % 100);
}
}
}
TLLM_CUDA_CHECK(cudaMemcpy(*devicePtr, *hostPtr, count * sizeof(T), cudaMemcpyHostToDevice));
}
void cleanup(void* hostPtr, void* devicePtr)
{
delete[] static_cast<char*>(hostPtr);
TLLM_CUDA_CHECK(cudaFree(devicePtr));
}
// Generate a one-to-one mapping, extending with random permutation if needed
std::vector<int> generateOneToOneMapping(std::vector<int> const& partialMapping, int totalSize)
{
std::vector<int> fullMapping(totalSize);
std::vector<bool> used(totalSize, false);
// First, copy the provided mapping and mark used indices
int providedSize = static_cast<int>(partialMapping.size());
for (int i = 0; i < std::min(providedSize, totalSize); i++)
{
int target = partialMapping[i];
if (target >= 0 && target < totalSize && !used[target])
{
fullMapping[i] = target;
used[target] = true;
}
else
{
// Invalid mapping, will be handled later
fullMapping[i] = -1;
}
}
// Collect unused indices
std::vector<int> unusedIndices;
for (int i = 0; i < totalSize; i++)
{
if (!used[i])
{
unusedIndices.push_back(i);
}
}
// Shuffle unused indices for random assignment
std::srand(42); // Fixed seed for reproducible tests
std::random_shuffle(unusedIndices.begin(), unusedIndices.end());
// Fill in any invalid mappings and extend with remaining unused indices
int unusedIdx = 0;
for (int i = 0; i < totalSize; i++)
{
if (i < providedSize && fullMapping[i] == -1)
{
// Fix invalid mapping
if (unusedIdx < unusedIndices.size())
{
fullMapping[i] = unusedIndices[unusedIdx++];
}
}
else if (i >= providedSize)
{
// Extend mapping
if (unusedIdx < unusedIndices.size())
{
fullMapping[i] = unusedIndices[unusedIdx++];
}
else
{
// Fallback: identity mapping for remaining
fullMapping[i] = i;
}
}
}
return fullMapping;
}
};
// Test class for launchSingleG2S function
class FusedMoeCommG2STest : public FusedMoeCommTestBase
{
protected:
void runG2STest(int topK, bool hasScales, bool hasBasicFields, int sendFieldCount,
std::vector<size_t> const& elementSizes, std::vector<uint16_t> const& vectorSizes, int tokenCount = 4,
int warpsPerBlock = 2)
{
// Setup expert parallel info
MoeExpertParallelInfo expertParallelInfo;
expertParallelInfo.topK = topK;
expertParallelInfo.expertCount = 8;
// Setup send field info
FusedMoeFieldInfo sendFieldInfo = {};
sendFieldInfo.isBasicInterleaved = false;
sendFieldInfo.fieldCount = sendFieldCount;
// Allocate token selected slots and expert scales if needed
int* hostTokenSlots = nullptr;
int* deviceTokenSlots = nullptr;
float* hostScales = nullptr;
float* deviceScales = nullptr;
if (hasBasicFields)
{
allocateAndInitializeData<int>(&hostTokenSlots, &deviceTokenSlots, tokenCount * topK,
[](size_t i) { return static_cast<int>(i % 8); });
sendFieldInfo.tokenSelectedSlots = deviceTokenSlots;
if (hasScales)
{
allocateAndInitializeData<float>(&hostScales, &deviceScales, tokenCount * topK,
[](size_t i) -> float { return 1.0f + static_cast<float>(i) * 0.1f; });
sendFieldInfo.expertScales = deviceScales;
}
}
// Setup send field info using new fillFieldInfo helper
std::vector<void*> hostFieldPtrs(sendFieldCount);
std::vector<void*> deviceFieldPtrs(sendFieldCount);
for (int i = 0; i < sendFieldCount; i++)
{
size_t elementSize = elementSizes[i % elementSizes.size()];
uint16_t vectorSize = vectorSizes[i % vectorSizes.size()];
size_t fieldSize = elementSize * vectorSize * tokenCount;
// Allocate field data
uint8_t* hostField;
uint8_t* deviceField;
allocateAndInitializeData<uint8_t>(&hostField, &deviceField, fieldSize,
[i](size_t idx) { return static_cast<uint8_t>((i * 100 + idx) % 128); });
hostFieldPtrs[i] = hostField;
deviceFieldPtrs[i] = deviceField;
// Use the new fillFieldInfo helper function
sendFieldInfo.fieldsInfo[i].fillFieldInfo(
deviceField, elementSize, vectorSize, vectorSize, getCudaDataType(elementSize));
}
// Fill field placement info
sendFieldInfo.fillFieldPlacementInfo(topK, hasBasicFields);
// Compute shared memory size and allocate output buffer
int warpShmSize = sendFieldInfo.computeSingleUncompactSize(topK, hasScales, hasBasicFields);
size_t shmDumpSize = tokenCount * warpShmSize;
size_t shmDumpIntCount = shmDumpSize / sizeof(int);
int* hostShmDump;
int* deviceShmDump;
allocateAndInitializeData<int>(&hostShmDump, &deviceShmDump, shmDumpIntCount, [](size_t) { return 0; });
// Launch G2S kernel with new signature
fused_moe_comm_tests::launchSingleG2S(
sendFieldInfo, expertParallelInfo, tokenCount, deviceShmDump, warpsPerBlock, hasBasicFields, stream);
TLLM_CUDA_CHECK(cudaStreamSynchronize(stream));
// Copy back results
int* resultShmDump = new int[shmDumpIntCount];
TLLM_CUDA_CHECK(
cudaMemcpy(resultShmDump, deviceShmDump, shmDumpIntCount * sizeof(int), cudaMemcpyDeviceToHost));
// Verify results
verifyG2SResults(resultShmDump, hostTokenSlots, hostScales, hostFieldPtrs, topK, hasScales, hasBasicFields,
sendFieldCount, elementSizes, vectorSizes, tokenCount, warpsPerBlock, warpShmSize);
// Cleanup
if (hasBasicFields)
{
cleanup(hostTokenSlots, deviceTokenSlots);
if (hasScales)
{
cleanup(hostScales, deviceScales);
}
}
for (int i = 0; i < sendFieldCount; i++)
{
cleanup(hostFieldPtrs[i], deviceFieldPtrs[i]);
}
cleanup(hostShmDump, deviceShmDump);
delete[] resultShmDump;
}
private:
void verifyG2SResults(int const* shmDump, int const* expectedTokenSlots, float const* expectedScales,
std::vector<void*> const& expectedFields, int topK, bool hasScales, bool hasBasicFields, int sendFieldCount,
std::vector<size_t> const& elementSizes, std::vector<uint16_t> const& vectorSizes, int tokenCount,
int warpsPerBlock, int warpShmSize)
{
for (int tokenId = 0; tokenId < tokenCount; tokenId++)
{
int const* warpShmData = shmDump + tokenId * warpShmSize / sizeof(int);
// Verify token slots and scales only if hasBasicFields is true
if (hasBasicFields)
{
// Verify token slots
if (expectedTokenSlots)
{
for (int k = 0; k < topK; k++)
{
int expected = expectedTokenSlots[tokenId * topK + k];
int actual = warpShmData[k];
EXPECT_EQ(expected, actual) << "Token slot mismatch at warp=" << tokenId << ", k=" << k;
}
}
// Verify scales if present
if (hasScales && expectedScales)
{
for (int k = 0; k < topK; k++)
{
float expected = expectedScales[tokenId * topK + k];
float actual = reinterpret_cast<float const*>(warpShmData)[topK + k];
EXPECT_NEAR(expected, actual, 1e-6f) << "Scale mismatch at warp=" << tokenId << ", k=" << k;
}
}
}
// Additional field verification can be added here if needed
// For now, we just verify that the operation completed successfully
}
}
};
// Test class for launchSingleS2G function
class FusedMoeCommS2GTest : public FusedMoeCommTestBase
{
protected:
void runS2GTest(int topK, bool hasScales, bool hasBasicFields, int recvFieldCount,
std::vector<size_t> const& elementSizes, std::vector<uint16_t> const& vectorSizes, int tokenCount = 4,
int warpsPerBlock = 2)
{
// Setup expert parallel info
MoeExpertParallelInfo expertParallelInfo;
expertParallelInfo.topK = topK;
expertParallelInfo.expertCount = 8;
// Setup recv field info
FusedMoeFieldInfo recvFieldInfo = {};
recvFieldInfo.isBasicInterleaved = false;
recvFieldInfo.fieldCount = recvFieldCount;
// Allocate token selected slots and expert scales if needed
int* hostTokenSlots = nullptr;
int* deviceTokenSlots = nullptr;
float* hostScales = nullptr;
float* deviceScales = nullptr;
if (hasBasicFields)
{
allocateAndInitializeData<int>(&hostTokenSlots, &deviceTokenSlots, tokenCount * topK,
[](size_t) { return 0; }); // Initialize to zero, will be filled by S2G
recvFieldInfo.tokenSelectedSlots = deviceTokenSlots;
if (hasScales)
{
allocateAndInitializeData<float>(&hostScales, &deviceScales, tokenCount * topK,
[](size_t) { return 0.0f; }); // Initialize to zero, will be filled by S2G
recvFieldInfo.expertScales = deviceScales;
}
}
// Setup recv field info using new fillFieldInfo helper
std::vector<void*> hostFieldPtrs(recvFieldCount);
std::vector<void*> deviceFieldPtrs(recvFieldCount);
for (int i = 0; i < recvFieldCount; i++)
{
size_t elementSize = elementSizes[i % elementSizes.size()];
uint16_t vectorSize = vectorSizes[i % vectorSizes.size()];
size_t fieldSize = elementSize * vectorSize * tokenCount;
// Allocate field data (initialize to zero, will be filled by S2G)
uint8_t* hostField;
uint8_t* deviceField;
allocateAndInitializeData<uint8_t>(
&hostField, &deviceField, fieldSize, [](size_t) { return static_cast<uint8_t>(0); });
hostFieldPtrs[i] = hostField;
deviceFieldPtrs[i] = deviceField;
// Use the new fillFieldInfo helper function
recvFieldInfo.fieldsInfo[i].fillFieldInfo(
deviceField, elementSize, vectorSize, vectorSize, getCudaDataType(elementSize));
}
// Fill field placement info
recvFieldInfo.fillFieldPlacementInfo(topK, hasBasicFields);
// Compute shared memory size and prepare input data
int warpShmSize = recvFieldInfo.computeSingleUncompactSize(topK, hasScales, hasBasicFields);
size_t shmPreloadSize = tokenCount * warpShmSize;
size_t shmPreloadIntCount = shmPreloadSize / sizeof(int);
int* hostShmPreload;
int* deviceShmPreload;
allocateAndInitializeData<int>(&hostShmPreload, &deviceShmPreload, shmPreloadIntCount,
[this, topK, hasScales, hasBasicFields, shmPreloadIntCount](size_t idx)
{ return this->generateShmPreloadData(idx, topK, hasScales, hasBasicFields, shmPreloadIntCount); });
// Launch S2G kernel with new signature
fused_moe_comm_tests::launchSingleS2G(
recvFieldInfo, expertParallelInfo, tokenCount, deviceShmPreload, warpsPerBlock, hasBasicFields, stream);
TLLM_CUDA_CHECK(cudaStreamSynchronize(stream));
// Copy back results only if hasBasicFields
int* resultTokenSlots = nullptr;
float* resultScales = nullptr;
if (hasBasicFields)
{
resultTokenSlots = new int[tokenCount * topK];
TLLM_CUDA_CHECK(cudaMemcpy(
resultTokenSlots, deviceTokenSlots, tokenCount * topK * sizeof(int), cudaMemcpyDeviceToHost));
if (hasScales)
{
resultScales = new float[tokenCount * topK];
TLLM_CUDA_CHECK(
cudaMemcpy(resultScales, deviceScales, tokenCount * topK * sizeof(float), cudaMemcpyDeviceToHost));
}
}
// Verify results
verifyS2GResults(resultTokenSlots, resultScales, hostShmPreload, topK, hasScales, hasBasicFields, tokenCount,
warpsPerBlock, warpShmSize);
// Cleanup
if (hasBasicFields)
{
cleanup(hostTokenSlots, deviceTokenSlots);
if (hasScales)
{
cleanup(hostScales, deviceScales);
}
}
for (int i = 0; i < recvFieldCount; i++)
{
cleanup(hostFieldPtrs[i], deviceFieldPtrs[i]);
}
cleanup(hostShmPreload, deviceShmPreload);
if (resultTokenSlots)
{
delete[] resultTokenSlots;
}
if (resultScales)
{
delete[] resultScales;
}
}
private:
int generateShmPreloadData(size_t idx, int topK, bool hasScales, bool hasBasicFields, int shmPreloadIntCount)
{
size_t warpIdx = idx / shmPreloadIntCount;
size_t offsetInWarp = idx % shmPreloadIntCount;
if (hasBasicFields)
{
if (offsetInWarp < topK)
{
// Token slots area
return static_cast<int>(warpIdx * 10 + offsetInWarp);
}
else if (hasScales && offsetInWarp < topK * 2)
{
// Scales area
float scale
= 1.0f + static_cast<float>(warpIdx) * 0.1f + static_cast<float>(offsetInWarp - topK) * 0.01f;
return *reinterpret_cast<int*>(&scale);
}
else
{
// Other field data
return static_cast<int>((warpIdx * 1000 + offsetInWarp) % 128);
}
}
else
{
// Only field data when no basic fields
return static_cast<int>((warpIdx * 1000 + offsetInWarp) % 128);
}
}
void verifyS2GResults(int const* resultTokenSlots, float const* resultScales, int const* shmPreloadData, int topK,
bool hasScales, bool hasBasicFields, int tokenCount, int warpsPerBlock, int warpShmSize)
{
if (!hasBasicFields)
{
// For non-basic fields tests, just verify that the operation completed successfully
// without errors. The actual field data verification would require more complex setup.
return;
}
for (int tokenId = 0; tokenId < tokenCount; tokenId++)
{
int const* warpShmData = shmPreloadData + tokenId * warpShmSize / sizeof(int);
// Verify token slots were written correctly
if (resultTokenSlots)
{
for (int k = 0; k < topK; k++)
{
int expected = warpShmData[k];
int actual = resultTokenSlots[tokenId * topK + k];
EXPECT_EQ(expected, actual) << "Token slot mismatch at warp=" << tokenId << ", k=" << k;
}
}
// Verify scales if present
if (hasScales && resultScales)
{
for (int k = 0; k < topK; k++)
{
float expected = reinterpret_cast<float const*>(warpShmData)[topK + k];
float actual = resultScales[tokenId * topK + k];
EXPECT_NEAR(expected, actual, 1e-6f) << "Scale mismatch at warp=" << tokenId << ", k=" << k;
}
}
}
}
};
// Test class for launchLoopback function (loopback test)
class FusedMoeCommLoopbackTest : public FusedMoeCommTestBase
{
protected:
void runLoopbackTest(int topK, bool hasScales, bool hasBasicFields, int fieldCount,
std::vector<size_t> const& elementSizes, std::vector<uint16_t> const& vectorSizes,
std::vector<int> const& recvIndexMappingVec, int tokenCount = 4, int warpsPerBlock = 2)
{
// Setup expert parallel info
MoeExpertParallelInfo expertParallelInfo;
expertParallelInfo.topK = topK;
expertParallelInfo.expertCount = 8;
// Setup field info - for loopback test, send and recv fields should be identical
FusedMoeFieldInfo sendFieldInfo = {};
sendFieldInfo.isBasicInterleaved = false;
sendFieldInfo.fieldCount = fieldCount;
FusedMoeFieldInfo recvFieldInfo = {};
recvFieldInfo.isBasicInterleaved = false;
recvFieldInfo.fieldCount = fieldCount;
// Allocate token selected slots and expert scales if needed
int* hostSendTokenSlots = nullptr;
int* deviceSendTokenSlots = nullptr;
float* hostSendScales = nullptr;
float* deviceSendScales = nullptr;
int* hostRecvTokenSlots = nullptr;
int* deviceRecvTokenSlots = nullptr;
float* hostRecvScales = nullptr;
float* deviceRecvScales = nullptr;
if (hasBasicFields)
{
// Send side basic fields
allocateAndInitializeData<int>(&hostSendTokenSlots, &deviceSendTokenSlots, tokenCount * topK,
[](size_t i) { return static_cast<int>(i % 8); });
sendFieldInfo.tokenSelectedSlots = deviceSendTokenSlots;
// Recv side basic fields (initialized to zero, will be filled by loopback)
allocateAndInitializeData<int>(
&hostRecvTokenSlots, &deviceRecvTokenSlots, tokenCount * topK, [](size_t) { return 0; });
recvFieldInfo.tokenSelectedSlots = deviceRecvTokenSlots;
if (hasScales)
{
allocateAndInitializeData<float>(&hostSendScales, &deviceSendScales, tokenCount * topK,
[](size_t i) -> float { return 1.0f + static_cast<float>(i) * 0.1f; });
sendFieldInfo.expertScales = deviceSendScales;
allocateAndInitializeData<float>(
&hostRecvScales, &deviceRecvScales, tokenCount * topK, [](size_t) { return 0.0f; });
recvFieldInfo.expertScales = deviceRecvScales;
}
}
// Setup field info - both send and recv use same layout for loopback
std::vector<void*> hostSendFieldPtrs(fieldCount);
std::vector<void*> deviceSendFieldPtrs(fieldCount);
std::vector<void*> hostRecvFieldPtrs(fieldCount);
std::vector<void*> deviceRecvFieldPtrs(fieldCount);
for (int i = 0; i < fieldCount; i++)
{
size_t elementSize = elementSizes[i % elementSizes.size()];
uint16_t vectorSize = vectorSizes[i % vectorSizes.size()];
size_t fieldSize = elementSize * vectorSize * tokenCount;
// Allocate send field data with specific pattern
uint8_t* hostSendField;
uint8_t* deviceSendField;
allocateAndInitializeData<uint8_t>(&hostSendField, &deviceSendField, fieldSize,
[i](size_t idx) { return static_cast<uint8_t>((i * 100 + idx + 1) % 128); });
// Allocate recv field data (initially zero, will be filled by loopback)
uint8_t* hostRecvField;
uint8_t* deviceRecvField;
allocateAndInitializeData<uint8_t>(
&hostRecvField, &deviceRecvField, fieldSize, [](size_t) { return static_cast<uint8_t>(0); });
hostSendFieldPtrs[i] = hostSendField;
deviceSendFieldPtrs[i] = deviceSendField;
hostRecvFieldPtrs[i] = hostRecvField;
deviceRecvFieldPtrs[i] = deviceRecvField;
// Fill field info for both send and recv
sendFieldInfo.fieldsInfo[i].fillFieldInfo(
deviceSendField, elementSize, vectorSize, vectorSize, getCudaDataType(elementSize));
recvFieldInfo.fieldsInfo[i].fillFieldInfo(
deviceRecvField, elementSize, vectorSize, vectorSize, getCudaDataType(elementSize));
}
// Fill field placement info
sendFieldInfo.fillFieldPlacementInfo(topK, hasBasicFields);
recvFieldInfo.fillFieldPlacementInfo(topK, hasBasicFields);
// Setup recvIndexMapping - ensure one-to-one mapping
std::vector<int> fullMapping = generateOneToOneMapping(recvIndexMappingVec, tokenCount);
int* hostRecvIndexMapping;
int* deviceRecvIndexMapping;
allocateAndInitializeData<int>(&hostRecvIndexMapping, &deviceRecvIndexMapping, tokenCount,
[&fullMapping](size_t i) { return fullMapping[i]; });
// Launch loopback kernel
fused_moe_comm_tests::launchLoopback(sendFieldInfo, recvFieldInfo, expertParallelInfo, deviceRecvIndexMapping,
tokenCount, warpsPerBlock, hasBasicFields, stream);
TLLM_CUDA_CHECK(cudaStreamSynchronize(stream));
// Copy back results and verify
verifyLoopbackResults(hostSendTokenSlots, hostSendScales, hostSendFieldPtrs, hostRecvFieldPtrs,
deviceRecvTokenSlots, deviceRecvScales, deviceRecvFieldPtrs, fullMapping, topK, hasScales, hasBasicFields,
fieldCount, elementSizes, vectorSizes, tokenCount);
// Cleanup
if (hasBasicFields)
{
cleanup(hostSendTokenSlots, deviceSendTokenSlots);
cleanup(hostRecvTokenSlots, deviceRecvTokenSlots);
if (hasScales)
{
cleanup(hostSendScales, deviceSendScales);
cleanup(hostRecvScales, deviceRecvScales);
}
}
for (int i = 0; i < fieldCount; i++)
{
cleanup(hostSendFieldPtrs[i], deviceSendFieldPtrs[i]);
cleanup(hostRecvFieldPtrs[i], deviceRecvFieldPtrs[i]);
}
cleanup(hostRecvIndexMapping, deviceRecvIndexMapping);
}
private:
void verifyLoopbackResults(int const* expectedSendTokenSlots, float const* expectedSendScales,
std::vector<void*> const& expectedSendFields, std::vector<void*> const& hostRecvFields,
int* deviceRecvTokenSlots, float* deviceRecvScales, std::vector<void*> const& deviceRecvFields,
std::vector<int> const& fullMapping, int topK, bool hasScales, bool hasBasicFields, int fieldCount,
std::vector<size_t> const& elementSizes, std::vector<uint16_t> const& vectorSizes, int tokenCount)
{
// Copy back device results for verification
int* resultRecvTokenSlots = nullptr;
float* resultRecvScales = nullptr;
if (hasBasicFields)
{
resultRecvTokenSlots = new int[tokenCount * topK];
TLLM_CUDA_CHECK(cudaMemcpy(
resultRecvTokenSlots, deviceRecvTokenSlots, tokenCount * topK * sizeof(int), cudaMemcpyDeviceToHost));
if (hasScales)
{
resultRecvScales = new float[tokenCount * topK];
TLLM_CUDA_CHECK(cudaMemcpy(
resultRecvScales, deviceRecvScales, tokenCount * topK * sizeof(float), cudaMemcpyDeviceToHost));
}
}
// Copy back field data
std::vector<uint8_t*> resultRecvFields(fieldCount);
for (int i = 0; i < fieldCount; i++)
{
size_t elementSize = elementSizes[i % elementSizes.size()];
uint16_t vectorSize = vectorSizes[i % vectorSizes.size()];
size_t fieldSize = elementSize * vectorSize * tokenCount;
resultRecvFields[i] = new uint8_t[fieldSize];
TLLM_CUDA_CHECK(cudaMemcpy(resultRecvFields[i], deviceRecvFields[i], fieldSize, cudaMemcpyDeviceToHost));
}
// Verify the loopback: recv[fullMapping[sendIndex]] should equal send[sendIndex]
int tokenSlotErrorCount = 0;
int scaleErrorCount = 0;
std::vector<int> fieldErrorCounts(fieldCount, 0);
for (int sendIndex = 0; sendIndex < tokenCount; sendIndex++)
{
int recvIndex = fullMapping[sendIndex];
ASSERT_GE(recvIndex, 0) << "Invalid recv index mapping at " << sendIndex;
ASSERT_LT(recvIndex, tokenCount) << "Recv index out of bounds at " << sendIndex;
// Verify basic fields if present
if (hasBasicFields)
{
// Verify token slots
if (expectedSendTokenSlots && resultRecvTokenSlots)
{
for (int k = 0; k < topK; k++)
{
int expected = expectedSendTokenSlots[sendIndex * topK + k];
int actual = resultRecvTokenSlots[recvIndex * topK + k];
EXPECT_EQ(expected, actual) << "Token slot loopback mismatch: send[" << sendIndex << "][" << k
<< "] -> recv[" << recvIndex << "][" << k << "]";
}
}
// Verify scales if present
if (hasScales && expectedSendScales && resultRecvScales)
{
for (int k = 0; k < topK; k++)
{
float expected = expectedSendScales[sendIndex * topK + k];
float actual = resultRecvScales[recvIndex * topK + k];
EXPECT_NEAR(expected, actual, 1e-6f) << "Scale loopback mismatch: send[" << sendIndex << "]["
<< k << "] -> recv[" << recvIndex << "][" << k << "]";
}
}
}
// Verify field data
for (int fieldIdx = 0; fieldIdx < fieldCount; fieldIdx++)
{
size_t elementSize = elementSizes[fieldIdx % elementSizes.size()];
uint16_t vectorSize = vectorSizes[fieldIdx % vectorSizes.size()];
size_t fieldSize = elementSize * vectorSize;
uint8_t const* expectedSendField = static_cast<uint8_t const*>(expectedSendFields[fieldIdx]);
uint8_t const* actualRecvField = resultRecvFields[fieldIdx];
for (size_t byteIdx = 0; byteIdx < fieldSize; byteIdx++)
{
uint8_t expected = expectedSendField[sendIndex * fieldSize + byteIdx];
uint8_t actual = actualRecvField[recvIndex * fieldSize + byteIdx];
EXPECT_EQ(expected, actual)
<< "Field loopback mismatch: field[" << fieldIdx << "] send[" << sendIndex << "][" << byteIdx
<< "] -> recv[" << recvIndex << "][" << byteIdx << "]";
}
}
}
// Cleanup temporary arrays
if (resultRecvTokenSlots)
delete[] resultRecvTokenSlots;
if (resultRecvScales)
delete[] resultRecvScales;
for (int i = 0; i < fieldCount; i++)
{
if (resultRecvFields[i])
delete[] resultRecvFields[i];
}
}
};
// Tests for G2S functionality
TEST_F(FusedMoeCommG2STest, BasicG2SWithoutScales)
{
runG2STest(2, false, true, 1, {4}, {64}); // topK=2, no scales, has basic fields, 1 field, 4-byte elements, 64 units
}
TEST_F(FusedMoeCommG2STest, BasicG2SWithScales)
{
runG2STest(
4, true, true, 1, {4}, {32}); // topK=4, with scales, has basic fields, 1 field, 4-byte elements, 32 units
}
TEST_F(FusedMoeCommG2STest, MultipleFieldsVariousAlignments)
{
runG2STest(2, true, true, 3, {1, 2, 4}, {16, 32, 64}); // Multiple fields with different element sizes
}
TEST_F(FusedMoeCommG2STest, LargeTopK)
{
runG2STest(8, true, true, 2, {4, 8}, {128, 256}); // Large topK value
}
TEST_F(FusedMoeCommG2STest, PerfectAlignmentFields)
{
runG2STest(4, false, true, 2, {16}, {32}); // 16-byte aligned fields
}
TEST_F(FusedMoeCommG2STest, MixedAlignmentTypes)
{
runG2STest(3, true, true, 4, {8, 4, 2, 1}, {64, 32, 16, 8}); // All alignment types
runG2STest(3, true, true, 4, {1, 2, 4, 8}, {63, 30, 17, 9}, 32); // All alignment types
}
TEST_F(FusedMoeCommG2STest, SingleByteAlignment)
{
runG2STest(2, false, true, 2, {1}, {128}); // Single byte alignment
}
TEST_F(FusedMoeCommG2STest, EdgeCaseTopKOne)
{
runG2STest(1, false, true, 1, {4}, {16}); // Minimal topK
}
TEST_F(FusedMoeCommG2STest, EdgeCaseNoExtraFields)
{
runG2STest(2, true, true, 0, {}, {}); // Only basic fields (token slots + scales)
}
TEST_F(FusedMoeCommG2STest, LargeTokenCount)
{
runG2STest(4, true, true, 2, {4, 8}, {64, 128}, 16, 4); // 16 tokens, 4 warps per block
}
// New tests for no basic fields scenario
TEST_F(FusedMoeCommG2STest, G2SWithoutBasicFields)
{
runG2STest(0, false, false, 2, {4, 8}, {32, 64}); // No basic fields, only field data
}
TEST_F(FusedMoeCommG2STest, G2SWithoutBasicFieldsLargeFields)
{
runG2STest(0, false, false, 3, {1, 4, 16}, {128, 256, 512}); // No basic fields, large field data
}
// Tests for S2G functionality
TEST_F(FusedMoeCommS2GTest, BasicS2GWithoutScales)
{
runS2GTest(2, false, true, 1, {4}, {64}); // topK=2, no scales, has basic fields, 1 field, 4-byte elements
}
TEST_F(FusedMoeCommS2GTest, BasicS2GWithScales)
{
runS2GTest(4, true, true, 1, {4}, {32}); // topK=4, with scales, has basic fields, 1 field, 4-byte elements
}
TEST_F(FusedMoeCommS2GTest, MultipleFieldsVariousAlignments)
{
runS2GTest(2, true, true, 3, {1, 2, 4}, {16, 32, 64}); // Multiple fields with different element sizes
}
TEST_F(FusedMoeCommS2GTest, LargeTopK)
{
runS2GTest(8, true, true, 2, {4, 8}, {128, 256}); // Large topK value
}
TEST_F(FusedMoeCommS2GTest, PerfectAlignmentFields)
{
runS2GTest(4, false, true, 2, {16}, {32}); // 16-byte aligned fields
}
TEST_F(FusedMoeCommS2GTest, MixedAlignmentTypes)
{
runS2GTest(3, true, true, 4, {1, 2, 4, 8}, {8, 16, 32, 64}); // All alignment types
runS2GTest(3, true, true, 4, {1, 2, 4, 8}, {63, 30, 17, 9}, 32); // All alignment types
}
TEST_F(FusedMoeCommS2GTest, SingleByteAlignment)
{
runS2GTest(2, false, true, 2, {1}, {128}); // Single byte alignment
}
TEST_F(FusedMoeCommS2GTest, EdgeCaseTopKOne)
{
runS2GTest(1, false, true, 1, {4}, {16}); // Minimal topK
}
TEST_F(FusedMoeCommS2GTest, EdgeCaseNoExtraFields)
{
runS2GTest(2, true, true, 0, {}, {}); // Only basic fields (token slots + scales)
}
TEST_F(FusedMoeCommS2GTest, LargeTokenCount)
{
runS2GTest(4, true, true, 2, {4, 8}, {64, 128}, 16, 4); // 16 tokens, 4 warps per block
}
// New tests for no basic fields scenario
TEST_F(FusedMoeCommS2GTest, S2GWithoutBasicFields)
{
runS2GTest(0, false, false, 2, {4, 8}, {32, 64}); // No basic fields, only field data
}
TEST_F(FusedMoeCommS2GTest, S2GWithoutBasicFieldsLargeFields)
{
runS2GTest(0, false, false, 3, {1, 4, 16}, {128, 256, 512}); // No basic fields, large field data
}
// Tests for G2S+Pack+Unpack+S2G loopback functionality
TEST_F(FusedMoeCommLoopbackTest, BasicLoopbackWithoutScales)
{
std::vector<int> mapping = {0, 1, 2, 3}; // Identity mapping
runLoopbackTest(2, false, true, 1, {4}, {64}, mapping);
}
TEST_F(FusedMoeCommLoopbackTest, BasicLoopbackWithScales)
{
std::vector<int> mapping = {0, 1, 2, 3}; // Identity mapping
runLoopbackTest(4, true, true, 1, {4}, {32}, mapping);
}
TEST_F(FusedMoeCommLoopbackTest, LoopbackWithReordering)
{
std::vector<int> mapping = {3, 0, 2, 1}; // Reorder mapping
runLoopbackTest(2, true, true, 2, {4, 8}, {32, 64}, mapping);
}
TEST_F(FusedMoeCommLoopbackTest, LoopbackWithReverseMapping)
{
std::vector<int> mapping = {3, 2, 1, 0}; // Reverse mapping
runLoopbackTest(3, false, true, 1, {2}, {128}, mapping);
}
TEST_F(FusedMoeCommLoopbackTest, LoopbackMultipleFieldsVariousAlignments)
{
std::vector<int> mapping = {1, 3, 0, 2}; // Complex reordering
runLoopbackTest(2, true, true, 3, {1, 2, 4}, {16, 32, 64}, mapping);
}
TEST_F(FusedMoeCommLoopbackTest, LoopbackLargeTopK)
{
std::vector<int> mapping = {2, 0, 3, 1}; // Reorder mapping
runLoopbackTest(8, true, true, 2, {4, 8}, {128, 256}, mapping);
}
TEST_F(FusedMoeCommLoopbackTest, LoopbackPerfectAlignmentFields)
{
std::vector<int> mapping = {0, 2, 1, 3}; // Partial reordering
runLoopbackTest(4, false, true, 2, {16}, {32}, mapping);
}
TEST_F(FusedMoeCommLoopbackTest, LoopbackMixedAlignmentTypes)
{
std::vector<int> mapping = {1, 0, 3, 2}; // Pair swap
runLoopbackTest(3, true, true, 4, {1, 2, 4, 8}, {8, 16, 32, 64}, mapping);
}
TEST_F(FusedMoeCommLoopbackTest, LoopbackSingleByteAlignment)
{
std::vector<int> mapping = {2, 3, 0, 1}; // Cyclic shift
runLoopbackTest(2, false, true, 2, {1}, {128}, mapping);
}
TEST_F(FusedMoeCommLoopbackTest, LoopbackEdgeCaseTopKOne)
{
std::vector<int> mapping = {1, 0, 3, 2}; // Simple reordering
runLoopbackTest(1, false, true, 1, {4}, {16}, mapping);
}
TEST_F(FusedMoeCommLoopbackTest, LoopbackEdgeCaseNoExtraFields)
{
std::vector<int> mapping = {3, 1, 0, 2}; // Random reordering
runLoopbackTest(2, true, true, 0, {}, {}, mapping); // Only basic fields
}
TEST_F(FusedMoeCommLoopbackTest, LoopbackLargeTokenCount)
{
std::vector<int> mapping = {7, 0, 5, 2, 3, 6, 1, 4, 15, 8, 11, 10, 9, 14, 13, 12}; // Complex 16-token mapping
runLoopbackTest(4, true, true, 2, {4, 8}, {64, 128}, mapping, 16, 4);
}
// New tests for no basic fields scenario
TEST_F(FusedMoeCommLoopbackTest, LoopbackWithoutBasicFields)
{
std::vector<int> mapping = {1, 3, 0, 2}; // Reorder mapping
runLoopbackTest(0, false, false, 2, {4, 8}, {32, 64}, mapping); // No basic fields, only field data
}
TEST_F(FusedMoeCommLoopbackTest, LoopbackWithoutBasicFieldsLargeFields)
{
std::vector<int> mapping = {2, 0, 3, 1}; // Reorder mapping
runLoopbackTest(0, false, false, 3, {1, 4, 16}, {128, 256, 512}, mapping); // No basic fields, large field data
}
// Test class for launchLocalFifoSendRecv function (FIFO-based local send/recv test)
class FusedMoeCommLocalFifoSendRecvTest : public FusedMoeCommTestBase
{
protected:
void runLocalFifoSendRecvTest(int topK, bool hasScales, bool hasBasicFields, int fieldCount,
std::vector<size_t> const& elementSizes, std::vector<uint16_t> const& vectorSizes,
std::vector<int> const& sendIndexMappingVec, std::vector<int> const& recvIndexMappingVec, int tokenCount = 4,
int warpsPerBlock = 2, int blockChannelCount = 1)
{
// Setup expert parallel info
MoeExpertParallelInfo expertParallelInfo;
expertParallelInfo.topK = topK;
expertParallelInfo.expertCount = 8;
// Setup field info for send and receive sides
FusedMoeFieldInfo sendFieldInfo = {};
sendFieldInfo.isBasicInterleaved = false;
sendFieldInfo.fieldCount = fieldCount;
FusedMoeFieldInfo recvFieldInfo = {};
recvFieldInfo.isBasicInterleaved = false;
recvFieldInfo.fieldCount = fieldCount;
// Allocate token selected slots and expert scales if needed
int* hostSendTokenSlots = nullptr;
int* deviceSendTokenSlots = nullptr;
float* hostSendScales = nullptr;
float* deviceSendScales = nullptr;
int* hostRecvTokenSlots = nullptr;
int* deviceRecvTokenSlots = nullptr;
float* hostRecvScales = nullptr;
float* deviceRecvScales = nullptr;
if (hasBasicFields)
{
// Send side basic fields
allocateAndInitializeData<int>(&hostSendTokenSlots, &deviceSendTokenSlots, tokenCount * topK,
[](size_t i) { return static_cast<int>(i % 8); });
sendFieldInfo.tokenSelectedSlots = deviceSendTokenSlots;
// Recv side basic fields (initialized to zero, will be filled by communication)
allocateAndInitializeData<int>(
&hostRecvTokenSlots, &deviceRecvTokenSlots, tokenCount * topK, [](size_t) { return 0; });
recvFieldInfo.tokenSelectedSlots = deviceRecvTokenSlots;
if (hasScales)
{
allocateAndInitializeData<float>(&hostSendScales, &deviceSendScales, tokenCount * topK,
[](size_t i) -> float { return 1.0f + static_cast<float>(i) * 0.1f; });
sendFieldInfo.expertScales = deviceSendScales;
allocateAndInitializeData<float>(
&hostRecvScales, &deviceRecvScales, tokenCount * topK, [](size_t) { return 0.0f; });
recvFieldInfo.expertScales = deviceRecvScales;
}
}
// Setup field info for additional fields
std::vector<void*> hostSendFieldPtrs(fieldCount);
std::vector<void*> deviceSendFieldPtrs(fieldCount);
std::vector<void*> hostRecvFieldPtrs(fieldCount);
std::vector<void*> deviceRecvFieldPtrs(fieldCount);
for (int i = 0; i < fieldCount; i++)
{
size_t elementSize = elementSizes[i % elementSizes.size()];
uint16_t vectorSize = vectorSizes[i % vectorSizes.size()];
size_t fieldSize = elementSize * vectorSize * tokenCount;
// Allocate send field data with specific pattern
uint8_t* hostSendField;
uint8_t* deviceSendField;
allocateAndInitializeData<uint8_t>(&hostSendField, &deviceSendField, fieldSize,
[i](size_t idx) { return static_cast<uint8_t>((i * 100 + idx + 1) % 128); });
// Allocate recv field data (initially zero, will be filled by communication)
uint8_t* hostRecvField;
uint8_t* deviceRecvField;
allocateAndInitializeData<uint8_t>(
&hostRecvField, &deviceRecvField, fieldSize, [](size_t) { return static_cast<uint8_t>(0); });
hostSendFieldPtrs[i] = hostSendField;
deviceSendFieldPtrs[i] = deviceSendField;
hostRecvFieldPtrs[i] = hostRecvField;
deviceRecvFieldPtrs[i] = deviceRecvField;
// Fill field info
sendFieldInfo.fieldsInfo[i].fillFieldInfo(
deviceSendField, elementSize, vectorSize, vectorSize, getCudaDataType(elementSize));
recvFieldInfo.fieldsInfo[i].fillFieldInfo(
deviceRecvField, elementSize, vectorSize, vectorSize, getCudaDataType(elementSize));
}
// Fill field placement info
sendFieldInfo.fillFieldPlacementInfo(topK, hasBasicFields);
recvFieldInfo.fillFieldPlacementInfo(topK, hasBasicFields);
// Setup sendIndexMapping and recvIndexMapping - ensure one-to-one mappings
std::vector<int> fullSendMapping = generateOneToOneMapping(sendIndexMappingVec, tokenCount);
std::vector<int> fullRecvMapping = generateOneToOneMapping(recvIndexMappingVec, tokenCount);
int* hostSendIndexMapping;
int* deviceSendIndexMapping;
int* hostRecvIndexMapping;
int* deviceRecvIndexMapping;
allocateAndInitializeData<int>(&hostSendIndexMapping, &deviceSendIndexMapping, tokenCount,
[&fullSendMapping](size_t i) { return fullSendMapping[i]; });
allocateAndInitializeData<int>(&hostRecvIndexMapping, &deviceRecvIndexMapping, tokenCount,
[&fullRecvMapping](size_t i) { return fullRecvMapping[i]; });
// Setup workspace for FIFO communication
FusedMoeWorkspace fusedMoeWorkspace;
int totalChannelCount = blockChannelCount * warpsPerBlock;
size_t workspaceSizePerRank = FusedMoeWorkspace::computeWorkspaceSizePreRank(1, totalChannelCount);
size_t totalWorkspaceSize = workspaceSizePerRank;
fusedMoeWorkspace.rankStrideInU64 = workspaceSizePerRank / sizeof(uint64_t);
fusedMoeWorkspace.channelCount = totalChannelCount;
TLLM_CUDA_CHECK(cudaMalloc(&fusedMoeWorkspace.workspacePtr, totalWorkspaceSize));
// Initialize workspace
FusedMoeWorldInfo worldInfo;
worldInfo.epInfo.epRank = 0;
worldInfo.epInfo.epSize = 1;
fusedMoeWorkspace.initializeLocalWorkspace(worldInfo);
// Launch FIFO send/recv kernel
fused_moe_comm_tests::launchLocalFifoSendRecv(sendFieldInfo, recvFieldInfo, expertParallelInfo,
deviceSendIndexMapping, deviceRecvIndexMapping, fusedMoeWorkspace, tokenCount, warpsPerBlock,
blockChannelCount, hasBasicFields, stream);
TLLM_CUDA_CHECK(cudaStreamSynchronize(stream));
// Copy back results and verify
verifyLocalFifoSendRecvResults(hostSendTokenSlots, hostSendScales, hostSendFieldPtrs, hostRecvFieldPtrs,
deviceRecvTokenSlots, deviceRecvScales, deviceRecvFieldPtrs, fullSendMapping, fullRecvMapping, topK,
hasScales, hasBasicFields, fieldCount, elementSizes, vectorSizes, tokenCount);
// Cleanup
if (hasBasicFields)
{
cleanup(hostSendTokenSlots, deviceSendTokenSlots);
cleanup(hostRecvTokenSlots, deviceRecvTokenSlots);
if (hasScales)
{
cleanup(hostSendScales, deviceSendScales);
cleanup(hostRecvScales, deviceRecvScales);
}
}
for (int i = 0; i < fieldCount; i++)
{
cleanup(hostSendFieldPtrs[i], deviceSendFieldPtrs[i]);
cleanup(hostRecvFieldPtrs[i], deviceRecvFieldPtrs[i]);
}
cleanup(hostSendIndexMapping, deviceSendIndexMapping);
cleanup(hostRecvIndexMapping, deviceRecvIndexMapping);
TLLM_CUDA_CHECK(cudaFree(fusedMoeWorkspace.workspacePtr));
}
private:
void verifyLocalFifoSendRecvResults(int const* expectedSendTokenSlots, float const* expectedSendScales,
std::vector<void*> const& expectedSendFields, std::vector<void*> const& hostRecvFields,
int* deviceRecvTokenSlots, float* deviceRecvScales, std::vector<void*> const& deviceRecvFields,
std::vector<int> const& fullSendMapping, std::vector<int> const& fullRecvMapping, int topK, bool hasScales,
bool hasBasicFields, int fieldCount, std::vector<size_t> const& elementSizes,
std::vector<uint16_t> const& vectorSizes, int tokenCount)
{
// Copy back device results for verification
int* resultRecvTokenSlots = nullptr;
float* resultRecvScales = nullptr;
if (hasBasicFields)
{
resultRecvTokenSlots = new int[tokenCount * topK];
TLLM_CUDA_CHECK(cudaMemcpy(
resultRecvTokenSlots, deviceRecvTokenSlots, tokenCount * topK * sizeof(int), cudaMemcpyDeviceToHost));
if (hasScales)
{
resultRecvScales = new float[tokenCount * topK];
TLLM_CUDA_CHECK(cudaMemcpy(
resultRecvScales, deviceRecvScales, tokenCount * topK * sizeof(float), cudaMemcpyDeviceToHost));
}
}
// Copy back field data
std::vector<uint8_t*> resultRecvFields(fieldCount);
for (int i = 0; i < fieldCount; i++)
{
size_t elementSize = elementSizes[i % elementSizes.size()];
uint16_t vectorSize = vectorSizes[i % vectorSizes.size()];
size_t fieldSize = elementSize * vectorSize * tokenCount;
resultRecvFields[i] = new uint8_t[fieldSize];
TLLM_CUDA_CHECK(cudaMemcpy(resultRecvFields[i], deviceRecvFields[i], fieldSize, cudaMemcpyDeviceToHost));
}
// Verify the FIFO send/recv with independent mappings:
// For logical index i:
// - Send side reads from fullSendMapping[i]
// - Recv side writes to fullRecvMapping[i]
// So we need to verify: recv[fullRecvMapping[i]] should equal send[fullSendMapping[i]]
int tokenSlotErrorCount = 0;
int scaleErrorCount = 0;
std::vector<int> fieldErrorCounts(fieldCount, 0);
for (int logicalIndex = 0; logicalIndex < tokenCount; logicalIndex++)
{
int actualSendIndex = fullSendMapping[logicalIndex];
int actualRecvIndex = fullRecvMapping[logicalIndex];
if (actualSendIndex < 0 || actualSendIndex >= tokenCount || actualRecvIndex < 0
|| actualRecvIndex >= tokenCount)
continue;
// Verify token selected slots
if (hasBasicFields)
{
for (int k = 0; k < topK; k++)
{
int expectedSlot = expectedSendTokenSlots[actualSendIndex * topK + k];
int actualSlot = resultRecvTokenSlots[actualRecvIndex * topK + k];
if (expectedSlot != actualSlot)
{
tokenSlotErrorCount++;
if (tokenSlotErrorCount <= 16)
{
EXPECT_EQ(expectedSlot, actualSlot)
<< "Token slot mismatch at logicalIndex=" << logicalIndex
<< ", actualSendIndex=" << actualSendIndex << ", actualRecvIndex=" << actualRecvIndex
<< ", k=" << k;
}
}
}
// Verify expert scales
if (hasScales)
{
for (int k = 0; k < topK; k++)
{
float expectedScale = expectedSendScales[actualSendIndex * topK + k];
float actualScale = resultRecvScales[actualRecvIndex * topK + k];
if (std::abs(expectedScale - actualScale) > 1e-6f)
{
scaleErrorCount++;
if (scaleErrorCount <= 16)
{
EXPECT_NEAR(expectedScale, actualScale, 1e-6f)
<< "Scale mismatch at logicalIndex=" << logicalIndex
<< ", actualSendIndex=" << actualSendIndex
<< ", actualRecvIndex=" << actualRecvIndex << ", k=" << k;
}
}
}
}
}
// Verify additional fields
for (int fieldIdx = 0; fieldIdx < fieldCount; fieldIdx++)
{
size_t elementSize = elementSizes[fieldIdx % elementSizes.size()];
uint16_t vectorSize = vectorSizes[fieldIdx % vectorSizes.size()];
size_t fieldSizePerToken = elementSize * vectorSize;
uint8_t const* expectedFieldData = static_cast<uint8_t const*>(expectedSendFields[fieldIdx]);
uint8_t const* actualFieldData = resultRecvFields[fieldIdx];
for (size_t byteIdx = 0; byteIdx < fieldSizePerToken; byteIdx++)
{
uint8_t expected = expectedFieldData[actualSendIndex * fieldSizePerToken + byteIdx];
uint8_t actual = actualFieldData[actualRecvIndex * fieldSizePerToken + byteIdx];
if (expected != actual)
{
fieldErrorCounts[fieldIdx]++;
if (fieldErrorCounts[fieldIdx] <= 16)
{
EXPECT_EQ(static_cast<int>(expected), static_cast<int>(actual))
<< "Field[" << fieldIdx << "] mismatch at logicalIndex=" << logicalIndex
<< ", actualSendIndex=" << actualSendIndex << ", actualRecvIndex=" << actualRecvIndex
<< ", byteIdx=" << byteIdx;
}
}
}
}
}
// Print error summary for counts exceeding 16
if (tokenSlotErrorCount > 16)
{
ADD_FAILURE() << "Token slot errors: Showed first 16 of " << tokenSlotErrorCount << " total mismatches.";
}
if (scaleErrorCount > 16)
{
ADD_FAILURE() << "Scale errors: Showed first 16 of " << scaleErrorCount << " total mismatches.";
}
for (int fieldIdx = 0; fieldIdx < fieldCount; fieldIdx++)
{
if (fieldErrorCounts[fieldIdx] > 16)
{
ADD_FAILURE() << "Field[" << fieldIdx << "] errors: Showed first 16 of " << fieldErrorCounts[fieldIdx]
<< " total mismatches.";
}
}
// Cleanup temporary arrays
if (resultRecvTokenSlots)
delete[] resultRecvTokenSlots;
if (resultRecvScales)
delete[] resultRecvScales;
for (int i = 0; i < fieldCount; i++)
{
if (resultRecvFields[i])
delete[] resultRecvFields[i];
}
}
};
// Tests for Local FIFO Send/Recv functionality with Packed Protocol
TEST_F(FusedMoeCommLocalFifoSendRecvTest, BasicFifoSendRecvPackedProtocol)
{
std::vector<int> sendMapping = {0, 1, 2, 3}; // Identity mapping for send
std::vector<int> recvMapping = {2, 3, 0, 1}; // Rotate mapping for recv
runLocalFifoSendRecvTest(2, false, true, 1, {4}, {64}, sendMapping, recvMapping, 4, 1, 1); // Packed protocol
}
TEST_F(FusedMoeCommLocalFifoSendRecvTest, BasicFifoSendRecvWithScalesPackedProtocol)
{
std::vector<int> sendMapping = {1, 2, 3, 0}; // Rotate send mapping
std::vector<int> recvMapping = {3, 0, 1, 2}; // Opposite rotation for recv
runLocalFifoSendRecvTest(
4, true, true, 1, {4}, {32}, sendMapping, recvMapping, 4, 2, 1); // With scales, Packed protocol
}
TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvWithReorderingPackedProtocol)
{
std::vector<int> sendMapping = {3, 0, 2, 1}; // Random send reorder
std::vector<int> recvMapping = {0, 3, 1, 2}; // Different recv reorder
runLocalFifoSendRecvTest(2, true, true, 2, {4, 8}, {32, 64}, sendMapping, recvMapping, 256, 2, 2);
}
TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvMultipleFieldsPackedProtocol)
{
std::vector<int> mapping = {1, 3, 0, 2}; // Complex reordering
runLocalFifoSendRecvTest(2, true, true, 3, {1, 2, 4}, {16, 32, 64}, mapping, mapping, 256, 2, 2);
}
TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvLargeTopKPackedProtocol)
{
std::vector<int> mapping = {2, 0, 3, 1}; // Reorder mapping
runLocalFifoSendRecvTest(8, true, true, 2, {4, 8}, {128, 256}, mapping, mapping, 512, 3, 2);
}
TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvWithoutBasicFieldsPackedProtocol)
{
std::vector<int> sendMapping = {1, 3, 0, 2}; // Send reorder mapping
std::vector<int> recvMapping = {3, 2, 1, 0}; // Reverse recv mapping
runLocalFifoSendRecvTest(
0, false, false, 2, {4, 8}, {32, 64}, sendMapping, recvMapping, 256, 2, 2); // No basic fields, Packed protocol
}
// Mixed alignment tests
TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvMixedAlignmentsPackedProtocol)
{
std::vector<int> mapping = {1, 0, 3, 2}; // Pair swap
runLocalFifoSendRecvTest(3, true, true, 4, {1, 2, 4, 8}, {8, 16, 32, 64}, mapping, mapping, 512, 2, 2);
}
// Edge cases
TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvEdgeCaseTopKOnePackedProtocol)
{
std::vector<int> mapping = {1, 0, 3, 2}; // Simple reordering
runLocalFifoSendRecvTest(1, false, true, 1, {4}, {16}, mapping, mapping, 128, 2, 1);
}
// Only basic fields cases
TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvEdgeCaseNoExtraFieldsPackedProtocol)
{
std::vector<int> mapping = {3, 1, 0, 2}; // Random reordering
runLocalFifoSendRecvTest(2, true, true, 0, {}, {}, mapping, mapping, 256, 2, 2); // Only basic fields
}
// Large scale tests
TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvLargeTokenCountPackedProtocol)
{
std::vector<int> sendMapping = {7, 0, 5, 2, 3, 6, 1, 4, 15, 8, 11, 10, 9, 14, 13, 12}; // Complex send mapping
std::vector<int> recvMapping = {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; // Reverse recv mapping
runLocalFifoSendRecvTest(
4, true, true, 2, {4, 8}, {64, 128}, sendMapping, recvMapping, 1024, 3, 3); // Large scale test
}
// Perfect alignment tests
TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvPerfectAlignmentPackedProtocol)
{
std::vector<int> sendMapping = {2, 0, 3, 1}; // Different send reordering
std::vector<int> recvMapping = {1, 3, 0, 2}; // Different recv reordering
runLocalFifoSendRecvTest(4, false, true, 2, {16}, {32}, sendMapping, recvMapping, 256, 2, 3);
}
// Single byte alignment tests
TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvSmallSingleByteAlignmentPackedProtocol)
{
std::vector<int> mapping = {2, 3, 0, 1}; // Cyclic shift
runLocalFifoSendRecvTest(2, false, true, 1, {1}, {127}, mapping, mapping, 4, 1, 1);
}
TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvSingleByteAlignmentPackedProtocol)
{
std::vector<int> mapping = {2, 3, 0, 1}; // Cyclic shift
runLocalFifoSendRecvTest(2, false, true, 2, {1}, {127}, mapping, mapping, 256, 3, 1);
}
// Stress tests
TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvStressTestManyChannelsPackedProtocol)
{
std::vector<int> mapping = {7, 2, 5, 0, 3, 6, 1, 4}; // Complex mapping
runLocalFifoSendRecvTest(4, true, true, 2, {8, 16}, {128, 256}, mapping, mapping, 512, 3, 4); // Many channels
}
TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvStressTest2ManyChannelsPackedProtocol)
{
std::vector<int> mapping = {7, 2, 5, 0, 3, 6, 1, 4}; // Complex mapping
runLocalFifoSendRecvTest(
4, true, true, 2, {2, 4, 8, 16}, {7, 15, 31, 255}, mapping, mapping, 4096, 1, 2); // Many channels
}
TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvStressTestManyWarpsPackedProtocol)
{
std::vector<int> mapping = {1, 0, 3, 2}; // Simple reordering
runLocalFifoSendRecvTest(2, false, true, 1, {4}, {64}, mapping, mapping, 256, 4, 2); // Many warps per block
}