mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 14:07:21 +08:00
[TRTLLM-4406][feat] LLM sleep & wakeup Part 1: virtual device memory (#5034)
Signed-off-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com>
This commit is contained in:
parent
b9fe0fa7ec
commit
a2f271c8e0
@ -1,5 +1,6 @@
|
||||
Checks: '*,
|
||||
-altera-id-dependent-backward-branch,
|
||||
-altera-struct-pack-align,
|
||||
-altera-unroll-loops,
|
||||
-boost-use-ranges,
|
||||
-cppcoreguidelines-avoid-do-while,
|
||||
@ -9,8 +10,10 @@ Checks: '*,
|
||||
-fuchsia-default-arguments-calls,
|
||||
-fuchsia-default-arguments-declarations,
|
||||
-fuchsia-overloaded-operator,
|
||||
-fuchsia-virtual-inheritance,
|
||||
-hicpp-vararg,
|
||||
-llvm-else-after-return,
|
||||
-llvmlibc-*,
|
||||
-misc-include-cleaner,
|
||||
-misc-non-private-member-variables-in-classes,
|
||||
-modernize-use-trailing-return-type'
|
||||
|
||||
540
cpp/include/tensorrt_llm/runtime/virtualMemory.h
Normal file
540
cpp/include/tensorrt_llm/runtime/virtualMemory.h
Normal file
@ -0,0 +1,540 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/runtime/cudaEvent.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
#include "tensorrt_llm/runtime/memoryCounters.h"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <unistd.h>
|
||||
#include <utility>
|
||||
|
||||
class VirtualMemoryManagerTest;
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
|
||||
/**
|
||||
* CUDAVirtualMemoryChunk is a handle to a piece of CUDA memory allocation,
|
||||
* providing the ability to release and rematerialize the allocation.
|
||||
*/
|
||||
class CUDAVirtualMemoryChunk
|
||||
{
|
||||
public:
|
||||
/**
|
||||
* CUDAVirtualMemoryChunk::Creator is the interface to obtain a CUmemGenericAllocationHandle,
|
||||
* either by creating one locally, or importing one from remote.
|
||||
*/
|
||||
struct Creator
|
||||
{
|
||||
Creator() = default;
|
||||
virtual ~Creator() = default;
|
||||
Creator(Creator const&) = default;
|
||||
Creator& operator=(Creator const&) = default;
|
||||
Creator(Creator&&) = default;
|
||||
Creator& operator=(Creator&&) = default;
|
||||
|
||||
// Note: create() shall not leak resources when throwing exceptions.
|
||||
// release() will only, and will always be called if create() success.
|
||||
// release() will be called with destructing=true when the CUDAVirtualMemoryChunk
|
||||
// is being destructed.
|
||||
virtual CUmemGenericAllocationHandle create() = 0;
|
||||
virtual void release(CUmemGenericAllocationHandle handle, bool destructing) = 0;
|
||||
};
|
||||
|
||||
using CreatorPtr = std::unique_ptr<Creator>;
|
||||
|
||||
/**
|
||||
* CUDAVirtualMemoryChunk::Configurator is the interface to configure a CUmemGenericAllocationHandle:
|
||||
* - Map into virtual address
|
||||
* - Bind to multicast object
|
||||
* - Backup and restore memory content
|
||||
*/
|
||||
struct Configurator
|
||||
{
|
||||
Configurator() = default;
|
||||
virtual ~Configurator() = default;
|
||||
Configurator(Configurator const&) = default;
|
||||
Configurator& operator=(Configurator const&) = default;
|
||||
Configurator(Configurator&&) = default;
|
||||
Configurator& operator=(Configurator&&) = default;
|
||||
|
||||
// Note: setup() shall not leak resources when throwing exceptions.
|
||||
// teardown() will only, and will always be called if setup() success.
|
||||
// teardown() will be called with destructing=true when the CUDAVirtualMemoryChunk
|
||||
// is being destructed.
|
||||
virtual void setup(CUmemGenericAllocationHandle handle) = 0;
|
||||
virtual void teardown(CUmemGenericAllocationHandle handle, bool destructing) = 0;
|
||||
};
|
||||
|
||||
using ConfiguratorPtr = std::unique_ptr<Configurator>;
|
||||
using Configurators = std::vector<ConfiguratorPtr>;
|
||||
|
||||
enum Status
|
||||
{
|
||||
INVALID, // This is a default constructed invalid CUDAVirtualMemoryChunk.
|
||||
RELEASED, // The memory represented by this CUDAVirtualMemoryChunk is not allocated.
|
||||
MATERIALIZED, // The memory represented by this CUDAVirtualMemoryChunk is allocated.
|
||||
ERRORED, // Error happened during materialize() or release().
|
||||
// This CUDAVirtualMemoryChunk cannot be used anymore.
|
||||
};
|
||||
|
||||
[[nodiscard]] Status status() const noexcept
|
||||
{
|
||||
if (mCreator == nullptr)
|
||||
{
|
||||
return INVALID;
|
||||
}
|
||||
|
||||
if (mState == 0 && mHandle == 0)
|
||||
{
|
||||
return RELEASED;
|
||||
}
|
||||
|
||||
if (mState == mConfigurators.size() && mHandle != 0)
|
||||
{
|
||||
return MATERIALIZED;
|
||||
}
|
||||
|
||||
return ERRORED;
|
||||
}
|
||||
|
||||
/**
|
||||
* Materialize this CUDAVirtualMemoryChunk.
|
||||
* Shall be called only when status() == RELEASED.
|
||||
*
|
||||
* Calls creator.create(), and then configurator.setup() for each configurator in order.
|
||||
*
|
||||
* Stop at the first thrown exception and propagates it.
|
||||
*/
|
||||
void materialize();
|
||||
|
||||
/**
|
||||
* Release this CUDAVirtualMemoryChunk.
|
||||
* Shall be called only when status() == MATERIALIZED, or materialize() throws.
|
||||
* Will be called automatically by destructor if necessary.
|
||||
*
|
||||
* Calls configurator.teardown() for each configurator that setup() succeed in materialize() in reversed order,
|
||||
* and then creator.release().
|
||||
*
|
||||
* Never stops early upon exception. The last thrown exception will be propagated, and others logged.
|
||||
*/
|
||||
void release()
|
||||
{
|
||||
_release(false);
|
||||
}
|
||||
|
||||
CUDAVirtualMemoryChunk(CUDAVirtualMemoryChunk const&) = delete;
|
||||
CUDAVirtualMemoryChunk& operator=(CUDAVirtualMemoryChunk const&) = delete;
|
||||
|
||||
CUDAVirtualMemoryChunk(CUDAVirtualMemoryChunk&& other) noexcept
|
||||
{
|
||||
mCreator = std::move(other.mCreator);
|
||||
mConfigurators = std::move(other.mConfigurators);
|
||||
mHandle = other.mHandle;
|
||||
mState = other.mState;
|
||||
new (&other) CUDAVirtualMemoryChunk; // Put other into default constructed state
|
||||
}
|
||||
|
||||
CUDAVirtualMemoryChunk& operator=(CUDAVirtualMemoryChunk&& other)
|
||||
{
|
||||
this->~CUDAVirtualMemoryChunk(); // May throw if current virtual memory need release
|
||||
new (this) CUDAVirtualMemoryChunk(std::move(other));
|
||||
return *this;
|
||||
}
|
||||
|
||||
CUDAVirtualMemoryChunk() noexcept = default;
|
||||
|
||||
CUDAVirtualMemoryChunk(CreatorPtr&& creator, Configurators&& configurators)
|
||||
: mCreator(std::move(creator))
|
||||
, mConfigurators(std::move(configurators))
|
||||
{
|
||||
}
|
||||
|
||||
virtual ~CUDAVirtualMemoryChunk()
|
||||
{
|
||||
// Calling release() is necessary if materialize() succeed or threw an exception.
|
||||
// If release() is already called by the user, whether succeed or threw an exception,
|
||||
// we shouldn't call release() again.
|
||||
if (mHandle != 0 && mState != INVALID_STATE)
|
||||
{
|
||||
_release(true);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Test if this CUDAVirtualMemoryChunk is managing a memory block.
|
||||
*/
|
||||
explicit operator bool() const noexcept
|
||||
{
|
||||
return mCreator != nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
void _release(bool destructing);
|
||||
|
||||
constexpr static size_t INVALID_STATE = static_cast<size_t>(-1);
|
||||
size_t mState = 0;
|
||||
CUmemGenericAllocationHandle mHandle{};
|
||||
std::unique_ptr<Creator> mCreator;
|
||||
std::vector<std::unique_ptr<Configurator>> mConfigurators;
|
||||
};
|
||||
|
||||
/**
|
||||
* LocalCreator creates memory allocation locally through cuMemCreate.
|
||||
*/
|
||||
template <bool count = true>
|
||||
struct LocalCreator : CUDAVirtualMemoryChunk::Creator
|
||||
{
|
||||
LocalCreator(CUmemAllocationProp const& prop, size_t size)
|
||||
: mProp(prop)
|
||||
, mSize(size)
|
||||
{
|
||||
}
|
||||
|
||||
CUmemGenericAllocationHandle create() override
|
||||
{
|
||||
CUmemGenericAllocationHandle handle{};
|
||||
TLLM_CU_CHECK(cuMemCreate(&handle, mSize, &mProp, 0));
|
||||
if constexpr (count)
|
||||
{
|
||||
MemoryCounters::getInstance().allocate(
|
||||
mProp.location.type == CU_MEM_LOCATION_TYPE_DEVICE ? MemoryType::kGPU : MemoryType::kPINNED, mSize);
|
||||
}
|
||||
return handle;
|
||||
}
|
||||
|
||||
void release(CUmemGenericAllocationHandle handle, bool destructing) override
|
||||
{
|
||||
TLLM_CU_CHECK_FREE_RESOURCE(cuMemRelease(handle));
|
||||
if constexpr (count)
|
||||
{
|
||||
MemoryCounters::getInstance().deallocate(
|
||||
mProp.location.type == CU_MEM_LOCATION_TYPE_DEVICE ? MemoryType::kGPU : MemoryType::kPINNED, mSize);
|
||||
}
|
||||
}
|
||||
|
||||
CUmemAllocationProp mProp{};
|
||||
size_t mSize{};
|
||||
};
|
||||
|
||||
/**
|
||||
* UnicastConfigurator maps the allocation handle into the specified unicast address range.
|
||||
*/
|
||||
struct UnicastConfigurator : CUDAVirtualMemoryChunk::Configurator
|
||||
{
|
||||
UnicastConfigurator(CUdeviceptr address, size_t size, CUmemAccessDesc const& desc)
|
||||
: mAddress(address)
|
||||
, mSize(size)
|
||||
, mDesc(desc)
|
||||
{
|
||||
}
|
||||
|
||||
void setup(CUmemGenericAllocationHandle handle) override
|
||||
{
|
||||
TLLM_CU_CHECK(cuMemMap(mAddress, mSize, 0, handle, 0));
|
||||
TLLM_CU_CHECK(cuMemSetAccess(mAddress, mSize, &mDesc, 1));
|
||||
}
|
||||
|
||||
void teardown(CUmemGenericAllocationHandle, bool) override
|
||||
{
|
||||
TLLM_CU_CHECK_FREE_RESOURCE(cuMemUnmap(mAddress, mSize));
|
||||
}
|
||||
|
||||
CUdeviceptr mAddress;
|
||||
size_t mSize;
|
||||
CUmemAccessDesc mDesc;
|
||||
};
|
||||
|
||||
/**
|
||||
* MulticastConfigurator binds the allocation handle to the given multicast object and offset.
|
||||
*/
|
||||
struct MulticastConfigurator : CUDAVirtualMemoryChunk::Configurator
|
||||
{
|
||||
void setup(CUmemGenericAllocationHandle handle) override
|
||||
{
|
||||
TLLM_CU_CHECK(cuMulticastBindMem(mMulticast, 0, handle, mBindOffset, mSize, 0));
|
||||
}
|
||||
|
||||
void teardown(CUmemGenericAllocationHandle, bool) override
|
||||
{
|
||||
TLLM_CU_CHECK_FREE_RESOURCE(cuMulticastUnbind(mMulticast, mDevice, 0, mSize));
|
||||
}
|
||||
|
||||
CUmemGenericAllocationHandle mMulticast;
|
||||
size_t mBindOffset;
|
||||
CUdevice mDevice;
|
||||
size_t mSize;
|
||||
};
|
||||
|
||||
/**
|
||||
* MemsetConfigurator fills the memory with given value.
|
||||
*/
|
||||
struct MemsetConfigurator : CUDAVirtualMemoryChunk::Configurator
|
||||
{
|
||||
MemsetConfigurator(CUdeviceptr address, size_t size, uint8_t value, CUstream stream)
|
||||
: mAddress(address)
|
||||
, mSize(size)
|
||||
, mStream(stream)
|
||||
, mValue(value)
|
||||
{
|
||||
}
|
||||
|
||||
void setup(CUmemGenericAllocationHandle) override
|
||||
{
|
||||
if (mFirstTime)
|
||||
{
|
||||
mFirstTime = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CU_CHECK(cuMemsetD8Async(mAddress, mValue, mSize, mStream));
|
||||
}
|
||||
}
|
||||
|
||||
void teardown(CUmemGenericAllocationHandle, bool) noexcept override {}
|
||||
|
||||
CUdeviceptr mAddress;
|
||||
size_t mSize;
|
||||
CUstream mStream{};
|
||||
uint8_t mValue;
|
||||
bool mFirstTime = true;
|
||||
};
|
||||
|
||||
/**
|
||||
* OffloadConfigurator offload the content of the allocation to the backup storage when teardown,
|
||||
* and restore the content on the following setup.
|
||||
*/
|
||||
struct OffloadConfigurator : CUDAVirtualMemoryChunk::Configurator
|
||||
{
|
||||
OffloadConfigurator(CUdeviceptr address, size_t size, MemoryType backType, CUstream stream, bool ondemand = false)
|
||||
: mAddress(address)
|
||||
, mSize(size)
|
||||
, mBackType(backType)
|
||||
, mStream(stream)
|
||||
, mOndemand(ondemand)
|
||||
{
|
||||
}
|
||||
|
||||
void setup(CUmemGenericAllocationHandle handle) override;
|
||||
void teardown(CUmemGenericAllocationHandle handle, bool destructing) override;
|
||||
|
||||
CUdeviceptr mAddress;
|
||||
size_t mSize;
|
||||
MemoryType mBackType;
|
||||
CUstream mStream;
|
||||
bool mOndemand;
|
||||
|
||||
IBuffer::UniquePtr mBackedStorage;
|
||||
};
|
||||
|
||||
class CudaVirtualMemoryManager
|
||||
{
|
||||
public:
|
||||
/**
|
||||
* Add memory to be managed by this manager.
|
||||
* @param handle Unique handle provided to reference this memory in `remove`.
|
||||
* @param tag Tag the memory, so this memory can be targeted in `releaseWithTag` and `materializeWithTag`.
|
||||
* @param memory The CUDAVirtualMemory object.
|
||||
*
|
||||
* The memory and internal state will remain valid if any exception is thrown.
|
||||
*/
|
||||
void add(uintptr_t handle, std::string tag, CUDAVirtualMemoryChunk&& memory);
|
||||
|
||||
/**
|
||||
* Creates and adds memory to be managed by this manager. The created memory is automatically materialized.
|
||||
* @param handle Unique handle provided to reference this memory in `remove`.
|
||||
* @param tag Tag the memory, so this memory can be targeted in `releaseWithTag` and
|
||||
* `materializeWithTag`.
|
||||
* @param creator The creator for the memory.
|
||||
* @param configurators The configurators for the memory.
|
||||
*
|
||||
* The internal state will remain valid if any exception is thrown.
|
||||
*/
|
||||
void add(uintptr_t handle, std::string tag, CUDAVirtualMemoryChunk::CreatorPtr&& creator,
|
||||
CUDAVirtualMemoryChunk::Configurators&& configurators);
|
||||
|
||||
template <typename... Configurators>
|
||||
void add(uintptr_t handle, std::string tag, CUDAVirtualMemoryChunk::CreatorPtr&& creator,
|
||||
Configurators&&... configurators)
|
||||
{
|
||||
add(handle, tag, std::move(creator), {std::forward<Configurators>(configurators)...});
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove the memory from the manager.
|
||||
* @param handle The handle provided to `add`.
|
||||
* @return The CUDAVirtualMemory object. If the handle is unknown, an empty CUDAVirtualMemory will be returned.
|
||||
*/
|
||||
CUDAVirtualMemoryChunk remove(uintptr_t handle) noexcept;
|
||||
|
||||
/**
|
||||
* Call release for CUDAVirtualMemoryChunk objects with a given tag.
|
||||
* @param tag the tag to select target memories.
|
||||
* @return Number of objects selected.
|
||||
*
|
||||
* This function will always call `CUDAVirtualMemoryChunk::release` on all selected objects.
|
||||
* The last exception thrown by `CUDAVirtualMemoryChunk::release` will be rethrown, and others will be logged.
|
||||
*
|
||||
* If any CUDAVirtualMemoryChunk threw an exception during `release`, it will be removed from the manager.
|
||||
* Call `retrieveBadHandles` to retrieve handles of all CUDAVirtualMemoryChunk that got removed due to exception.
|
||||
*/
|
||||
size_t releaseWithTag(std::string const& tag);
|
||||
|
||||
/**
|
||||
* Call materialize for CUDAVirtualMemoryChunk objects with a given tag.
|
||||
* @param tag the tag to select target memories.
|
||||
* @return Number of objects selected.
|
||||
*
|
||||
* This function will stop at the first `CUDAVirtualMemoryChunk::materialize` that throws exception,
|
||||
* and attempt to roll back previous successful `materialize` by calling `release`.
|
||||
* The exception thrown by `CUDAVirtualMemoryChunk::materialize` will be rethrown,
|
||||
* and any exception thrown by `release` will be logged.
|
||||
*
|
||||
* If any CUDAVirtualMemoryChunk threw an exception during `materialize` or `release`, it will be removed from the
|
||||
* manager. Successfully roll backed CUDAVirtualMemoryChunk will not be removed.
|
||||
* Call `retrieveBadHandles` to retrieve handles of all CUDAVirtualMemoryChunk that got removed due to exception.
|
||||
*/
|
||||
size_t materializeWithTag(std::string const& tag);
|
||||
|
||||
/**
|
||||
* Retrieve handles of all CUDAVirtualMemoryChunk that got removed due to exception and reset the list.
|
||||
* The returned list may not include all removed CUDAVirtualMemoryChunk handles if OOM happened.
|
||||
* This method is only for diagnostic purpose, and should not be called concurrently with other methods.
|
||||
* @return The handle list.
|
||||
*/
|
||||
std::vector<uintptr_t> retrieveBadHandles() noexcept;
|
||||
|
||||
private:
|
||||
CUDAVirtualMemoryChunk unsafeRemove(uintptr_t handle) noexcept;
|
||||
void addBadHandle(uintptr_t handle) noexcept;
|
||||
|
||||
struct Entry;
|
||||
// Unordered map invalidates iterator upon rehash, so we can only use the ordered map.
|
||||
using PointerMemoryMap = std::map<uintptr_t, Entry>;
|
||||
using TagEntryMap = std::multimap<std::string, PointerMemoryMap::iterator>;
|
||||
|
||||
struct Entry
|
||||
{
|
||||
CUDAVirtualMemoryChunk mMemory;
|
||||
TagEntryMap::iterator mEntryIt;
|
||||
};
|
||||
|
||||
std::mutex mMutex;
|
||||
PointerMemoryMap mMemories;
|
||||
TagEntryMap mEntries;
|
||||
std::vector<uintptr_t> mBadHandles;
|
||||
|
||||
friend VirtualMemoryManagerTest;
|
||||
};
|
||||
|
||||
class CudaVirtualMemoryAllocator
|
||||
{
|
||||
using CudaStreamPtr = std::shared_ptr<CudaStream>;
|
||||
using Pointer = void*;
|
||||
|
||||
public:
|
||||
enum RestoreMode
|
||||
{
|
||||
NONE, // The memory is not backed. Upon rematerialize, memory has uninitialized content.
|
||||
MEMSET, // The memory is memset to zero upon rematerialize.
|
||||
CPU, // The memory is backed by normal CPU memory. The content is restored upon rematerialize.
|
||||
PINNED // The memory is backed by pinned CPU memory. The content is restored upon rematerialize.
|
||||
};
|
||||
|
||||
class Configuration
|
||||
{
|
||||
CudaVirtualMemoryManager& mManager;
|
||||
std::string mTag;
|
||||
CudaStreamPtr mBackStream;
|
||||
std::size_t mPageSize;
|
||||
RestoreMode mMode;
|
||||
bool mBackground{};
|
||||
|
||||
friend class CudaVirtualMemoryAllocator;
|
||||
friend void setVirtualMemoryAllocator(
|
||||
std::string const& tag, RestoreMode mode, std::shared_ptr<CudaStream> backStream);
|
||||
|
||||
public:
|
||||
/**
|
||||
* CudaVirtualMemoryAllocator::Configuration
|
||||
* @param manager Manager used to track and manage virtual memories
|
||||
* @param tag The tag for allocated memories
|
||||
* @param mode Backed storage mode
|
||||
* @param backStream The CUDA stream used for restoring memory content
|
||||
* Note: Virtual Address Allocation is not async. The stream is not used in allocation.
|
||||
*/
|
||||
Configuration(CudaVirtualMemoryManager& manager, std::string tag, RestoreMode mode, CudaStreamPtr backStream)
|
||||
: mManager(manager)
|
||||
, mTag(std::move(tag))
|
||||
, mBackStream(std::move(backStream))
|
||||
, mPageSize(getpagesize())
|
||||
, mMode(mode)
|
||||
{
|
||||
}
|
||||
|
||||
[[nodiscard]] std::size_t pageAligned(std::size_t n) const noexcept
|
||||
{
|
||||
return (n + mPageSize - 1) & ~(mPageSize - 1);
|
||||
}
|
||||
|
||||
// Background configuration, used to indicate no virtual memory allocator is explicitly configured by the user.
|
||||
static Configuration backgroundConfiguration;
|
||||
|
||||
private:
|
||||
Configuration(CudaVirtualMemoryManager& manager, std::string tag, RestoreMode mode, CudaStreamPtr backStream,
|
||||
bool background)
|
||||
: Configuration(manager, std::move(tag), mode, std::move(backStream))
|
||||
{
|
||||
mBackground = background;
|
||||
}
|
||||
};
|
||||
|
||||
explicit CudaVirtualMemoryAllocator(std::shared_ptr<Configuration> config)
|
||||
: mConfig(std::move(config))
|
||||
{
|
||||
}
|
||||
|
||||
// Tells if this is the background allocator.
|
||||
explicit operator bool() const noexcept
|
||||
{
|
||||
return !mConfig->mBackground;
|
||||
}
|
||||
|
||||
void allocate(Pointer* ptr, std::size_t n, int device) const;
|
||||
void deallocate(Pointer ptr, std::size_t n) const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<Configuration> mConfig;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
CudaVirtualMemoryManager& getVirtualMemoryManager();
|
||||
CudaVirtualMemoryAllocator getVirtualMemoryAllocator();
|
||||
void setVirtualMemoryAllocator(
|
||||
std::string const& tag, CudaVirtualMemoryAllocator::RestoreMode mode, std::shared_ptr<CudaStream> backStream);
|
||||
void clearVirtualMemoryAllocator();
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
@ -155,6 +155,16 @@ void checkDriver(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void checkDriverExitSafe(T result, char const* const func, char const* const file, int const line)
|
||||
{
|
||||
if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED)
|
||||
{
|
||||
throw TllmException(
|
||||
file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %d.", func, result).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
/*
|
||||
@ -167,4 +177,11 @@ void checkDriver(
|
||||
(stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \
|
||||
} while (0)
|
||||
|
||||
// Avoid using CUDADriverWrapper when freeing resource, during which the global instance may already be freed.
|
||||
#define TLLM_CU_CHECK_FREE_RESOURCE(stat) \
|
||||
do \
|
||||
{ \
|
||||
tensorrt_llm::common::checkDriverExitSafe((stat), #stat, __FILE__, __LINE__); \
|
||||
} while (0)
|
||||
|
||||
#endif // CUDA_DRIVER_WRAPPER_H
|
||||
|
||||
@ -39,6 +39,7 @@
|
||||
#include "tensorrt_llm/runtime/speculativeDecodingMode.h"
|
||||
#include "tensorrt_llm/runtime/tllmRuntime.h"
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
#include "tensorrt_llm/runtime/virtualMemory.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
@ -116,6 +117,10 @@ void initBindings(nb::module_& m)
|
||||
.def_rw("scaling_vec_pointer", &tr::LoraCache::TaskLayerModuleConfig::scalingVecPointer)
|
||||
.def(nb::self == nb::self);
|
||||
|
||||
nb::class_<tr::CudaVirtualMemoryManager>(m, "CudaVirtualMemoryManager")
|
||||
.def("release_with_tag", &tr::CudaVirtualMemoryManager::releaseWithTag, nb::arg("tag"))
|
||||
.def("materialize_with_tag", &tr::CudaVirtualMemoryManager::materializeWithTag, nb::arg("tag"));
|
||||
|
||||
nb::class_<tr::BufferManager>(m, "BufferManager")
|
||||
.def(nb::init<tr::BufferManager::CudaStreamPtr, bool>(), nb::arg("stream"), nb::arg("trim_pool") = false)
|
||||
.def_prop_ro("stream", &tr::BufferManager::getStream);
|
||||
@ -311,6 +316,29 @@ void initBindings(nb::module_& m)
|
||||
[](int32_t tp_size) { return tensorrt_llm::kernels::max_workspace_size_lowprecision(tp_size); },
|
||||
"Calculate the maximum workspace size needed for low precision all-reduce operations");
|
||||
|
||||
nb::enum_<tr::CudaVirtualMemoryAllocator::RestoreMode>(m, "CudaVirtualMemoryAllocatorRestoreMode")
|
||||
.value("NONE", tr::CudaVirtualMemoryAllocator::RestoreMode::NONE)
|
||||
.value("CPU", tr::CudaVirtualMemoryAllocator::RestoreMode::CPU)
|
||||
.value("PINNED", tr::CudaVirtualMemoryAllocator::RestoreMode::PINNED)
|
||||
.value("MEMSET", tr::CudaVirtualMemoryAllocator::RestoreMode::MEMSET);
|
||||
|
||||
m.def("get_virtual_memory_manager", &tr::getVirtualMemoryManager, "Get the virtual memory manager",
|
||||
nb::rv_policy::reference);
|
||||
|
||||
m.def(
|
||||
"set_virtual_memory_allocator",
|
||||
[](std::string const& tag, tr::CudaVirtualMemoryAllocator::RestoreMode mode, uintptr_t stream)
|
||||
{
|
||||
static_assert(sizeof(uintptr_t) == sizeof(cudaStream_t));
|
||||
tr::setVirtualMemoryAllocator(tag, mode,
|
||||
std::make_shared<tr::CudaStream>(
|
||||
reinterpret_cast<cudaStream_t>(stream), tensorrt_llm::common::getDevice(), false));
|
||||
},
|
||||
"Set the virtual memory allocator and start allocating virtual memory for CUDA allocations");
|
||||
|
||||
m.def("clear_virtual_memory_allocator", &tr::clearVirtualMemoryAllocator,
|
||||
"Reset the current virtual memory allocator and stop allocating virtual memory for CUDA allocations");
|
||||
|
||||
nb::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
|
||||
.def(nb::init<size_t, uint32_t, uint32_t, at::Device, bool>())
|
||||
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer)
|
||||
|
||||
@ -38,6 +38,7 @@
|
||||
#include "tensorrt_llm/runtime/speculativeDecodingMode.h"
|
||||
#include "tensorrt_llm/runtime/tllmRuntime.h"
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
#include "tensorrt_llm/runtime/virtualMemory.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
@ -213,6 +214,10 @@ void initBindings(pybind11::module_& m)
|
||||
.def_readwrite("scaling_vec_pointer", &tr::LoraCache::TaskLayerModuleConfig::scalingVecPointer)
|
||||
.def(py::self == py::self);
|
||||
|
||||
py::class_<tr::CudaVirtualMemoryManager>(m, "CudaVirtualMemoryManager")
|
||||
.def("release_with_tag", &tr::CudaVirtualMemoryManager::releaseWithTag, py::arg("tag"))
|
||||
.def("materialize_with_tag", &tr::CudaVirtualMemoryManager::materializeWithTag, py::arg("tag"));
|
||||
|
||||
py::classh<tr::BufferManager>(m, "BufferManager")
|
||||
.def(py::init<tr::BufferManager::CudaStreamPtr, bool>(), py::arg("stream"), py::arg("trim_pool") = false)
|
||||
.def_property_readonly("stream", &tr::BufferManager::getStream);
|
||||
@ -405,6 +410,29 @@ void initBindings(pybind11::module_& m)
|
||||
[](int32_t tp_size) { return tensorrt_llm::kernels::max_workspace_size_lowprecision(tp_size); },
|
||||
"Calculate the maximum workspace size needed for low precision all-reduce operations");
|
||||
|
||||
py::enum_<tr::CudaVirtualMemoryAllocator::RestoreMode>(m, "CudaVirtualMemoryAllocatorRestoreMode")
|
||||
.value("NONE", tr::CudaVirtualMemoryAllocator::RestoreMode::NONE)
|
||||
.value("CPU", tr::CudaVirtualMemoryAllocator::RestoreMode::CPU)
|
||||
.value("PINNED", tr::CudaVirtualMemoryAllocator::RestoreMode::PINNED)
|
||||
.value("MEMSET", tr::CudaVirtualMemoryAllocator::RestoreMode::MEMSET);
|
||||
|
||||
m.def("get_virtual_memory_manager", &tr::getVirtualMemoryManager, "Get the virtual memory manager",
|
||||
py::return_value_policy::reference);
|
||||
|
||||
m.def(
|
||||
"set_virtual_memory_allocator",
|
||||
[](std::string const& tag, tr::CudaVirtualMemoryAllocator::RestoreMode mode, uintptr_t stream)
|
||||
{
|
||||
static_assert(sizeof(uintptr_t) == sizeof(cudaStream_t));
|
||||
tr::setVirtualMemoryAllocator(tag, mode,
|
||||
std::make_shared<tr::CudaStream>(
|
||||
reinterpret_cast<cudaStream_t>(stream), tensorrt_llm::common::getDevice(), false));
|
||||
},
|
||||
"Set the virtual memory allocator and start allocating virtual memory for CUDA allocations");
|
||||
|
||||
m.def("clear_virtual_memory_allocator", &tr::clearVirtualMemoryAllocator,
|
||||
"Reset the current virtual memory allocator and stop allocating virtual memory for CUDA allocations");
|
||||
|
||||
py::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
|
||||
.def(py::init<size_t, uint32_t, uint32_t, at::Device, bool>())
|
||||
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer)
|
||||
|
||||
@ -55,7 +55,8 @@ set(SRCS
|
||||
tllmStreamReaders.cpp
|
||||
tllmLogger.cpp
|
||||
workerPool.cpp
|
||||
worldConfig.cpp)
|
||||
worldConfig.cpp
|
||||
virtualMemory.cpp)
|
||||
|
||||
include_directories(${API_INCLUDE_DIR}/tensorrt_llm/runtime)
|
||||
|
||||
|
||||
@ -39,6 +39,10 @@ BufferManager::BufferManager(CudaStreamPtr stream, bool trimPool)
|
||||
|
||||
BufferManager::IBufferPtr BufferManager::gpu(std::size_t size, nvinfer1::DataType type) const
|
||||
{
|
||||
if (auto vmAllocator = getVirtualMemoryAllocator())
|
||||
{
|
||||
return std::make_unique<VirtualAddressDeviceBuffer>(size, type, std::move(vmAllocator));
|
||||
}
|
||||
if (static_cast<bool>(mPool))
|
||||
{
|
||||
return std::make_unique<DeviceBuffer>(size, type, CudaAllocatorAsync{mStream, mPool});
|
||||
@ -49,6 +53,10 @@ BufferManager::IBufferPtr BufferManager::gpu(std::size_t size, nvinfer1::DataTyp
|
||||
|
||||
BufferManager::ITensorPtr BufferManager::gpu(nvinfer1::Dims dims, nvinfer1::DataType type) const
|
||||
{
|
||||
if (auto vmAllocator = getVirtualMemoryAllocator())
|
||||
{
|
||||
return std::make_unique<VirtualAddressDeviceTensor>(dims, type, std::move(vmAllocator));
|
||||
}
|
||||
if (static_cast<bool>(mPool))
|
||||
{
|
||||
return std::make_unique<DeviceTensor>(dims, type, CudaAllocatorAsync{mStream, mPool});
|
||||
@ -59,11 +67,19 @@ BufferManager::ITensorPtr BufferManager::gpu(nvinfer1::Dims dims, nvinfer1::Data
|
||||
|
||||
BufferManager::IBufferPtr BufferManager::gpuSync(std::size_t size, nvinfer1::DataType type)
|
||||
{
|
||||
if (auto vmAllocator = getVirtualMemoryAllocator())
|
||||
{
|
||||
return std::make_unique<VirtualAddressDeviceBuffer>(size, type, std::move(vmAllocator));
|
||||
}
|
||||
return std::make_unique<StaticDeviceBuffer>(size, type, CudaAllocator{});
|
||||
}
|
||||
|
||||
BufferManager::ITensorPtr BufferManager::gpuSync(nvinfer1::Dims dims, nvinfer1::DataType type)
|
||||
{
|
||||
if (auto vmAllocator = getVirtualMemoryAllocator())
|
||||
{
|
||||
return std::make_unique<VirtualAddressDeviceTensor>(dims, type, std::move(vmAllocator));
|
||||
}
|
||||
return std::make_unique<StaticDeviceTensor>(dims, type, CudaAllocator{});
|
||||
}
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/ipcNvlsMemory.h"
|
||||
#include "tensorrt_llm/runtime/memoryCounters.h"
|
||||
#include "tensorrt_llm/runtime/virtualMemory.h"
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
@ -500,6 +501,36 @@ protected:
|
||||
|
||||
using PinnedPoolAllocator = PoolAllocator<PinnedAllocator>;
|
||||
|
||||
class CudaVirtualMemoryAllocatorAdaptor
|
||||
: public BaseAllocator<CudaVirtualMemoryAllocatorAdaptor, MemoryType::kGPU, /* count */ false>,
|
||||
CudaVirtualMemoryAllocator
|
||||
{
|
||||
// Update to MemoryCounters is done in Creator to more precisely reflect the memory usage.
|
||||
using Base = BaseAllocator<CudaVirtualMemoryAllocatorAdaptor, MemoryType::kGPU, false>;
|
||||
friend Base;
|
||||
|
||||
public:
|
||||
// No explicit, to allow implicit conversion from CudaVirtualMemoryAllocator
|
||||
CudaVirtualMemoryAllocatorAdaptor(CudaVirtualMemoryAllocator const& allocator)
|
||||
: CudaVirtualMemoryAllocator(allocator)
|
||||
{
|
||||
}
|
||||
|
||||
using Base::allocate;
|
||||
using Base::deallocate;
|
||||
|
||||
protected:
|
||||
void allocateImpl(PointerType* ptr, std::size_t n) const
|
||||
{
|
||||
this->CudaVirtualMemoryAllocator::allocate(ptr, n, tensorrt_llm::common::getDevice());
|
||||
}
|
||||
|
||||
void deallocateImpl(PointerType ptr, std::size_t n) const
|
||||
{
|
||||
this->CudaVirtualMemoryAllocator::deallocate(ptr, n);
|
||||
}
|
||||
};
|
||||
|
||||
// Adopted from https://github.com/NVIDIA/TensorRT/blob/release/8.6/samples/common/buffers.h
|
||||
|
||||
//!
|
||||
@ -508,17 +539,10 @@ using PinnedPoolAllocator = PoolAllocator<PinnedAllocator>;
|
||||
//! \details This templated RAII (Resource Acquisition Is Initialization) class handles the allocation,
|
||||
//! deallocation, querying of buffers on both the device and the host.
|
||||
//! It can handle data of arbitrary types because it stores byte buffers.
|
||||
//! The template parameters AllocFunc and FreeFunc are used for the
|
||||
//! allocation and deallocation of the buffer.
|
||||
//! AllocFunc must be a functor that takes in (void** ptr, size_t size)
|
||||
//! and returns bool. ptr is a pointer to where the allocated buffer address should be stored.
|
||||
//! size is the amount of memory in bytes to allocate.
|
||||
//! The boolean indicates whether or not the memory allocation was successful.
|
||||
//! FreeFunc must be a functor that takes in (void* ptr) and returns void.
|
||||
//! ptr is the allocated buffer address. It must work with nullptr input.
|
||||
//! The template parameter TAllocator must inherit from BaseAllocator.
|
||||
//!
|
||||
template <typename TAllocator>
|
||||
class GenericBuffer : virtual public IBuffer
|
||||
class GenericBuffer : virtual public IBuffer, TAllocator // Inherit from TAllocator for EBO
|
||||
{
|
||||
public:
|
||||
using AllocatorType = TAllocator;
|
||||
@ -527,20 +551,27 @@ public:
|
||||
//! \brief Construct an empty buffer.
|
||||
//!
|
||||
explicit GenericBuffer(nvinfer1::DataType type, TAllocator allocator = {}) // NOLINT(*-pro-type-member-init)
|
||||
: GenericBuffer{0, type, std::move(allocator)} {};
|
||||
: GenericBuffer{0, type, std::move(allocator)}
|
||||
{
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief Construct a buffer with the specified allocation size in number of elements.
|
||||
//!
|
||||
explicit GenericBuffer( // NOLINT(*-pro-type-member-init)
|
||||
std::size_t size, nvinfer1::DataType type, TAllocator allocator = {})
|
||||
: GenericBuffer{size, size, type, std::move(allocator)} {};
|
||||
: GenericBuffer{size, size, type, std::move(allocator)}
|
||||
{
|
||||
}
|
||||
|
||||
GenericBuffer(GenericBuffer const& other) = delete;
|
||||
GenericBuffer& operator=(GenericBuffer const& buf) = delete;
|
||||
|
||||
GenericBuffer(GenericBuffer&& buf) noexcept
|
||||
: mSize{buf.mSize}
|
||||
: TAllocator(static_cast<TAllocator&&>(buf))
|
||||
, mSize{buf.mSize}
|
||||
, mCapacity{buf.mCapacity}
|
||||
, mType{buf.mType}
|
||||
, mAllocator{std::move(buf.mAllocator)}
|
||||
, mBuffer{buf.mBuffer}
|
||||
{
|
||||
buf.mSize = 0;
|
||||
@ -552,11 +583,11 @@ public:
|
||||
{
|
||||
if (this != &buf)
|
||||
{
|
||||
mAllocator.deallocate(mBuffer, toBytes(mCapacity));
|
||||
this->TAllocator::deallocate(mBuffer, toBytes(mCapacity));
|
||||
mSize = buf.mSize;
|
||||
mCapacity = buf.mCapacity;
|
||||
mType = buf.mType;
|
||||
mAllocator = std::move(buf.mAllocator);
|
||||
*static_cast<TAllocator*>(this) = static_cast<TAllocator&&>(buf);
|
||||
mBuffer = buf.mBuffer;
|
||||
// Reset buf.
|
||||
buf.mSize = 0;
|
||||
@ -615,7 +646,7 @@ public:
|
||||
//!
|
||||
[[nodiscard]] MemoryType getMemoryType() const override
|
||||
{
|
||||
return mAllocator.getMemoryType();
|
||||
return this->TAllocator::getMemoryType();
|
||||
}
|
||||
|
||||
//!
|
||||
@ -625,8 +656,8 @@ public:
|
||||
{
|
||||
if (mCapacity < newSize)
|
||||
{
|
||||
mAllocator.deallocate(mBuffer, toBytes(mCapacity));
|
||||
mBuffer = mAllocator.allocate(toBytes(newSize));
|
||||
this->TAllocator::deallocate(mBuffer, toBytes(mCapacity));
|
||||
mBuffer = this->TAllocator::allocate(toBytes(newSize));
|
||||
mCapacity = newSize;
|
||||
}
|
||||
mSize = newSize;
|
||||
@ -637,7 +668,7 @@ public:
|
||||
//!
|
||||
void release() override
|
||||
{
|
||||
mAllocator.deallocate(mBuffer, toBytes(mCapacity));
|
||||
this->TAllocator::deallocate(mBuffer, toBytes(mCapacity));
|
||||
mSize = 0;
|
||||
mCapacity = 0;
|
||||
mBuffer = nullptr;
|
||||
@ -647,7 +678,7 @@ public:
|
||||
{
|
||||
try
|
||||
{
|
||||
mAllocator.deallocate(mBuffer, toBytes(mCapacity));
|
||||
this->TAllocator::deallocate(mBuffer, toBytes(mCapacity));
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
@ -657,11 +688,11 @@ public:
|
||||
|
||||
protected:
|
||||
explicit GenericBuffer(std::size_t size, std::size_t capacity, nvinfer1::DataType type, TAllocator allocator = {})
|
||||
: mSize{size}
|
||||
: TAllocator{std::move(allocator)}
|
||||
, mSize{size}
|
||||
, mCapacity{capacity}
|
||||
, mType{type}
|
||||
, mAllocator{std::move(allocator)}
|
||||
, mBuffer{capacity > 0 ? mAllocator.allocate(toBytes(capacity)) : nullptr}
|
||||
, mBuffer{capacity > 0 ? this->TAllocator::allocate(toBytes(capacity)) : nullptr}
|
||||
{
|
||||
TLLM_CHECK(size <= capacity);
|
||||
TLLM_CHECK(capacity == 0 || size > 0);
|
||||
@ -670,7 +701,6 @@ protected:
|
||||
private:
|
||||
std::size_t mSize{0}, mCapacity{0};
|
||||
nvinfer1::DataType mType;
|
||||
TAllocator mAllocator;
|
||||
void* mBuffer;
|
||||
};
|
||||
|
||||
@ -834,6 +864,7 @@ using HostBuffer = GenericBuffer<HostAllocator>;
|
||||
using PinnedBuffer = GenericBuffer<PinnedAllocator>;
|
||||
using PinnedPoolBuffer = GenericBuffer<PinnedPoolAllocator>;
|
||||
using UVMBuffer = GenericBuffer<UVMAllocator>;
|
||||
using VirtualAddressDeviceBuffer = GenericBuffer<CudaVirtualMemoryAllocatorAdaptor>;
|
||||
|
||||
template <typename T>
|
||||
std::make_unsigned_t<T> nonNegative(T value)
|
||||
@ -1069,5 +1100,6 @@ using HostTensor = GenericTensor<HostAllocator>;
|
||||
using PinnedTensor = GenericTensor<PinnedAllocator>;
|
||||
using PinnedPoolTensor = GenericTensor<PinnedPoolAllocator>;
|
||||
using UVMTensor = GenericTensor<UVMAllocator>;
|
||||
using VirtualAddressDeviceTensor = GenericTensor<CudaVirtualMemoryAllocatorAdaptor>;
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
433
cpp/tensorrt_llm/runtime/virtualMemory.cpp
Normal file
433
cpp/tensorrt_llm/runtime/virtualMemory.cpp
Normal file
@ -0,0 +1,433 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/runtime/virtualMemory.h"
|
||||
#include "bufferManager.h"
|
||||
|
||||
#include <forward_list>
|
||||
#include <shared_mutex>
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
struct ScopeGuard
|
||||
{
|
||||
bool const& ok;
|
||||
T t;
|
||||
|
||||
~ScopeGuard() noexcept(noexcept(t()))
|
||||
{
|
||||
if (!ok)
|
||||
{
|
||||
t();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
ScopeGuard(bool const&, T) -> ScopeGuard<T>;
|
||||
|
||||
} // namespace
|
||||
|
||||
void CUDAVirtualMemoryChunk::materialize()
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(status() == RELEASED, "virtual memory not in RELEASED status, is: %d", status());
|
||||
mHandle = mCreator->create();
|
||||
|
||||
// Track the number of configurators ran, so release can correctly teardown.
|
||||
for (auto const& conf : mConfigurators)
|
||||
{
|
||||
conf->setup(mHandle); // May throw
|
||||
++mState;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Callable, typename... Args>
|
||||
static bool safe_invoke_helper(std::exception_ptr& ep, char const* msg, Callable&& f, Args&&... args) noexcept
|
||||
{
|
||||
try
|
||||
{
|
||||
std::invoke(std::forward<Callable>(f), std::forward<Args>(args)...);
|
||||
return true;
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
if (ep)
|
||||
{
|
||||
try
|
||||
{
|
||||
std::rethrow_exception(ep);
|
||||
}
|
||||
catch (std::exception& e)
|
||||
{
|
||||
TLLM_LOG_ERROR(msg, e.what());
|
||||
}
|
||||
}
|
||||
ep = std::current_exception();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void CUDAVirtualMemoryChunk::_release(bool destructing)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(status() == MATERIALIZED || (status() == ERRORED && mState != INVALID_STATE),
|
||||
"virtual memory is in status %d which cannot be released", status());
|
||||
size_t const count = mConfigurators.size();
|
||||
size_t const start = count - mState;
|
||||
|
||||
// Revert materialize(). Only configurators that ran setup() successfully
|
||||
// will have their teardown() been called.
|
||||
// Never early returns on exceptions. The last exception will be rethrown, and
|
||||
// previous ones will be logged.
|
||||
std::exception_ptr ePtr{};
|
||||
auto const* msg = "Multiple exceptions thrown during release. The previous exception is: %s";
|
||||
for (size_t i = start; i < count; ++i)
|
||||
{
|
||||
safe_invoke_helper(
|
||||
ePtr, msg, &Configurator::teardown, mConfigurators[count - i - 1].get(), mHandle, destructing);
|
||||
}
|
||||
safe_invoke_helper(ePtr, msg, &Creator::release, mCreator.get(), mHandle, destructing);
|
||||
mHandle = {};
|
||||
mState = 0;
|
||||
|
||||
if (ePtr != nullptr)
|
||||
{
|
||||
mState = INVALID_STATE;
|
||||
std::rethrow_exception(ePtr);
|
||||
}
|
||||
}
|
||||
|
||||
void OffloadConfigurator::setup(CUmemGenericAllocationHandle)
|
||||
{
|
||||
if (mBackedStorage != nullptr)
|
||||
{
|
||||
if (mOndemand)
|
||||
{
|
||||
TLLM_CU_CHECK(cuMemcpyHtoD_v2(mAddress, mBackedStorage->data(), mSize));
|
||||
mBackedStorage.reset();
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CU_CHECK(cuMemcpyHtoDAsync_v2(mAddress, mBackedStorage->data(), mSize, mStream));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OffloadConfigurator::teardown(CUmemGenericAllocationHandle, bool destructing)
|
||||
{
|
||||
if (destructing)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if (mBackedStorage == nullptr)
|
||||
{
|
||||
switch (mBackType)
|
||||
{
|
||||
case MemoryType::kCPU: mBackedStorage = BufferManager::cpu(mSize, nvinfer1::DataType::kINT8); break;
|
||||
case MemoryType::kPINNED: mBackedStorage = BufferManager::pinned(mSize, nvinfer1::DataType::kINT8); break;
|
||||
default: TLLM_THROW("Unknown memory type: %d", static_cast<int32_t>(mBackType));
|
||||
}
|
||||
}
|
||||
// We have to synchronize here, or the memory may be unmapped before the copy operation.
|
||||
TLLM_CU_CHECK_FREE_RESOURCE(cuMemcpyDtoH_v2(mBackedStorage->data(), mAddress, mSize));
|
||||
}
|
||||
|
||||
void CudaVirtualMemoryManager::add(uintptr_t handle, std::string tag, CUDAVirtualMemoryChunk&& memory)
|
||||
{
|
||||
bool success = false;
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
memory.status() == CUDAVirtualMemoryChunk::RELEASED || memory.status() == CUDAVirtualMemoryChunk::MATERIALIZED,
|
||||
"CudaVirtualMemoryManager: bad virtual memory status");
|
||||
|
||||
std::unique_lock lock(mMutex);
|
||||
auto [memIt, created] = mMemories.try_emplace(handle, Entry{});
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
created, "CudaVirtualMemoryManager: handle 0x%016zx already being used by another memory", handle);
|
||||
ScopeGuard eraseMemIt{success, [&, memIt_ = memIt] { mMemories.erase(memIt_); }};
|
||||
|
||||
auto const entryIt = mEntries.emplace(std::move(tag), memIt);
|
||||
entryIt->second->second.mEntryIt = entryIt;
|
||||
|
||||
memIt->second.mMemory = std::move(memory);
|
||||
success = true;
|
||||
}
|
||||
|
||||
void CudaVirtualMemoryManager::add(uintptr_t handle, std::string tag, CUDAVirtualMemoryChunk::CreatorPtr&& creator,
|
||||
CUDAVirtualMemoryChunk::Configurators&& configurators)
|
||||
{
|
||||
std::unique_lock lock(mMutex);
|
||||
bool success = false;
|
||||
|
||||
auto [memIt, created] = mMemories.try_emplace(handle,
|
||||
Entry{
|
||||
{std::move(creator), std::move(configurators)},
|
||||
});
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
created, "CudaVirtualMemoryManager: handle 0x%016zx already being used by another memory", handle);
|
||||
ScopeGuard eraseMemIt{success, [&, memIt_ = memIt] { mMemories.erase(memIt_); }};
|
||||
|
||||
auto const entryIt = mEntries.emplace(std::move(tag), memIt);
|
||||
memIt->second.mEntryIt = entryIt;
|
||||
ScopeGuard eraseTagIt{success, [&] { mEntries.erase(entryIt); }};
|
||||
|
||||
try
|
||||
{
|
||||
// Hopefully we don't need to hold the mutex guarding mMemories and mEntries anymore.
|
||||
lock.unlock();
|
||||
memIt->second.mMemory.materialize();
|
||||
success = true;
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
// ...unless materialize() throws and we need to rollback.
|
||||
lock.lock();
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
CUDAVirtualMemoryChunk CudaVirtualMemoryManager::remove(uintptr_t handle) noexcept
|
||||
{
|
||||
std::unique_lock lock(mMutex);
|
||||
|
||||
return unsafeRemove(handle);
|
||||
}
|
||||
|
||||
CUDAVirtualMemoryChunk CudaVirtualMemoryManager::unsafeRemove(uintptr_t handle) noexcept
|
||||
{
|
||||
auto const nodeHandle = mMemories.extract(handle);
|
||||
if (!nodeHandle)
|
||||
{
|
||||
return {};
|
||||
}
|
||||
mEntries.erase(nodeHandle.mapped().mEntryIt);
|
||||
|
||||
return std::move(nodeHandle.mapped().mMemory);
|
||||
}
|
||||
|
||||
void CudaVirtualMemoryManager::addBadHandle(uintptr_t handle) noexcept
|
||||
{
|
||||
try
|
||||
{
|
||||
mBadHandles.push_back(handle);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uintptr_t> CudaVirtualMemoryManager::retrieveBadHandles() noexcept
|
||||
{
|
||||
return std::move(mBadHandles);
|
||||
}
|
||||
|
||||
size_t CudaVirtualMemoryManager::releaseWithTag(std::string const& tag)
|
||||
{
|
||||
std::unique_lock lock(mMutex);
|
||||
|
||||
std::exception_ptr ePtr{};
|
||||
auto [begin, end] = mEntries.equal_range(tag);
|
||||
size_t count = 0;
|
||||
for (auto it = begin; it != end;)
|
||||
{
|
||||
auto const handle = it->second->first;
|
||||
auto& memory = it->second->second.mMemory;
|
||||
++it; // element referenced by `it` will be invalidated by unsafeRemove(handle)
|
||||
if (memory.status() == CUDAVirtualMemoryChunk::MATERIALIZED)
|
||||
{
|
||||
if (!safe_invoke_helper(ePtr,
|
||||
"Multiple exceptions thrown during releaseWithTag. The previous exception is: %s",
|
||||
&CUDAVirtualMemoryChunk::release, &memory))
|
||||
{
|
||||
addBadHandle(handle);
|
||||
unsafeRemove(handle);
|
||||
}
|
||||
++count;
|
||||
}
|
||||
}
|
||||
|
||||
if (ePtr != nullptr)
|
||||
{
|
||||
std::rethrow_exception(ePtr);
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
size_t CudaVirtualMemoryManager::materializeWithTag(std::string const& tag)
|
||||
{
|
||||
std::unique_lock lock(mMutex);
|
||||
|
||||
auto [begin, end] = mEntries.equal_range(tag);
|
||||
size_t count = 0;
|
||||
|
||||
auto it = begin;
|
||||
|
||||
try
|
||||
{
|
||||
for (; it != end; ++it)
|
||||
{
|
||||
auto& memory = it->second->second.mMemory;
|
||||
if (memory.status() == CUDAVirtualMemoryChunk::RELEASED)
|
||||
{
|
||||
memory.materialize();
|
||||
++count;
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
for (auto itRollback = begin; itRollback != it;)
|
||||
{
|
||||
auto const handle = itRollback->second->first;
|
||||
auto& memory = itRollback->second->second.mMemory;
|
||||
++itRollback;
|
||||
try
|
||||
{
|
||||
memory.release();
|
||||
}
|
||||
catch (std::exception& e)
|
||||
{
|
||||
addBadHandle(handle);
|
||||
unsafeRemove(handle);
|
||||
TLLM_LOG_ERROR("Additional exception thrown during rollback of materializeWithTag: %s", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
addBadHandle(it->second->first);
|
||||
unsafeRemove(it->second->first);
|
||||
|
||||
throw;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
static_assert(sizeof(void*) == sizeof(CUdeviceptr));
|
||||
|
||||
static CUdeviceptr deviceptr_cast(void* ptr)
|
||||
{
|
||||
CUdeviceptr ret{};
|
||||
std::memcpy(&ret, &ptr, sizeof(CUdeviceptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
static void* deviceptr_cast(CUdeviceptr ptr)
|
||||
{
|
||||
void* ret{};
|
||||
std::memcpy(&ret, &ptr, sizeof(CUdeviceptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
void CudaVirtualMemoryAllocator::allocate(Pointer* ptr, std::size_t n, int device) const
|
||||
{
|
||||
CUdeviceptr address{};
|
||||
std::size_t const pageAlignedSize = mConfig->pageAligned(n);
|
||||
TLLM_CU_CHECK(cuMemAddressReserve(&address, pageAlignedSize, 0, {}, 0));
|
||||
|
||||
CUDAVirtualMemoryChunk::Configurators configurators;
|
||||
configurators.push_back(std::make_unique<UnicastConfigurator>(address, n,
|
||||
CUmemAccessDesc{{
|
||||
CU_MEM_LOCATION_TYPE_DEVICE,
|
||||
device,
|
||||
},
|
||||
CU_MEM_ACCESS_FLAGS_PROT_READWRITE}));
|
||||
|
||||
switch (mConfig->mMode)
|
||||
{
|
||||
case NONE: break;
|
||||
case MEMSET:
|
||||
configurators.push_back(std::make_unique<MemsetConfigurator>(address, n, 0, mConfig->mBackStream->get()));
|
||||
break;
|
||||
case CPU:
|
||||
configurators.push_back(
|
||||
std::make_unique<OffloadConfigurator>(address, n, MemoryType::kCPU, mConfig->mBackStream->get()));
|
||||
break;
|
||||
case PINNED:
|
||||
configurators.push_back(
|
||||
std::make_unique<OffloadConfigurator>(address, n, MemoryType::kPINNED, mConfig->mBackStream->get()));
|
||||
break;
|
||||
}
|
||||
|
||||
mConfig->mManager.add(address, mConfig->mTag,
|
||||
std::make_unique<LocalCreator<>>(CUmemAllocationProp{CU_MEM_ALLOCATION_TYPE_PINNED, CU_MEM_HANDLE_TYPE_NONE,
|
||||
{
|
||||
CU_MEM_LOCATION_TYPE_DEVICE,
|
||||
device,
|
||||
}},
|
||||
n),
|
||||
std::move(configurators));
|
||||
|
||||
*ptr = deviceptr_cast(address);
|
||||
}
|
||||
|
||||
void CudaVirtualMemoryAllocator::deallocate(Pointer ptr, std::size_t n) const
|
||||
{
|
||||
auto const address = deviceptr_cast(ptr);
|
||||
mConfig->mManager.remove(address);
|
||||
|
||||
std::size_t const pageAlignedSize = mConfig->pageAligned(n);
|
||||
TLLM_CU_CHECK_FREE_RESOURCE(cuMemAddressFree(address, pageAlignedSize));
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
|
||||
CudaVirtualMemoryManager& getVirtualMemoryManager()
|
||||
{
|
||||
static CudaVirtualMemoryManager manager;
|
||||
return manager;
|
||||
}
|
||||
|
||||
using AllocConf = CudaVirtualMemoryAllocator::Configuration;
|
||||
|
||||
AllocConf AllocConf::backgroundConfiguration{getVirtualMemoryManager(), "", NONE, nullptr, true};
|
||||
|
||||
static const std::shared_ptr<AllocConf> bgConf{std::shared_ptr<AllocConf>{}, &AllocConf::backgroundConfiguration};
|
||||
|
||||
static std::shared_mutex currentConfMutex;
|
||||
static std::shared_ptr<AllocConf> currentConf = bgConf;
|
||||
|
||||
CudaVirtualMemoryAllocator getVirtualMemoryAllocator()
|
||||
{
|
||||
std::shared_lock lock(currentConfMutex);
|
||||
return CudaVirtualMemoryAllocator{currentConf};
|
||||
}
|
||||
|
||||
void setVirtualMemoryAllocator(
|
||||
std::string const& tag, CudaVirtualMemoryAllocator::RestoreMode mode, std::shared_ptr<CudaStream> backStream)
|
||||
{
|
||||
std::unique_lock lock(currentConfMutex);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(currentConf == bgConf,
|
||||
"An active virtual memory allocator (tag: %s, mode: %d, stream: %p) is already present",
|
||||
currentConf->mTag.c_str(), currentConf->mMode, currentConf->mBackStream.get());
|
||||
currentConf = std::make_shared<AllocConf>(getVirtualMemoryManager(), tag, mode, backStream);
|
||||
}
|
||||
|
||||
void clearVirtualMemoryAllocator()
|
||||
{
|
||||
std::unique_lock lock(currentConfMutex);
|
||||
currentConf = bgConf;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
@ -85,6 +85,7 @@ add_library(
|
||||
selectiveScanOp.cpp
|
||||
userbuffersFinalizeOp.cpp
|
||||
userbuffersTensor.cpp
|
||||
virtualMemoryAllocator.cpp
|
||||
weightOnlyQuantGemm.cpp
|
||||
weightOnlyQuantOp.cpp
|
||||
mtpOp.cpp
|
||||
|
||||
60
cpp/tensorrt_llm/thop/virtualMemoryAllocator.cpp
Normal file
60
cpp/tensorrt_llm/thop/virtualMemoryAllocator.cpp
Normal file
@ -0,0 +1,60 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/runtime/virtualMemory.h"
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <sys/types.h>
|
||||
|
||||
extern "C"
|
||||
{
|
||||
|
||||
void* tensorrt_llm_virtual_memory_alloc(ssize_t size, int device, cudaStream_t) noexcept
|
||||
{
|
||||
void* ptr{};
|
||||
try
|
||||
{
|
||||
tensorrt_llm::runtime::getVirtualMemoryAllocator().allocate(&ptr, size, device);
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
TLLM_LOG_EXCEPTION(e);
|
||||
ptr = {};
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
TLLM_LOG_ERROR("Unknown exception thrown allocating virtual memory");
|
||||
ptr = {};
|
||||
}
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
void tensorrt_llm_virtual_memory_free(void* ptr, ssize_t size, cudaStream_t) noexcept
|
||||
{
|
||||
try
|
||||
{
|
||||
tensorrt_llm::runtime::getVirtualMemoryAllocator().deallocate(ptr, size);
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
TLLM_LOG_EXCEPTION(e);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
TLLM_LOG_ERROR("Unknown exception thrown deallocating virtual memory");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -28,6 +28,7 @@ add_gtest(tllmRuntimeTest tllmRuntimeTest.cpp)
|
||||
add_gtest(transposeKVKernelTest transposeKVKernelTest.cpp)
|
||||
add_gtest(userBufferTest userBufferTest.cpp)
|
||||
add_gtest(utilsTest utilsTest.cpp)
|
||||
add_gtest(virtualMemoryTest virtualMemoryTest.cpp)
|
||||
add_gtest(workerPoolTest workerPoolTest.cpp)
|
||||
add_gtest(worldConfigTest worldConfigTest.cpp)
|
||||
|
||||
|
||||
1572
cpp/tests/unit_tests/runtime/virtualMemoryTest.cpp
Normal file
1572
cpp/tests/unit_tests/runtime/virtualMemoryTest.cpp
Normal file
File diff suppressed because it is too large
Load Diff
88
tensorrt_llm/_torch/virtual_memory.py
Normal file
88
tensorrt_llm/_torch/virtual_memory.py
Normal file
@ -0,0 +1,88 @@
|
||||
import functools
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm.bindings.internal.runtime import \
|
||||
CudaVirtualMemoryAllocatorRestoreMode as RestoreMode
|
||||
from tensorrt_llm.bindings.internal.runtime import (
|
||||
clear_virtual_memory_allocator, get_virtual_memory_manager,
|
||||
set_virtual_memory_allocator)
|
||||
|
||||
__all__ = [
|
||||
"RestoreMode", "maybe_scope", "scope", "release_with_tag",
|
||||
"materialize_with_tag"
|
||||
]
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _get_torch_pluggable_virtual_memory_allocator():
|
||||
th_common = next(path for path in torch.classes.loaded_libraries
|
||||
if 'th_common' in path)
|
||||
virtual_memory_allocator = torch.cuda.CUDAPluggableAllocator(
|
||||
th_common, 'tensorrt_llm_virtual_memory_alloc',
|
||||
'tensorrt_llm_virtual_memory_free')
|
||||
return virtual_memory_allocator.allocator()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _virtual_memory_helper(tag: str, mode: RestoreMode):
|
||||
stream = torch.cuda.current_stream()
|
||||
set_virtual_memory_allocator(tag, mode, stream.cuda_stream)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
clear_virtual_memory_allocator()
|
||||
|
||||
|
||||
def _scope(
|
||||
tag: str,
|
||||
mode: RestoreMode = RestoreMode.NONE
|
||||
) -> Generator[torch.cuda.MemPool, None, None]:
|
||||
"""A context manager that routes allocations to virtual memory allocator
|
||||
using given tag and backed mode.
|
||||
|
||||
:param tag: The tag to reference the memory for release and materialize
|
||||
:param mode: The backed mode to choose how the memory content is backed up
|
||||
"""
|
||||
pool = torch.cuda.MemPool(_get_torch_pluggable_virtual_memory_allocator())
|
||||
with _virtual_memory_helper(tag, mode), torch.cuda.use_mem_pool(pool):
|
||||
yield pool
|
||||
|
||||
|
||||
scope = contextmanager(_scope)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def maybe_scope(
|
||||
enable: bool,
|
||||
tag: str,
|
||||
mode: RestoreMode = RestoreMode.NONE
|
||||
) -> Generator[torch.cuda.MemPool | None, None, None]:
|
||||
if enable:
|
||||
yield from _scope(tag, mode)
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
def release_with_tag(*tags: str) -> int:
|
||||
"""Release virtual memory allocated with given tags
|
||||
|
||||
:param tags: The tag of the scope when the virtual memory is allocated
|
||||
:return: Number of memory blobs released
|
||||
"""
|
||||
manager = get_virtual_memory_manager()
|
||||
released_blobs = sum(manager.release_with_tag(tag) for tag in tags)
|
||||
return released_blobs
|
||||
|
||||
|
||||
def materialize_with_tag(*tags: str) -> int:
|
||||
"""Materialize virtual memory allocated with given tags
|
||||
|
||||
:param tags: The tag of the scope when the virtual memory is allocated
|
||||
:return: Number of memory blobs materialized
|
||||
"""
|
||||
manager = get_virtual_memory_manager()
|
||||
materialized_blobs = sum(manager.materialize_with_tag(tag) for tag in tags)
|
||||
return materialized_blobs
|
||||
254
tests/unittest/_torch/test_virtual_memory.py
Normal file
254
tests/unittest/_torch/test_virtual_memory.py
Normal file
@ -0,0 +1,254 @@
|
||||
import gc
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import pynvml
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch import virtual_memory
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.bindings.internal.batch_manager import CacheType
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def cuda_sync_fixture():
|
||||
"""
|
||||
Synchronizes CUDA to catch device errors.
|
||||
"""
|
||||
|
||||
torch.cuda.synchronize()
|
||||
yield
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def memory_info_available():
|
||||
"""
|
||||
Checks if NVML can get per-process memory information.
|
||||
"""
|
||||
|
||||
# Allocate a small tensor to test memory tracking
|
||||
tensor = torch.zeros(4096, dtype=torch.int32, device='cuda')
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Try to get memory usage
|
||||
usage = get_current_process_memory_info()
|
||||
|
||||
# Clean up
|
||||
del tensor
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if usage == 0:
|
||||
warnings.warn("Per process memory information unavailable.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def nvml_init():
|
||||
pynvml.nvmlInit()
|
||||
|
||||
|
||||
def get_current_process_memory_info() -> int:
|
||||
"""
|
||||
Returns GPU memory usage for current process in bytes.
|
||||
"""
|
||||
# Get current process ID
|
||||
current_pid = os.getpid()
|
||||
# Get device handle for GPU 0
|
||||
device_handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
||||
|
||||
# Get running processes
|
||||
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(device_handle)
|
||||
|
||||
# Find current process
|
||||
for process in processes:
|
||||
if process.pid == current_pid:
|
||||
return process.usedGpuMemory
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def clean_cache():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
yield
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def test_basic(memory_info_available):
|
||||
memory_usage_begin = get_current_process_memory_info()
|
||||
|
||||
alloc_size = 256 * 1024 * 1024
|
||||
tag = "test_tag"
|
||||
|
||||
with virtual_memory.scope(tag) as pool:
|
||||
tensor = torch.full([alloc_size], 42, dtype=torch.int8, device='cuda')
|
||||
memory_usage_materialized = get_current_process_memory_info()
|
||||
if memory_info_available:
|
||||
assert memory_usage_begin + alloc_size == memory_usage_materialized
|
||||
|
||||
assert tensor[0].item() == 42
|
||||
|
||||
torch.cuda.synchronize()
|
||||
virtual_memory.release_with_tag(tag)
|
||||
|
||||
memory_usage_released = get_current_process_memory_info()
|
||||
if memory_info_available:
|
||||
assert memory_usage_begin == memory_usage_released
|
||||
|
||||
torch.cuda.synchronize()
|
||||
virtual_memory.materialize_with_tag(tag)
|
||||
|
||||
memory_usage_rematerialized = get_current_process_memory_info()
|
||||
if memory_info_available:
|
||||
assert memory_usage_begin + alloc_size == memory_usage_rematerialized
|
||||
|
||||
torch.fill_(tensor, 24)
|
||||
assert tensor[0].item() == 24
|
||||
|
||||
del tensor
|
||||
del pool
|
||||
|
||||
memory_usage_end = get_current_process_memory_info()
|
||||
if memory_info_available:
|
||||
assert memory_usage_begin == memory_usage_end
|
||||
|
||||
|
||||
def test_restore():
|
||||
alloc_size = 1024 * 1024
|
||||
tag = "test_tag"
|
||||
|
||||
with virtual_memory.scope(tag, virtual_memory.RestoreMode.PINNED) as pool:
|
||||
tensor = torch.full([alloc_size], 42, dtype=torch.int8, device='cuda')
|
||||
|
||||
assert tensor[0].item() == 42
|
||||
|
||||
torch.cuda.synchronize()
|
||||
virtual_memory.release_with_tag(tag)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
virtual_memory.materialize_with_tag(tag)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
assert tensor[0].item() == 42
|
||||
|
||||
del tensor
|
||||
del pool
|
||||
|
||||
|
||||
def test_kv_cache_manager(memory_info_available):
|
||||
kv_cache_params = {
|
||||
"kv_cache_config": KvCacheConfig(max_tokens=1024),
|
||||
"kv_cache_type": CacheType.SELF,
|
||||
"num_layers": 8,
|
||||
"num_kv_heads": 256,
|
||||
"head_dim": 64,
|
||||
"tokens_per_block": 64,
|
||||
"max_seq_len": 1024,
|
||||
"max_batch_size": 1,
|
||||
"mapping": Mapping(world_size=1, tp_size=1, rank=0),
|
||||
"dtype": tensorrt_llm.bindings.DataType.FP8,
|
||||
}
|
||||
|
||||
mgr = KVCacheManager(**kv_cache_params)
|
||||
mgr.shutdown()
|
||||
del mgr
|
||||
|
||||
memory_usage_begin = get_current_process_memory_info()
|
||||
|
||||
tag = "test_tag"
|
||||
cache_size = torch.empty(
|
||||
[
|
||||
2, # KV
|
||||
8, # Layers
|
||||
256, # Heads
|
||||
1024, # Tokens
|
||||
64, # Head dim
|
||||
],
|
||||
dtype=torch.float8_e4m3fn,
|
||||
device='meta')
|
||||
|
||||
alloc_size = cache_size.nelement()
|
||||
|
||||
with virtual_memory.scope(tag) as pool:
|
||||
mgr = KVCacheManager(**kv_cache_params)
|
||||
memory_usage_materialized = get_current_process_memory_info()
|
||||
if memory_info_available:
|
||||
assert memory_usage_begin + alloc_size == memory_usage_materialized
|
||||
|
||||
torch.cuda.synchronize()
|
||||
virtual_memory.release_with_tag(tag)
|
||||
|
||||
memory_usage_released = get_current_process_memory_info()
|
||||
if memory_info_available:
|
||||
assert memory_usage_begin == memory_usage_released
|
||||
|
||||
torch.cuda.synchronize()
|
||||
virtual_memory.materialize_with_tag(tag)
|
||||
|
||||
memory_usage_rematerialized = get_current_process_memory_info()
|
||||
if memory_info_available:
|
||||
assert memory_usage_begin + alloc_size == memory_usage_rematerialized
|
||||
|
||||
mgr.shutdown()
|
||||
del mgr
|
||||
del pool
|
||||
|
||||
memory_usage_end = get_current_process_memory_info()
|
||||
if memory_info_available:
|
||||
assert memory_usage_begin == memory_usage_end
|
||||
|
||||
|
||||
def test_cuda_graph(memory_info_available):
|
||||
|
||||
def work(input: torch.Tensor) -> torch.Tensor:
|
||||
intermediate = input + input
|
||||
output = input + intermediate
|
||||
return output
|
||||
|
||||
g = torch.cuda.CUDAGraph()
|
||||
tag = "cuda_graph"
|
||||
|
||||
with virtual_memory.scope(tag) as pool:
|
||||
static_input = torch.ones(1024, dtype=torch.float32, device='cuda')
|
||||
static_output = torch.zeros(1024, dtype=torch.float32, device='cuda')
|
||||
|
||||
with torch.cuda.graph(g):
|
||||
static_output.copy_(work(static_input))
|
||||
|
||||
torch.fill_(static_input, 1.0)
|
||||
g.replay()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
assert static_output[0].item() == 3.0
|
||||
|
||||
memory_usage_before = get_current_process_memory_info()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
virtual_memory.release_with_tag(tag)
|
||||
|
||||
memory_usage_released = get_current_process_memory_info()
|
||||
if memory_info_available:
|
||||
assert memory_usage_released < memory_usage_before
|
||||
|
||||
torch.cuda.synchronize()
|
||||
virtual_memory.materialize_with_tag(tag)
|
||||
|
||||
torch.fill_(static_input, 1.0)
|
||||
torch.fill_(static_output, 0.0)
|
||||
g.replay()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
assert static_output[0].item() == 3.0
|
||||
|
||||
del static_input, static_output, g, pool
|
||||
Loading…
Reference in New Issue
Block a user