/* * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/common/ncclUtils.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/opUtils.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include #include #include #include #if ENABLE_MULTI_DEVICE && BUILD_PYT #include #endif #if ENABLE_MULTI_DEVICE namespace mpi = tensorrt_llm::mpi; namespace tr = tensorrt_llm::runtime; namespace nccl_util = tensorrt_llm::common::nccl_util; using tensorrt_llm::getComm; // Helper function to create a split communicator for testing // This allows us to test cleanup behavior explicitly by controlling the lifetime std::shared_ptr createSplitComm(ncclComm_t parentComm, int color, int key) { ncclComm_t newComm; ncclResult_t result = ncclCommSplit(parentComm, color, key, &newComm, nullptr); if (result != ncclSuccess) { TLLM_THROW("ncclCommSplit failed with error: %d", result); } // Create a shared_ptr with custom deleter that cleans up resources first return std::shared_ptr(new ncclComm_t(newComm), [](ncclComm_t* comm) { if (comm && *comm) { // STEP 1: Clean up all registered resources FIRST tensorrt_llm::common::nccl_util::NcclCommResourceManager::getInstance().cleanupResources(*comm); // STEP 2: Now destroy the NCCL communicator ncclResult_t result = ncclCommDestroy(*comm); if (result != ncclSuccess) { TLLM_LOG_WARNING("ncclCommDestroy failed with error: %d", result); } // STEP 3: Free the memory delete comm; } }); } //============================================================================== // NcclCommResourceManager Tests //============================================================================== class NcclCommResourceManagerTest : public ::testing::Test { protected: void SetUp() override { auto& comm = mpi::MpiComm::world(); mWorldSize = comm.getSize(); mRank = comm.getRank(); if (mWorldSize < 2) { GTEST_SKIP() << "Requires at least 2 ranks (got " << mWorldSize << ")"; } // Set CUDA device for this rank (required before NCCL initialization) int deviceCount = 0; TLLM_CUDA_CHECK(cudaGetDeviceCount(&deviceCount)); if (deviceCount > 0) { int deviceId = mRank % deviceCount; TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); } // Create a communicator for testing std::set group; for (int i = 0; i < mWorldSize; ++i) { group.insert(i); } mComm = getComm(group); } void TearDown() override { // Communicator cleanup happens automatically via shared_ptr deleter mComm.reset(); } int mWorldSize; int mRank; std::shared_ptr mComm; }; TEST_F(NcclCommResourceManagerTest, ResourceRegistration) { auto& manager = nccl_util::NcclCommResourceManager::getInstance(); // Create a separate comm using split for this test auto testComm = createSplitComm(*mComm, 0, mRank); // Register a resource bool cleanupCalled = false; manager.registerResource( *testComm, [&cleanupCalled]() { cleanupCalled = true; }, "TestResource"); EXPECT_TRUE(manager.hasResources(*testComm)); EXPECT_EQ(manager.getResourceCount(*testComm), 1); EXPECT_FALSE(cleanupCalled); // Cleanup not called yet // Store the raw comm value before destruction ncclComm_t rawComm = *testComm; // Cleanup should be called when comm is destroyed testComm.reset(); // Verify cleanup was called EXPECT_TRUE(cleanupCalled); // Verify cleanup: check that the old comm (now destroyed) no longer has resources // Note: The comm is destroyed, but we can still check the manager's internal state // The cleanup should have removed all resources for this comm EXPECT_FALSE(manager.hasResources(rawComm)); EXPECT_EQ(manager.getResourceCount(rawComm), 0); } TEST_F(NcclCommResourceManagerTest, MultipleResources) { auto& manager = nccl_util::NcclCommResourceManager::getInstance(); // Create a separate comm using split for this test auto testComm = createSplitComm(*mComm, 0, mRank); std::vector cleanupOrder; manager.registerResource( *testComm, [&cleanupOrder]() { cleanupOrder.push_back(1); }, "Resource1"); manager.registerResource( *testComm, [&cleanupOrder]() { cleanupOrder.push_back(2); }, "Resource2"); manager.registerResource( *testComm, [&cleanupOrder]() { cleanupOrder.push_back(3); }, "Resource3"); EXPECT_EQ(manager.getResourceCount(*testComm), 3); // Cleanup order should be preserved - destroy comm and verify order testComm.reset(); // Verify cleanup order was preserved (1, 2, 3) EXPECT_EQ(cleanupOrder.size(), 3); EXPECT_EQ(cleanupOrder[0], 1); EXPECT_EQ(cleanupOrder[1], 2); EXPECT_EQ(cleanupOrder[2], 3); } TEST_F(NcclCommResourceManagerTest, ResourceCount) { auto& manager = nccl_util::NcclCommResourceManager::getInstance(); // Create a separate comm using split for this test auto testComm = createSplitComm(*mComm, 0, mRank); EXPECT_FALSE(manager.hasResources(*testComm)); EXPECT_EQ(manager.getResourceCount(*testComm), 0); manager.registerResource( *testComm, []() {}, "Test1"); EXPECT_EQ(manager.getResourceCount(*testComm), 1); manager.registerResource( *testComm, []() {}, "Test2"); EXPECT_EQ(manager.getResourceCount(*testComm), 2); testComm.reset(); } //============================================================================== // NCCLWindowAllocator Tests //============================================================================== class NCCLWindowAllocatorTest : public ::testing::Test { protected: void SetUp() override { auto& comm = mpi::MpiComm::world(); mWorldSize = comm.getSize(); mRank = comm.getRank(); if (mWorldSize < 2) { GTEST_SKIP() << "Requires at least 2 ranks (got " << mWorldSize << ")"; } // Set CUDA device for this rank (required before NCCL initialization) int deviceCount = 0; TLLM_CUDA_CHECK(cudaGetDeviceCount(&deviceCount)); if (deviceCount > 0) { int deviceId = mRank % deviceCount; TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); } // Check if NCCL symmetric is supported auto& ncclHelper = nccl_util::NCCLHelper::getInstance(); if (!ncclHelper.isLoaded()) { GTEST_SKIP() << "NCCL library with symmetric memory support is not available"; } std::set group; for (int i = 0; i < mWorldSize; ++i) { group.insert(i); } mComm = getComm(group); } void TearDown() override { // Cleanup happens automatically mComm.reset(); } int mWorldSize; int mRank; std::shared_ptr mComm; }; TEST_F(NCCLWindowAllocatorTest, BasicAllocation) { auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); const size_t bufferSize = 1024 * 1024; // 1MB auto buffer = allocator.requestBuffer(*mComm, bufferSize); EXPECT_TRUE(buffer.isValid()); EXPECT_NE(buffer.ptr, nullptr); EXPECT_NE(buffer.window, nullptr); EXPECT_EQ(buffer.size, bufferSize); EXPECT_GE(buffer.handle, 0); // Verify we can search for it auto found = allocator.searchBuffer(*mComm, buffer.ptr); EXPECT_TRUE(found.isValid()); EXPECT_EQ(found.ptr, buffer.ptr); // Release the buffer allocator.releaseBuffer(*mComm, buffer.ptr); } TEST_F(NCCLWindowAllocatorTest, BufferReuse) { auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); const size_t bufferSize = 512 * 1024; // 512KB // Allocate first buffer auto buffer1 = allocator.requestBuffer(*mComm, bufferSize); EXPECT_TRUE(buffer1.isValid()); void* ptr1 = buffer1.ptr; // Release it allocator.releaseBuffer(*mComm, ptr1); // Request another buffer of the same size - should reuse auto buffer2 = allocator.requestBuffer(*mComm, bufferSize); EXPECT_TRUE(buffer2.isValid()); EXPECT_EQ(buffer2.ptr, ptr1); // Should be the same buffer allocator.releaseBuffer(*mComm, buffer2.ptr); } TEST_F(NCCLWindowAllocatorTest, BestFitReuse) { auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); // Allocate buffers of different sizes auto buffer1MB = allocator.requestBuffer(*mComm, 1024 * 1024); auto buffer2MB = allocator.requestBuffer(*mComm, 2 * 1024 * 1024); auto buffer512KB = allocator.requestBuffer(*mComm, 512 * 1024); void* ptr1MB = buffer1MB.ptr; void* ptr2MB = buffer2MB.ptr; void* ptr512KB = buffer512KB.ptr; // Release all allocator.releaseBuffer(*mComm, ptr1MB); allocator.releaseBuffer(*mComm, ptr2MB); allocator.releaseBuffer(*mComm, ptr512KB); // Request 768KB - should reuse 1MB (best fit, smallest that fits) auto buffer768KB = allocator.requestBuffer(*mComm, 768 * 1024); EXPECT_TRUE(buffer768KB.isValid()); EXPECT_EQ(buffer768KB.ptr, ptr1MB); // Should reuse 1MB buffer EXPECT_EQ(buffer768KB.size, 1024 * 1024); // Original size allocator.releaseBuffer(*mComm, buffer768KB.ptr); } TEST_F(NCCLWindowAllocatorTest, MultipleBuffers) { auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); const size_t bufferSize = 256 * 1024; std::vector ptrs; // Allocate multiple buffers for (int i = 0; i < 5; ++i) { auto buffer = allocator.requestBuffer(*mComm, bufferSize); EXPECT_TRUE(buffer.isValid()); ptrs.push_back(buffer.ptr); } EXPECT_EQ(allocator.getBufferCount(*mComm), 5); EXPECT_EQ(allocator.getBufferInUseCount(*mComm), 5); // Release all for (auto* ptr : ptrs) { allocator.releaseBuffer(*mComm, ptr); } EXPECT_EQ(allocator.getBufferInUseCount(*mComm), 0); EXPECT_EQ(allocator.getBufferCount(*mComm), 5); // Buffers still exist, just not in use } TEST_F(NCCLWindowAllocatorTest, SearchBuffer) { auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); const size_t bufferSize = 128 * 1024; auto buffer = allocator.requestBuffer(*mComm, bufferSize); // Test searchBuffer auto found = allocator.searchBuffer(*mComm, buffer.ptr); EXPECT_TRUE(found.isValid()); EXPECT_EQ(found.ptr, buffer.ptr); // Compare against actual allocated size (ncclMemAlloc may allocate more than requested) EXPECT_EQ(found.size, buffer.size); EXPECT_GE(found.size, bufferSize); // At least the requested size // Test search for non-existent buffer void* fakePtr = reinterpret_cast(0xDEADBEEF); auto notFound = allocator.searchBuffer(*mComm, fakePtr); EXPECT_FALSE(notFound.isValid()); allocator.releaseBuffer(*mComm, buffer.ptr); } TEST_F(NCCLWindowAllocatorTest, GetWindowAndSize) { auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); const size_t bufferSize = 64 * 1024; auto buffer = allocator.requestBuffer(*mComm, bufferSize); // Test getWindow auto window = allocator.getWindow(*mComm, buffer.ptr); EXPECT_NE(window, nullptr); EXPECT_EQ(window, buffer.window); // Test getSize - compare against actual allocated size (ncclMemAlloc may allocate more than requested) auto size = allocator.getSize(*mComm, buffer.ptr); EXPECT_EQ(size, buffer.size); EXPECT_GE(size, bufferSize); // At least the requested size // Test with invalid pointer void* fakePtr = reinterpret_cast(0xDEADBEEF); EXPECT_EQ(allocator.getWindow(*mComm, fakePtr), nullptr); EXPECT_EQ(allocator.getSize(*mComm, fakePtr), 0); allocator.releaseBuffer(*mComm, buffer.ptr); } TEST_F(NCCLWindowAllocatorTest, GetBufferInfo) { auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); const size_t bufferSize = 32 * 1024; auto buffer = allocator.requestBuffer(*mComm, bufferSize); auto info = allocator.getBufferInfo(*mComm, buffer.ptr); EXPECT_TRUE(info.isValid()); EXPECT_EQ(info.ptr, buffer.ptr); EXPECT_EQ(info.size, buffer.size); EXPECT_EQ(info.handle, buffer.handle); EXPECT_EQ(info.window, buffer.window); allocator.releaseBuffer(*mComm, buffer.ptr); } TEST_F(NCCLWindowAllocatorTest, ScopedBuffer) { const size_t bufferSize = 16 * 1024; { nccl_util::ScopedNCCLWindowBuffer scopedBuffer(*mComm, bufferSize); EXPECT_TRUE(scopedBuffer.getBuffer().isValid()); EXPECT_NE(scopedBuffer.getPtr(), nullptr); // Compare against actual allocated size (ncclMemAlloc may allocate more than requested) EXPECT_EQ(scopedBuffer.getSize(), scopedBuffer.getBuffer().size); EXPECT_GE(scopedBuffer.getSize(), bufferSize); // At least the requested size EXPECT_NE(scopedBuffer.getWindow(), nullptr); // Buffer should be in use auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); EXPECT_EQ(allocator.getBufferInUseCount(*mComm), 1); } // Buffer should be released when scoped buffer goes out of scope auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); EXPECT_EQ(allocator.getBufferInUseCount(*mComm), 0); } TEST_F(NCCLWindowAllocatorTest, CleanupOnCommDestroy) { auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); // Create a separate comm using split for this test auto testComm = createSplitComm(*mComm, 0, mRank); // Store the raw comm value before destruction ncclComm_t rawComm = *testComm; // Allocate some buffers const size_t bufferSize = 8 * 1024; auto buffer1 = allocator.requestBuffer(*testComm, bufferSize); auto buffer2 = allocator.requestBuffer(*testComm, bufferSize * 2); EXPECT_EQ(allocator.getBufferCount(*testComm), 2); EXPECT_EQ(allocator.getBufferInUseCount(*testComm), 2); // Verify buffers are valid EXPECT_TRUE(buffer1.isValid()); EXPECT_TRUE(buffer2.isValid()); // Manually release buffers before cleanup to avoid warnings allocator.releaseBuffer(*testComm, buffer1.ptr); allocator.releaseBuffer(*testComm, buffer2.ptr); // Verify buffers are released but still exist in pool EXPECT_EQ(allocator.getBufferInUseCount(*testComm), 0); EXPECT_EQ(allocator.getBufferCount(*testComm), 2); // Buffers still exist, just not in use // Destroy the communicator - buffers should be cleaned up automatically testComm.reset(); // Verify cleanup: check that the old comm (now destroyed) no longer has buffers // Note: The comm is destroyed, but we can still check the allocator's internal state // The cleanup should have removed all buffers for this comm EXPECT_EQ(allocator.getBufferCount(rawComm), 0); EXPECT_EQ(allocator.getBufferInUseCount(rawComm), 0); // Note: isCommValid only checks for null, not cleaned-up state, because NCCL can reuse addresses // The real check is that buffers are gone, which we verify above } TEST_F(NCCLWindowAllocatorTest, CommValidity) { auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); // Valid comm should be valid EXPECT_TRUE(allocator.isCommValid(*mComm)); // Null comm should be invalid EXPECT_FALSE(allocator.isCommValid(nullptr)); } //============================================================================== // Integration Tests //============================================================================== TEST_F(NCCLWindowAllocatorTest, MultipleComms) { auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); // Create two different communicators using split (different colors) auto comm1 = createSplitComm(*mComm, 0, mRank); auto comm2 = createSplitComm(*mComm, 1, mRank); const size_t bufferSize = 4 * 1024; // Allocate buffers from both comms auto buffer1 = allocator.requestBuffer(*comm1, bufferSize); auto buffer2 = allocator.requestBuffer(*comm2, bufferSize); EXPECT_TRUE(buffer1.isValid()); EXPECT_TRUE(buffer2.isValid()); // Buffers should be tracked separately per comm EXPECT_EQ(allocator.getBufferCount(*comm1), 1); EXPECT_EQ(allocator.getBufferCount(*comm2), 1); EXPECT_NE(buffer1.ptr, buffer2.ptr); // Different buffers from different comms allocator.releaseBuffer(*comm1, buffer1.ptr); allocator.releaseBuffer(*comm2, buffer2.ptr); // Clean up comms comm1.reset(); comm2.reset(); } #if ENABLE_MULTI_DEVICE && BUILD_PYT //============================================================================== // createNCCLWindowTensor Tests //============================================================================== class CreateNCCLWindowTensorTest : public ::testing::Test { protected: void SetUp() override { auto& comm = mpi::MpiComm::world(); mWorldSize = comm.getSize(); mRank = comm.getRank(); if (mWorldSize < 2) { GTEST_SKIP() << "Requires at least 2 ranks (got " << mWorldSize << ")"; } // Set CUDA device for this rank (required before NCCL initialization) int deviceCount = 0; TLLM_CUDA_CHECK(cudaGetDeviceCount(&deviceCount)); if (deviceCount > 0) { int deviceId = mRank % deviceCount; TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); } // Check if NCCL symmetric is supported auto& ncclHelper = nccl_util::NCCLHelper::getInstance(); if (!ncclHelper.isLoaded()) { GTEST_SKIP() << "NCCL library with symmetric memory support is not available"; } std::set group; for (int i = 0; i < mWorldSize; ++i) { group.insert(i); } mComm = getComm(group); } void TearDown() override { mComm.reset(); } int mWorldSize; int mRank; std::shared_ptr mComm; }; TEST_F(CreateNCCLWindowTensorTest, BasicTensorCreation) { using nccl_util::createNCCLWindowTensor; // Create a tensor with shape [4, 8] and float32 dtype std::vector shape = {4, 8}; auto [tensor, buffer] = createNCCLWindowTensor(*mComm, shape, torch::kFloat32); // Verify tensor properties EXPECT_TRUE(tensor.defined()); EXPECT_EQ(tensor.dtype(), torch::kFloat32); EXPECT_EQ(tensor.device().type(), torch::kCUDA); EXPECT_EQ(tensor.dim(), 2); EXPECT_EQ(tensor.size(0), 4); EXPECT_EQ(tensor.size(1), 8); EXPECT_EQ(tensor.numel(), 4 * 8); // Verify buffer properties EXPECT_TRUE(buffer.isValid()); EXPECT_NE(buffer.ptr, nullptr); // ncclMemAlloc may allocate more than requested, so check at least the requested size EXPECT_GE(buffer.size, 4 * 8 * sizeof(float)); EXPECT_NE(buffer.window, nullptr); // Verify tensor data pointer matches buffer pointer EXPECT_EQ(tensor.data_ptr(), buffer.ptr); // Tensor should be in use auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); EXPECT_EQ(allocator.getBufferInUseCount(*mComm), 1); } TEST_F(CreateNCCLWindowTensorTest, DifferentDtypes) { using nccl_util::createNCCLWindowTensor; std::vector shape = {10}; // Test float32 { auto [tensor, buffer] = createNCCLWindowTensor(*mComm, shape, torch::kFloat32); EXPECT_EQ(tensor.dtype(), torch::kFloat32); // ncclMemAlloc may allocate more than requested, so check at least the requested size EXPECT_GE(buffer.size, 10 * sizeof(float)); EXPECT_EQ(tensor.data_ptr(), buffer.ptr); } // Test float16 { auto [tensor, buffer] = createNCCLWindowTensor(*mComm, shape, torch::kFloat16); EXPECT_EQ(tensor.dtype(), torch::kFloat16); // ncclMemAlloc may allocate more than requested, so check at least the requested size EXPECT_GE(buffer.size, 10 * sizeof(at::Half)); EXPECT_EQ(tensor.data_ptr(), buffer.ptr); } // Test int32 { auto [tensor, buffer] = createNCCLWindowTensor(*mComm, shape, torch::kInt32); EXPECT_EQ(tensor.dtype(), torch::kInt32); // ncclMemAlloc may allocate more than requested, so check at least the requested size EXPECT_GE(buffer.size, 10 * sizeof(int32_t)); EXPECT_EQ(tensor.data_ptr(), buffer.ptr); } } TEST_F(CreateNCCLWindowTensorTest, DifferentShapes) { using nccl_util::createNCCLWindowTensor; // 1D tensor { std::vector shape = {100}; auto [tensor, buffer] = createNCCLWindowTensor(*mComm, shape, torch::kFloat32); EXPECT_EQ(tensor.dim(), 1); EXPECT_EQ(tensor.size(0), 100); // ncclMemAlloc may allocate more than requested, so check at least the requested size EXPECT_GE(buffer.size, 100 * sizeof(float)); } // 3D tensor { std::vector shape = {2, 3, 4}; auto [tensor, buffer] = createNCCLWindowTensor(*mComm, shape, torch::kFloat32); EXPECT_EQ(tensor.dim(), 3); EXPECT_EQ(tensor.size(0), 2); EXPECT_EQ(tensor.size(1), 3); EXPECT_EQ(tensor.size(2), 4); // ncclMemAlloc may allocate more than requested, so check at least the requested size EXPECT_GE(buffer.size, 2 * 3 * 4 * sizeof(float)); } // 4D tensor { std::vector shape = {1, 2, 3, 4}; auto [tensor, buffer] = createNCCLWindowTensor(*mComm, shape, torch::kFloat32); EXPECT_EQ(tensor.dim(), 4); EXPECT_EQ(tensor.numel(), 1 * 2 * 3 * 4); // ncclMemAlloc may allocate more than requested, so check at least the requested size EXPECT_GE(buffer.size, 1 * 2 * 3 * 4 * sizeof(float)); } } TEST_F(CreateNCCLWindowTensorTest, TensorDeleterReleasesBuffer) { using nccl_util::createNCCLWindowTensor; auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); { std::vector shape = {16, 16}; auto [tensor, buffer] = createNCCLWindowTensor(*mComm, shape, torch::kFloat32); EXPECT_EQ(allocator.getBufferInUseCount(*mComm), 1); EXPECT_TRUE(buffer.isValid()); void* bufferPtr = buffer.ptr; // Tensor goes out of scope - deleter should release the buffer } // Buffer should be released (not in use anymore) EXPECT_EQ(allocator.getBufferInUseCount(*mComm), 0); // Buffer should still exist in the pool (for reuse) EXPECT_GE(allocator.getBufferCount(*mComm), 1); } TEST_F(CreateNCCLWindowTensorTest, MultipleTensors) { using nccl_util::createNCCLWindowTensor; auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); std::vector shape = {8, 8}; auto [tensor1, buffer1] = createNCCLWindowTensor(*mComm, shape, torch::kFloat32); auto [tensor2, buffer2] = createNCCLWindowTensor(*mComm, shape, torch::kFloat32); auto [tensor3, buffer3] = createNCCLWindowTensor(*mComm, shape, torch::kFloat32); EXPECT_EQ(allocator.getBufferInUseCount(*mComm), 3); EXPECT_NE(buffer1.ptr, buffer2.ptr); EXPECT_NE(buffer2.ptr, buffer3.ptr); EXPECT_NE(buffer1.ptr, buffer3.ptr); // All tensors should be valid EXPECT_TRUE(tensor1.defined()); EXPECT_TRUE(tensor2.defined()); EXPECT_TRUE(tensor3.defined()); } TEST_F(CreateNCCLWindowTensorTest, TensorStrides) { using nccl_util::createNCCLWindowTensor; std::vector shape = {3, 4, 5}; auto [tensor, buffer] = createNCCLWindowTensor(*mComm, shape, torch::kFloat32); // Verify strides are correct (row-major order) EXPECT_EQ(tensor.stride(0), 4 * 5); // stride for first dimension EXPECT_EQ(tensor.stride(1), 5); // stride for second dimension EXPECT_EQ(tensor.stride(2), 1); // stride for third dimension } #endif // ENABLE_MULTI_DEVICE && BUILD_PYT #endif // ENABLE_MULTI_DEVICE