TensorRT-LLMs/cpp/tensorrt_llm/common/envUtils.h
Chuang Zhu 44cfd757b2
Agent interface impl for NIXL (#4125)
* agentConnection

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

recv

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

agentState

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

NIXL interfaces

Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>

update cmakelists

Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>

nixl improve

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

remove cppzmq

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

fix

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

transferAgent remove register

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

work for cache Test

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

reduce sleep time

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

fix test

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

intergarte

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

nixl env

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

fix rebase error

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

cpp test

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

stash for send metaData

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

loadRemoteMD after fetchRemoteMD

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

workaround for mixed gen and context

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

test_env

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

avoid port conflict in test

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

* format

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

* use std::string

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

* typo

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

* fix transferAgentTest

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>

---------

Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
2025-05-22 09:09:41 +08:00

102 lines
2.8 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstdint>
#include <optional>
#include <string>
namespace tensorrt_llm::common
{
// Useful when you want to inject some debug code controllable with env var.
std::optional<int32_t> getIntEnv(char const* name);
std::optional<size_t> getUInt64Env(char const* name);
// XQA kernels (optimized kernels for generation phase).
bool forceXQAKernels();
// Whether XQA JIT is enabled.
//
// Returns the value of TRTLLM_ENABLE_XQA_JIT env var. If such env var doesn't exist, std::nullopt is returned.
std::optional<bool> getEnvEnableXQAJIT();
// 0 means to use heuristics.
std::optional<int32_t> getEnvXqaBlocksPerSequence();
// Whether use tileSizeKv64 for multiCtasKvMode of trtllm-gen kernels.
bool getEnvUseTileSizeKv64ForTrtllmGen();
// Tune the number of blocks per sequence for accuracy/performance purpose.
bool getEnvMmhaMultiblockDebug();
int getEnvMmhaBlocksPerSequence();
int getEnvMmhaKernelBlockSize();
// Whether PDL is enabled.
bool getEnvEnablePDL();
bool getEnvUseUCXKvCache();
bool getEnvUseMPIKvCache();
bool getEnvUseNixlKvCache();
std::string getEnvUCXInterface();
bool getEnvDisaggLayerwise();
bool getEnvParallelCacheSend();
bool getEnvRequestKVCacheConcurrent();
bool getEnvDisableKVCacheTransferOverlap();
bool getEnvEnableReceiveKVCacheParallel();
std::string getEnvKVCacheTransferOutputPath();
bool getEnvTryZCopyForKVCacheTransfer();
// Force deterministic behavior for all kernels.
bool getEnvForceDeterministic();
// Force deterministic behavior for MoE plugin.
bool getEnvForceDeterministicMOE();
// Force deterministic behavior for attention plugin.
bool getEnvForceDeterministicAttention();
// Force deterministic behavior for all reduce plugin.
bool getEnvForceDeterministicAllReduce();
// Return the workspace size for custom all reduce kernels.
// This only works when force deterministic is enabled.
size_t getEnvAllReduceWorkspaceSize();
size_t getEnvKVCacheRecvBufferCount();
bool getEnvKVCacheTransferUseAsyncBuffer();
size_t getEnvKVCacheSendMaxConcurrenceNum();
size_t getEnvMemSizeForKVCacheTransferBuffer();
uint16_t getEnvNixlPort();
} // namespace tensorrt_llm::common