chore: Mass integration of release/0.20 (#5082)

Signed-off-by: Stanley Sun <190317771+StanleySun639@users.noreply.github.com>
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com>
Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com>
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
Signed-off-by: Yanchao Lu <yanchaol@nvidia.com>
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
Signed-off-by: Ivy Zhang <25222398+crazydemo@users.noreply.github.com>
Co-authored-by: Stanley Sun <190317771+StanleySun639@users.noreply.github.com>
Co-authored-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
Co-authored-by: Erin <14718778+hchings@users.noreply.github.com>
Co-authored-by: Frank <3429989+FrankD412@users.noreply.github.com>
Co-authored-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
Co-authored-by: Yanchao Lu <yanchaol@nvidia.com>
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
Co-authored-by: Yechan Kim <161688079+yechank-nvidia@users.noreply.github.com>
Co-authored-by: Ivy Zhang <25222398+crazydemo@users.noreply.github.com>
This commit is contained in:
amirkl94 2025-06-17 14:32:02 +03:00 committed by GitHub
parent 13eef642e6
commit 8451a87742
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 75 additions and 50 deletions

View File

@ -947,30 +947,32 @@ std::vector<Response> Executor::Impl::awaitResponses(std::optional<std::chrono::
{
TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called");
checkParallelApiUsage(__func__);
std::vector<Response> responses;
std::unique_lock<std::mutex> lck(mResponsesMtx);
auto pred = [&mShutdown = mShutdown, &resp = this->mResponses]() -> bool { return !resp.empty() || mShutdown; };
auto storeResponses = [this, &resp = this->mResponses, &responses]()
auto pred = [this]() -> bool { return !mResponses.empty() || mShutdown; };
auto storeResponses = [this]()
{
for (auto it = resp.cbegin(); it != resp.cend();)
std::vector<Response> responses;
for (auto it = mResponses.begin(); it != mResponses.end();)
{
responses.insert(responses.end(), it->second.begin(), it->second.end());
addTerminatedReqId(it->second, it->first);
resp.erase(it++);
it = mResponses.erase(it);
}
return responses;
};
std::vector<Response> responses;
if (timeout)
{
if (mResponsesCv.wait_for(lck, timeout.value(), pred))
{
storeResponses();
responses = storeResponses();
}
}
else
{
mResponsesCv.wait(lck, pred);
storeResponses();
responses = storeResponses();
}
return responses;
}
@ -980,15 +982,16 @@ std::vector<Response> Executor::Impl::awaitResponses(
{
TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called");
checkParallelApiUsage(__func__);
std::vector<Response> responses;
std::unique_lock<std::mutex> lck(mResponsesMtx);
auto pred = [&mShutdown = mShutdown, &resp = this->mResponses, reqId]() -> bool
{ return (resp.find(reqId) != resp.end() && !resp.at(reqId).empty()) || mShutdown; };
auto storeIdResponse = [this, &resp = this->mResponses, &responses, reqId]()
auto pred = [this, reqId]() -> bool
{ return (mResponses.find(reqId) != mResponses.end() && !mResponses.at(reqId).empty()) || mShutdown; };
auto storeIdResponse = [this, reqId]()
{
responses.swap(resp.at(reqId));
resp.erase(reqId);
std::vector<Response> responses;
responses.swap(mResponses.at(reqId));
mResponses.erase(reqId);
addTerminatedReqId(responses, reqId);
return responses;
};
// We don't process a terminated request again. Terminated request is defined as a response
@ -1005,17 +1008,18 @@ std::vector<Response> Executor::Impl::awaitResponses(
return {Response(reqId, err)};
}
std::vector<Response> responses;
if (timeout)
{
if (mResponsesCv.wait_for(lck, timeout.value(), pred))
{
storeIdResponse();
responses = storeIdResponse();
}
}
else
{
mResponsesCv.wait(lck, pred);
storeIdResponse();
responses = storeIdResponse();
}
return responses;
}
@ -1025,26 +1029,27 @@ std::vector<std::vector<Response>> Executor::Impl::awaitResponses(
{
TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called");
checkParallelApiUsage(__func__);
std::vector<std::vector<Response>> v(requestIds.size());
std::vector<std::vector<Response>> responses;
responses.reserve(requestIds.size());
if (timeout)
{
auto const start_time = std::chrono::high_resolution_clock::now();
for (unsigned i = 0; i < v.size(); ++i)
for (auto const requestId : requestIds)
{
auto const elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - start_time);
v[i] = awaitResponses(requestIds[i],
timeout.value() > elapsed_ms ? timeout.value() - elapsed_ms : std::chrono::milliseconds{0});
responses.emplace_back(awaitResponses(
requestId, timeout.value() > elapsed_ms ? timeout.value() - elapsed_ms : std::chrono::milliseconds{0}));
}
}
else
{
for (unsigned i = 0; i < v.size(); ++i)
for (auto const requestId : requestIds)
{
v[i] = awaitResponses(requestIds[i]);
responses.emplace_back(awaitResponses(requestId));
}
}
return v;
return responses;
}
SizeType32 Executor::Impl::getNumResponsesReady(std::optional<IdType> const& optId) const
@ -1663,7 +1668,7 @@ void Executor::Impl::terminateActiveRequests(RequestList& activeRequests, std::s
}
// Remove from the requestList
activeRequests.erase(it++);
it = activeRequests.erase(it);
}
}

View File

@ -107,7 +107,7 @@ public:
std::vector<Response> awaitResponses(std::optional<std::chrono::milliseconds> const& timeout = std::nullopt);
std::vector<Response> awaitResponses(
IdType const& optId, std::optional<std::chrono::milliseconds> const& optTimeout = std::nullopt);
IdType const& reqId, std::optional<std::chrono::milliseconds> const& optTimeout = std::nullopt);
std::vector<std::vector<Response>> awaitResponses(
std::vector<IdType> const& requestIds, std::optional<std::chrono::milliseconds> const& timeout);

View File

@ -30,7 +30,7 @@ Please refer to [this guide](https://nvidia.github.io/TensorRT-LLM/installation/
- [trtllm-serve](#trtllm-serve)
- [Disaggregated Serving](#disaggregated-serving)
- [Dynamo](#dynamo)
- [tensorrtllm_backend for triton inference server (Experimental)](#tensorrtllm_backend-for-triton-inference-server-experimental)
- [tensorrtllm\_backend for triton inference server (Experimental)](#tensorrtllm_backend-for-triton-inference-server-experimental)
- [Advanced Usages](#advanced-usages)
- [Multi-node](#multi-node)
- [mpirun](#mpirun)
@ -40,6 +40,8 @@ Please refer to [this guide](https://nvidia.github.io/TensorRT-LLM/installation/
- [FlashMLA](#flashmla)
- [FP8 KV Cache and MLA](#fp8-kv-cache-and-mla)
- [W4AFP8](#w4afp8)
- [Activation calibration](#activation-calibration)
- [Weight quantization and assembling](#weight-quantization-and-assembling)
- [KV Cache Reuse](#kv-cache-reuse)
- [Notes and Troubleshooting](#notes-and-troubleshooting)
- [Known Issues](#known-issues)
@ -227,6 +229,8 @@ trtllm-eval --model <YOUR_MODEL_DIR> \
## Serving
### trtllm-serve
Take max-throughput scenario on B200 as an example, the settings are extracted from the [blog](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.md#b200-max-throughput). **For users' own models and cases, the specific settings could be different to get best performance.**
To serve the model using `trtllm-serve`:
```bash
@ -253,12 +257,12 @@ trtllm-serve \
--host localhost \
--port 8000 \
--backend pytorch \
--max_batch_size 161 \
--max_num_tokens 1160 \
--max_batch_size 384 \
--max_num_tokens 1536 \
--tp_size 8 \
--ep_size 8 \
--pp_size 1 \
--kv_cache_free_gpu_memory_fraction 0.95 \
--kv_cache_free_gpu_memory_fraction 0.85 \
--extra_llm_api_options ./extra-llm-api-config.yml
```

View File

@ -219,7 +219,6 @@ class PPInitCaller(type):
def __call__(cls, *args, **kwargs):
obj = type.__call__(cls, *args, **kwargs)
obj.__pp_init__()
return obj
@ -235,6 +234,7 @@ class DecoderModel(nn.Module, metaclass=PPInitCaller):
self.model_config = model_config
self.prologue = []
self.epilogue = []
self.keep_embed_tokens = False
def forward(
self,
@ -278,7 +278,7 @@ class DecoderModel(nn.Module, metaclass=PPInitCaller):
)
return
if hasattr(self, "embed_tokens"):
if hasattr(self, "embed_tokens") and not self.keep_embed_tokens:
self.prologue.append(self.embed_tokens)
if hasattr(self, "norm"):
self.epilogue.append(self.norm)
@ -394,6 +394,8 @@ class DecoderModelForCausalLM(nn.Module,
assert self.lm_head.tp_mode == self.model.embed_tokens.tp_mode, (
"lm_head and vocab embedding should use the same TP mode")
self.lm_head.weight = self.model.embed_tokens.weight
if config.mapping.is_last_pp_rank():
self.model.keep_embed_tokens = True
self.logits_processor = LogitsProcessor()

View File

@ -143,8 +143,9 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
if not enable_chunked_prefill and max_num_tokens < dataset_metadata.max_isl:
logger.warning(
f"Chunked prefill is disabled, but max_num_tokens ({max_num_tokens}) is less than the max ISL ({dataset_metadata.max_isl}). "
f"Forcing max_num_tokens to {dataset_metadata.max_isl}.")
max_num_tokens = dataset_metadata.max_isl
f"Forcing max_num_tokens to {dataset_metadata.max_isl + max_batch_size}."
)
max_num_tokens = dataset_metadata.max_isl + max_batch_size
pyt_options = {
"use_cuda_graph":

View File

@ -618,6 +618,11 @@ def compute_logprobs(
# reshape from [1, T, V] to [T, V]
logits = logits.squeeze(0)
if tokens is not None and logits.size(0) > len(tokens):
# WAR for nvbug 5324291 where TRT backend might return more logits
# than output tokens.
logits = logits[:len(tokens)]
logprobs = F.log_softmax(logits.to("cuda", dtype=torch.float32), dim=-1)
topk_vals, topk_indices = torch.topk(logprobs, k=top_k, dim=-1)

View File

@ -935,13 +935,17 @@ class ModelRunnerCpp(ModelRunnerMixin):
output_ids = [[[] for _ in range(num_sequences)]
for _ in range(len(request_ids))]
multi_responses = self.session.await_responses(request_ids)
responses = [
response for responses in multi_responses for response in responses
]
all_responses = []
finished_request_ids = set()
while finished_request_ids != set(request_ids):
responses = self.session.await_responses()
for response in responses:
if response.result.is_final:
finished_request_ids.add(response.request_id)
all_responses.extend(responses)
return self._fill_output(
responses=responses,
responses=all_responses,
output_ids=output_ids,
end_id=end_id,
return_dict=return_dict,

View File

@ -93,6 +93,7 @@ def test_llm_hf_gemma_quantization_1gpu_vswa(batch_size, data_type,
gemma_example_root,
llm_datasets_root, llm_rouge_root,
qformat):
skip_fp8_pre_ada(use_fp8=qformat == "fp8")
max_attention_window = VSWA_ATTENTION[Path(gemma_model_root).stem]
hf_gemma_quantization_1gpu(batch_size, data_type, gemma_model_root,
llm_venv, cmodel_dir, engine_dir,

View File

@ -1977,15 +1977,15 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
],
"video": [
["city", "night", "lights", "jacket", "wet"],
["earth", "spinning", "black", "illuminated", "lights"],
["earth", "spinning", "black"],
],
},
"qwen2.5-vl-7b-instruct": {
"image": [
["dramatic", "moody", "stormy", "turbulent", "wave"],
[
"dome", "yosemite", "landmark", "sunny", "rock", "clouds",
"pleasant"
"large", "dome", "yosemite", "landmark", "rock", "road",
"formation"
],
["highway", "traffic", "vehicles", "bus", "police"],
],

View File

@ -574,15 +574,16 @@ def setup_cache_data(request, tensorrt_llm_example_root):
def cleanup_engine_outputs(output_dir_root):
for dirpath, dirnames, _ in os.walk(output_dir_root, topdown=False):
for dirname in dirnames:
if "engine_dir" in dirname or "model_dir" in dirname or "ckpt_dir" in dirname:
folder_path = os.path.join(dirpath, dirname)
try:
shutil.rmtree(folder_path)
print_info(f"Deleted folder: {folder_path}")
except Exception as e:
print_info(f"Error deleting {folder_path}: {e}")
if output_dir_root is not None:
for dirpath, dirnames, _ in os.walk(output_dir_root, topdown=False):
for dirname in dirnames:
if "engine_dir" in dirname or "model_dir" in dirname or "ckpt_dir" in dirname:
folder_path = os.path.join(dirpath, dirname)
try:
shutil.rmtree(folder_path)
print_info(f"Deleted folder: {folder_path}")
except Exception as e:
print_info(f"Error deleting {folder_path}: {e}")
# Teardown hook to clean up engine outputs after each group of test cases are finished

View File

@ -70,6 +70,7 @@ def test_llama_v3_8b_rss_increasement(
inflight_batcher_llm_client_root,
tensorrt_llm_llama_example_root,
llama_v3_8b_model_root,
tensorrt_llm_example_root,
llm_backend_venv,
):
if BATCHING_STRATEGY == "V1" and BATCH_SCHEDULER_POLICY == "max_utilization":
@ -83,7 +84,8 @@ def test_llama_v3_8b_rss_increasement(
llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"]
# Build engine
ENGINE_PATH = prepare_llama_v3_8b_engine(tensorrt_llm_llama_example_root,
ENGINE_PATH = prepare_llama_v3_8b_engine(tensorrt_llm_example_root,
tensorrt_llm_llama_example_root,
llama_v3_8b_model_root,
workers=1)