From bf9ac96de3801a09695a41226f07683c32605912 Mon Sep 17 00:00:00 2001 From: rakib-hasan Date: Tue, 6 May 2025 21:15:41 -0700 Subject: [PATCH] Adding option to specify a set of token ids for multimodal tokens (#4107) Signed-off-by: Rakib Hasan --- .../models/modeling_multimodal_utils.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_multimodal_utils.py b/tensorrt_llm/_torch/models/modeling_multimodal_utils.py index 620ca65f5c..e97dcc6e07 100644 --- a/tensorrt_llm/_torch/models/modeling_multimodal_utils.py +++ b/tensorrt_llm/_torch/models/modeling_multimodal_utils.py @@ -32,6 +32,7 @@ def fuse_input_embeds( embedding_layer: Embedding, input_ids: torch.LongTensor, mm_embeds: List[torch.Tensor], + mm_token_ids: Optional[torch.LongTensor] = None, ) -> Tuple[Optional[torch.FloatTensor], Optional[torch.FloatTensor]]: """ Fuse text and multimodal embeddings. input_ids is [text_total_length + mm_total_length] and mm_embed is [mm_total_length, hidden_dim]. We just need to fuse them into [text_total_length + mm_total_length, hidden_dim] by slice-and-assign to the corresponding entries. @@ -39,6 +40,7 @@ def fuse_input_embeds( Args: input_ids: shape [text_total_length + mm_total_length], flattened from List[(text_length1 + mm_total_length1), ..., (text_lengthi + mm_total_lengthi)]. For LLM model, the requests are inflight batched together, but the input_ids are flattened with padding removed. By the slice condition < vocab_size, we can easily separate text / multimodal tokens and naturally batched the LLM embedding lookup mm_embed: List[(mm_total_length1, hidden_dim), ..., (mm_total_lengthi, hidden_dim)]. + mm_token_ids: possible token ids for multimodal tokens, if known. If not known and set to None, it is assumed that the multimodal tokens are out-of-vocabulary tokens i.e. the `input_ids` contains tokens >= vocab_size that represent the multimodal tokens. Returns: - If (1) JIT test run, (2) non-multimodal run, i.e. all text-only requests, either context or generation phase (3) multimodal run, all requests in generation phase --> there is no multimodal data, return only the input_ids - If (4) multimodal run, mixed batch of context and generation requests, each context request has a multimodal feature --> return only the fused input_embeds of shape [total length, hidden_dim]. For text tokens, LLM embedding layer has already run. @@ -46,11 +48,26 @@ def fuse_input_embeds( if len(mm_embeds) == 0: return input_ids, None - vocab_size = embedding_layer.num_embeddings mm_embed = torch.cat(mm_embeds, dim=0) - text_token_indices = torch.where(input_ids < vocab_size)[0] - mm_token_indices = torch.where(input_ids >= vocab_size)[0] + if mm_token_ids is None: + # NOTE: + # If mm_token_ids is None, it is assumed that the multimodal + # tokens are out-of-vocab tokens i.e. the `input_ids` contains + # tokens >= vocab_size that represent the multimodal tokens. + # Since mm_token_ids is be unbounded in this case, + # using torch.isin() may not be performant. + # This provides a more performant alternative while keeping + # the flexibility of still specifying all possible mm_token_ids, + # if the user wants to. + vocab_size = embedding_layer.num_embeddings + mm_token_mask = input_ids >= vocab_size + text_token_mask = input_ids < vocab_size + else: + mm_token_mask = torch.isin(input_ids, mm_token_ids) + text_token_mask = ~mm_token_mask + text_token_indices = torch.where(text_token_mask)[0] + mm_token_indices = torch.where(mm_token_mask)[0] text_embed = embedding_layer(input_ids[text_token_indices]) input_embeds = torch.empty(input_ids.shape[0],