mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
chore: Mass integration of release/0.20 (#5082)
Signed-off-by: Stanley Sun <190317771+StanleySun639@users.noreply.github.com> Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com> Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> Signed-off-by: Yanchao Lu <yanchaol@nvidia.com> Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> Signed-off-by: Ivy Zhang <25222398+crazydemo@users.noreply.github.com> Co-authored-by: Stanley Sun <190317771+StanleySun639@users.noreply.github.com> Co-authored-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Co-authored-by: Erin <14718778+hchings@users.noreply.github.com> Co-authored-by: Frank <3429989+FrankD412@users.noreply.github.com> Co-authored-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> Co-authored-by: Yanchao Lu <yanchaol@nvidia.com> Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Co-authored-by: Yechan Kim <161688079+yechank-nvidia@users.noreply.github.com> Co-authored-by: Ivy Zhang <25222398+crazydemo@users.noreply.github.com>
This commit is contained in:
parent
13eef642e6
commit
8451a87742
@ -947,30 +947,32 @@ std::vector<Response> Executor::Impl::awaitResponses(std::optional<std::chrono::
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called");
|
||||
checkParallelApiUsage(__func__);
|
||||
std::vector<Response> responses;
|
||||
std::unique_lock<std::mutex> lck(mResponsesMtx);
|
||||
auto pred = [&mShutdown = mShutdown, &resp = this->mResponses]() -> bool { return !resp.empty() || mShutdown; };
|
||||
auto storeResponses = [this, &resp = this->mResponses, &responses]()
|
||||
auto pred = [this]() -> bool { return !mResponses.empty() || mShutdown; };
|
||||
auto storeResponses = [this]()
|
||||
{
|
||||
for (auto it = resp.cbegin(); it != resp.cend();)
|
||||
std::vector<Response> responses;
|
||||
for (auto it = mResponses.begin(); it != mResponses.end();)
|
||||
{
|
||||
responses.insert(responses.end(), it->second.begin(), it->second.end());
|
||||
addTerminatedReqId(it->second, it->first);
|
||||
resp.erase(it++);
|
||||
it = mResponses.erase(it);
|
||||
}
|
||||
return responses;
|
||||
};
|
||||
|
||||
std::vector<Response> responses;
|
||||
if (timeout)
|
||||
{
|
||||
if (mResponsesCv.wait_for(lck, timeout.value(), pred))
|
||||
{
|
||||
storeResponses();
|
||||
responses = storeResponses();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
mResponsesCv.wait(lck, pred);
|
||||
storeResponses();
|
||||
responses = storeResponses();
|
||||
}
|
||||
return responses;
|
||||
}
|
||||
@ -980,15 +982,16 @@ std::vector<Response> Executor::Impl::awaitResponses(
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called");
|
||||
checkParallelApiUsage(__func__);
|
||||
std::vector<Response> responses;
|
||||
std::unique_lock<std::mutex> lck(mResponsesMtx);
|
||||
auto pred = [&mShutdown = mShutdown, &resp = this->mResponses, reqId]() -> bool
|
||||
{ return (resp.find(reqId) != resp.end() && !resp.at(reqId).empty()) || mShutdown; };
|
||||
auto storeIdResponse = [this, &resp = this->mResponses, &responses, reqId]()
|
||||
auto pred = [this, reqId]() -> bool
|
||||
{ return (mResponses.find(reqId) != mResponses.end() && !mResponses.at(reqId).empty()) || mShutdown; };
|
||||
auto storeIdResponse = [this, reqId]()
|
||||
{
|
||||
responses.swap(resp.at(reqId));
|
||||
resp.erase(reqId);
|
||||
std::vector<Response> responses;
|
||||
responses.swap(mResponses.at(reqId));
|
||||
mResponses.erase(reqId);
|
||||
addTerminatedReqId(responses, reqId);
|
||||
return responses;
|
||||
};
|
||||
|
||||
// We don't process a terminated request again. Terminated request is defined as a response
|
||||
@ -1005,17 +1008,18 @@ std::vector<Response> Executor::Impl::awaitResponses(
|
||||
return {Response(reqId, err)};
|
||||
}
|
||||
|
||||
std::vector<Response> responses;
|
||||
if (timeout)
|
||||
{
|
||||
if (mResponsesCv.wait_for(lck, timeout.value(), pred))
|
||||
{
|
||||
storeIdResponse();
|
||||
responses = storeIdResponse();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
mResponsesCv.wait(lck, pred);
|
||||
storeIdResponse();
|
||||
responses = storeIdResponse();
|
||||
}
|
||||
return responses;
|
||||
}
|
||||
@ -1025,26 +1029,27 @@ std::vector<std::vector<Response>> Executor::Impl::awaitResponses(
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called");
|
||||
checkParallelApiUsage(__func__);
|
||||
std::vector<std::vector<Response>> v(requestIds.size());
|
||||
std::vector<std::vector<Response>> responses;
|
||||
responses.reserve(requestIds.size());
|
||||
if (timeout)
|
||||
{
|
||||
auto const start_time = std::chrono::high_resolution_clock::now();
|
||||
for (unsigned i = 0; i < v.size(); ++i)
|
||||
for (auto const requestId : requestIds)
|
||||
{
|
||||
auto const elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
std::chrono::high_resolution_clock::now() - start_time);
|
||||
v[i] = awaitResponses(requestIds[i],
|
||||
timeout.value() > elapsed_ms ? timeout.value() - elapsed_ms : std::chrono::milliseconds{0});
|
||||
responses.emplace_back(awaitResponses(
|
||||
requestId, timeout.value() > elapsed_ms ? timeout.value() - elapsed_ms : std::chrono::milliseconds{0}));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (unsigned i = 0; i < v.size(); ++i)
|
||||
for (auto const requestId : requestIds)
|
||||
{
|
||||
v[i] = awaitResponses(requestIds[i]);
|
||||
responses.emplace_back(awaitResponses(requestId));
|
||||
}
|
||||
}
|
||||
return v;
|
||||
return responses;
|
||||
}
|
||||
|
||||
SizeType32 Executor::Impl::getNumResponsesReady(std::optional<IdType> const& optId) const
|
||||
@ -1663,7 +1668,7 @@ void Executor::Impl::terminateActiveRequests(RequestList& activeRequests, std::s
|
||||
}
|
||||
|
||||
// Remove from the requestList
|
||||
activeRequests.erase(it++);
|
||||
it = activeRequests.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -107,7 +107,7 @@ public:
|
||||
std::vector<Response> awaitResponses(std::optional<std::chrono::milliseconds> const& timeout = std::nullopt);
|
||||
|
||||
std::vector<Response> awaitResponses(
|
||||
IdType const& optId, std::optional<std::chrono::milliseconds> const& optTimeout = std::nullopt);
|
||||
IdType const& reqId, std::optional<std::chrono::milliseconds> const& optTimeout = std::nullopt);
|
||||
|
||||
std::vector<std::vector<Response>> awaitResponses(
|
||||
std::vector<IdType> const& requestIds, std::optional<std::chrono::milliseconds> const& timeout);
|
||||
|
||||
@ -30,7 +30,7 @@ Please refer to [this guide](https://nvidia.github.io/TensorRT-LLM/installation/
|
||||
- [trtllm-serve](#trtllm-serve)
|
||||
- [Disaggregated Serving](#disaggregated-serving)
|
||||
- [Dynamo](#dynamo)
|
||||
- [tensorrtllm_backend for triton inference server (Experimental)](#tensorrtllm_backend-for-triton-inference-server-experimental)
|
||||
- [tensorrtllm\_backend for triton inference server (Experimental)](#tensorrtllm_backend-for-triton-inference-server-experimental)
|
||||
- [Advanced Usages](#advanced-usages)
|
||||
- [Multi-node](#multi-node)
|
||||
- [mpirun](#mpirun)
|
||||
@ -40,6 +40,8 @@ Please refer to [this guide](https://nvidia.github.io/TensorRT-LLM/installation/
|
||||
- [FlashMLA](#flashmla)
|
||||
- [FP8 KV Cache and MLA](#fp8-kv-cache-and-mla)
|
||||
- [W4AFP8](#w4afp8)
|
||||
- [Activation calibration](#activation-calibration)
|
||||
- [Weight quantization and assembling](#weight-quantization-and-assembling)
|
||||
- [KV Cache Reuse](#kv-cache-reuse)
|
||||
- [Notes and Troubleshooting](#notes-and-troubleshooting)
|
||||
- [Known Issues](#known-issues)
|
||||
@ -227,6 +229,8 @@ trtllm-eval --model <YOUR_MODEL_DIR> \
|
||||
## Serving
|
||||
### trtllm-serve
|
||||
|
||||
Take max-throughput scenario on B200 as an example, the settings are extracted from the [blog](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.md#b200-max-throughput). **For users' own models and cases, the specific settings could be different to get best performance.**
|
||||
|
||||
To serve the model using `trtllm-serve`:
|
||||
|
||||
```bash
|
||||
@ -253,12 +257,12 @@ trtllm-serve \
|
||||
--host localhost \
|
||||
--port 8000 \
|
||||
--backend pytorch \
|
||||
--max_batch_size 161 \
|
||||
--max_num_tokens 1160 \
|
||||
--max_batch_size 384 \
|
||||
--max_num_tokens 1536 \
|
||||
--tp_size 8 \
|
||||
--ep_size 8 \
|
||||
--pp_size 1 \
|
||||
--kv_cache_free_gpu_memory_fraction 0.95 \
|
||||
--kv_cache_free_gpu_memory_fraction 0.85 \
|
||||
--extra_llm_api_options ./extra-llm-api-config.yml
|
||||
```
|
||||
|
||||
|
||||
@ -219,7 +219,6 @@ class PPInitCaller(type):
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
obj = type.__call__(cls, *args, **kwargs)
|
||||
obj.__pp_init__()
|
||||
return obj
|
||||
|
||||
|
||||
@ -235,6 +234,7 @@ class DecoderModel(nn.Module, metaclass=PPInitCaller):
|
||||
self.model_config = model_config
|
||||
self.prologue = []
|
||||
self.epilogue = []
|
||||
self.keep_embed_tokens = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -278,7 +278,7 @@ class DecoderModel(nn.Module, metaclass=PPInitCaller):
|
||||
)
|
||||
return
|
||||
|
||||
if hasattr(self, "embed_tokens"):
|
||||
if hasattr(self, "embed_tokens") and not self.keep_embed_tokens:
|
||||
self.prologue.append(self.embed_tokens)
|
||||
if hasattr(self, "norm"):
|
||||
self.epilogue.append(self.norm)
|
||||
@ -394,6 +394,8 @@ class DecoderModelForCausalLM(nn.Module,
|
||||
assert self.lm_head.tp_mode == self.model.embed_tokens.tp_mode, (
|
||||
"lm_head and vocab embedding should use the same TP mode")
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
if config.mapping.is_last_pp_rank():
|
||||
self.model.keep_embed_tokens = True
|
||||
|
||||
self.logits_processor = LogitsProcessor()
|
||||
|
||||
|
||||
@ -143,8 +143,9 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
|
||||
if not enable_chunked_prefill and max_num_tokens < dataset_metadata.max_isl:
|
||||
logger.warning(
|
||||
f"Chunked prefill is disabled, but max_num_tokens ({max_num_tokens}) is less than the max ISL ({dataset_metadata.max_isl}). "
|
||||
f"Forcing max_num_tokens to {dataset_metadata.max_isl}.")
|
||||
max_num_tokens = dataset_metadata.max_isl
|
||||
f"Forcing max_num_tokens to {dataset_metadata.max_isl + max_batch_size}."
|
||||
)
|
||||
max_num_tokens = dataset_metadata.max_isl + max_batch_size
|
||||
|
||||
pyt_options = {
|
||||
"use_cuda_graph":
|
||||
|
||||
@ -618,6 +618,11 @@ def compute_logprobs(
|
||||
# reshape from [1, T, V] to [T, V]
|
||||
logits = logits.squeeze(0)
|
||||
|
||||
if tokens is not None and logits.size(0) > len(tokens):
|
||||
# WAR for nvbug 5324291 where TRT backend might return more logits
|
||||
# than output tokens.
|
||||
logits = logits[:len(tokens)]
|
||||
|
||||
logprobs = F.log_softmax(logits.to("cuda", dtype=torch.float32), dim=-1)
|
||||
topk_vals, topk_indices = torch.topk(logprobs, k=top_k, dim=-1)
|
||||
|
||||
|
||||
@ -935,13 +935,17 @@ class ModelRunnerCpp(ModelRunnerMixin):
|
||||
output_ids = [[[] for _ in range(num_sequences)]
|
||||
for _ in range(len(request_ids))]
|
||||
|
||||
multi_responses = self.session.await_responses(request_ids)
|
||||
responses = [
|
||||
response for responses in multi_responses for response in responses
|
||||
]
|
||||
all_responses = []
|
||||
finished_request_ids = set()
|
||||
while finished_request_ids != set(request_ids):
|
||||
responses = self.session.await_responses()
|
||||
for response in responses:
|
||||
if response.result.is_final:
|
||||
finished_request_ids.add(response.request_id)
|
||||
all_responses.extend(responses)
|
||||
|
||||
return self._fill_output(
|
||||
responses=responses,
|
||||
responses=all_responses,
|
||||
output_ids=output_ids,
|
||||
end_id=end_id,
|
||||
return_dict=return_dict,
|
||||
|
||||
@ -93,6 +93,7 @@ def test_llm_hf_gemma_quantization_1gpu_vswa(batch_size, data_type,
|
||||
gemma_example_root,
|
||||
llm_datasets_root, llm_rouge_root,
|
||||
qformat):
|
||||
skip_fp8_pre_ada(use_fp8=qformat == "fp8")
|
||||
max_attention_window = VSWA_ATTENTION[Path(gemma_model_root).stem]
|
||||
hf_gemma_quantization_1gpu(batch_size, data_type, gemma_model_root,
|
||||
llm_venv, cmodel_dir, engine_dir,
|
||||
|
||||
@ -1977,15 +1977,15 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
|
||||
],
|
||||
"video": [
|
||||
["city", "night", "lights", "jacket", "wet"],
|
||||
["earth", "spinning", "black", "illuminated", "lights"],
|
||||
["earth", "spinning", "black"],
|
||||
],
|
||||
},
|
||||
"qwen2.5-vl-7b-instruct": {
|
||||
"image": [
|
||||
["dramatic", "moody", "stormy", "turbulent", "wave"],
|
||||
[
|
||||
"dome", "yosemite", "landmark", "sunny", "rock", "clouds",
|
||||
"pleasant"
|
||||
"large", "dome", "yosemite", "landmark", "rock", "road",
|
||||
"formation"
|
||||
],
|
||||
["highway", "traffic", "vehicles", "bus", "police"],
|
||||
],
|
||||
|
||||
@ -574,15 +574,16 @@ def setup_cache_data(request, tensorrt_llm_example_root):
|
||||
|
||||
|
||||
def cleanup_engine_outputs(output_dir_root):
|
||||
for dirpath, dirnames, _ in os.walk(output_dir_root, topdown=False):
|
||||
for dirname in dirnames:
|
||||
if "engine_dir" in dirname or "model_dir" in dirname or "ckpt_dir" in dirname:
|
||||
folder_path = os.path.join(dirpath, dirname)
|
||||
try:
|
||||
shutil.rmtree(folder_path)
|
||||
print_info(f"Deleted folder: {folder_path}")
|
||||
except Exception as e:
|
||||
print_info(f"Error deleting {folder_path}: {e}")
|
||||
if output_dir_root is not None:
|
||||
for dirpath, dirnames, _ in os.walk(output_dir_root, topdown=False):
|
||||
for dirname in dirnames:
|
||||
if "engine_dir" in dirname or "model_dir" in dirname or "ckpt_dir" in dirname:
|
||||
folder_path = os.path.join(dirpath, dirname)
|
||||
try:
|
||||
shutil.rmtree(folder_path)
|
||||
print_info(f"Deleted folder: {folder_path}")
|
||||
except Exception as e:
|
||||
print_info(f"Error deleting {folder_path}: {e}")
|
||||
|
||||
|
||||
# Teardown hook to clean up engine outputs after each group of test cases are finished
|
||||
|
||||
@ -70,6 +70,7 @@ def test_llama_v3_8b_rss_increasement(
|
||||
inflight_batcher_llm_client_root,
|
||||
tensorrt_llm_llama_example_root,
|
||||
llama_v3_8b_model_root,
|
||||
tensorrt_llm_example_root,
|
||||
llm_backend_venv,
|
||||
):
|
||||
if BATCHING_STRATEGY == "V1" and BATCH_SCHEDULER_POLICY == "max_utilization":
|
||||
@ -83,7 +84,8 @@ def test_llama_v3_8b_rss_increasement(
|
||||
|
||||
llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"]
|
||||
# Build engine
|
||||
ENGINE_PATH = prepare_llama_v3_8b_engine(tensorrt_llm_llama_example_root,
|
||||
ENGINE_PATH = prepare_llama_v3_8b_engine(tensorrt_llm_example_root,
|
||||
tensorrt_llm_llama_example_root,
|
||||
llama_v3_8b_model_root,
|
||||
workers=1)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user