mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
* Why? Certain VLMs like the Qwen family need more than just the multimodal embeddings in the language model, and need MRoPE position IDs and deltas. Prior to this commit, only the embeddings could be communicated from the encoder worker to the prefill worker. * What? This commit extends the `DisaggregatedParams` to include the MRoPE information. It also adjusts several pieces of code required to communicate that between E, P and D workers. Closes TRTLLM-9409. Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
95 lines
4.2 KiB
Python
95 lines
4.2 KiB
Python
from dataclasses import dataclass
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import numpy as np
|
|
|
|
# isort: off
|
|
# needed before trying to import bindings to load tensorrt_libs
|
|
import tensorrt as trt # noqa
|
|
# isort: on
|
|
|
|
from tensorrt_llm.bindings import executor as tllme
|
|
|
|
|
|
@dataclass(slots=True, kw_only=True)
|
|
class DisaggregatedParams:
|
|
"""Disaggregated serving parameters.
|
|
|
|
Args:
|
|
request_type (str): The type of request ("context_only" | "generation_only" | "context_and_generation")
|
|
first_gen_tokens (List[int]): The first tokens of the generation request
|
|
ctx_request_id (int): The context request id
|
|
opaque_state(bytes): Any additional state needing to be exchanged between context and gen instances
|
|
draft_tokens (List[int]): The draft tokens of the generation request
|
|
|
|
multimodal_embedding_handles (List[Dict[str, Any]]): The resulting multimodal embedding handles from ViT.
|
|
multimodal_hashes (List[List[int]]): The multimodal hashes of each multimodal item in the request.
|
|
"""
|
|
|
|
request_type: Optional[str] = None
|
|
# P-D Disaggregated Params
|
|
first_gen_tokens: Optional[List[int]] = None
|
|
ctx_request_id: Optional[int] = None
|
|
opaque_state: Optional[bytes] = None
|
|
draft_tokens: Optional[List[int]] = None
|
|
|
|
# E-P Disaggregated Params
|
|
multimodal_embedding_handles: Optional[List[Dict[str, Any]]] = (
|
|
None # multimodal embedding handles should be a list of cudaIPC handles for each mm_embedding
|
|
)
|
|
multimodal_hashes: Optional[List[List[int]]] = (
|
|
None # user provided mm hashes should be a list of 8 integers
|
|
)
|
|
mrope_position_ids_handle: Optional[Dict[str, Any]] = None
|
|
mrope_position_deltas_handle: Optional[Dict[str, Any]] = None
|
|
|
|
def get_context_phase_params(self) -> tllme.ContextPhaseParams:
|
|
return tllme.ContextPhaseParams(
|
|
self.first_gen_tokens, self.ctx_request_id, self.opaque_state, self.draft_tokens
|
|
)
|
|
|
|
def get_request_type(self) -> tllme.RequestType:
|
|
if self.request_type == "context_only":
|
|
return tllme.RequestType.REQUEST_TYPE_CONTEXT_ONLY
|
|
elif self.request_type == "generation_only":
|
|
return tllme.RequestType.REQUEST_TYPE_GENERATION_ONLY
|
|
elif self.request_type == "context_and_generation":
|
|
return tllme.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown request type: {self.request_type}. Must be context_only, generation_only or "
|
|
"context_and_generation"
|
|
)
|
|
|
|
def __post_init__(self):
|
|
if self.request_type is not None:
|
|
self.request_type = self.request_type.lower()
|
|
if self.request_type not in [
|
|
"context_only",
|
|
"generation_only",
|
|
"context_and_generation",
|
|
]:
|
|
raise ValueError(
|
|
f"Unknown request type: {self.request_type}. Must be context_only, generation_only or "
|
|
"context_and_generation"
|
|
)
|
|
if self.multimodal_embedding_handles is not None:
|
|
if self.multimodal_hashes is not None:
|
|
# if mm hashes are provided, kvcache reuse can be enabled
|
|
assert len(self.multimodal_embedding_handles) == len(self.multimodal_hashes), (
|
|
"multimodal_embedding_handles and multimodal_hashes must have the same length"
|
|
)
|
|
for mm_hash in self.multimodal_hashes:
|
|
assert isinstance(mm_hash, list), "mm_hash must be a list"
|
|
assert len(mm_hash) == 8, "mm_hash must be a list of 8 integers"
|
|
assert all(isinstance(x, int) for x in mm_hash), "mm_hash must contain integers"
|
|
else:
|
|
# if user did not provide mm embedding handles, kvcache reuse will be disabled
|
|
assert len(self.multimodal_embedding_handles) > 0, (
|
|
"multimodal_embedding_handles must be provided"
|
|
)
|
|
vals = np.random.randint(
|
|
np.iinfo(np.int32).min, np.iinfo(np.int32).max, size=8, dtype=np.int32
|
|
).tolist()
|
|
self.multimodal_hashes = [vals] * len(self.multimodal_embedding_handles)
|