/* * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "tensorrt_llm/common/config.h" #include #include #include TRTLLM_NAMESPACE_BEGIN namespace common { /** * @brief Wrapper that holds an optional reference and integrates with std containers. * @details The wrapper uses a std::optional> at its core. When constructed with a unique or shared ptr with a nullptr value, it is interpreted as not holding a value, meaning the std::optional of the wrapper object will be false. * * @tparam T */ template class OptionalRef { private: std::optional> opt; public: OptionalRef() = default; OptionalRef(T& ref) : opt(std::ref(ref)) { } OptionalRef(std::nullopt_t) : opt(std::nullopt) { } OptionalRef(std::shared_ptr const& ptr) : opt(ptr ? std::optional>(std::ref(*ptr)) : std::nullopt) { } // Constructor for std::shared_ptr> when T is const-qualified template >> OptionalRef(std::shared_ptr> const& ptr) : opt(ptr ? std::optional>(std::ref(*ptr)) : std::nullopt) { } OptionalRef(std::unique_ptr const& ptr) : opt(ptr ? std::optional>(std::ref(*ptr)) : std::nullopt) { } // Constructor for std::unique_ptr> when T is const-qualified template >> OptionalRef(std::unique_ptr> const& ptr) : opt(ptr ? std::optional>(std::ref(*ptr)) : std::nullopt) { } T* operator->() const { return opt ? &(opt->get()) : nullptr; } T& operator*() const { return opt->get(); } explicit operator bool() const { return opt.has_value(); } bool has_value() const { return opt.has_value(); } T& value() const { return opt->get(); } }; } // namespace common TRTLLM_NAMESPACE_END