mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Adding option to specify a set of token ids for multimodal tokens (#4107)
Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
This commit is contained in:
parent
f670a036df
commit
bf9ac96de3
@ -32,6 +32,7 @@ def fuse_input_embeds(
|
||||
embedding_layer: Embedding,
|
||||
input_ids: torch.LongTensor,
|
||||
mm_embeds: List[torch.Tensor],
|
||||
mm_token_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[Optional[torch.FloatTensor], Optional[torch.FloatTensor]]:
|
||||
"""
|
||||
Fuse text and multimodal embeddings. input_ids is [text_total_length + mm_total_length] and mm_embed is [mm_total_length, hidden_dim]. We just need to fuse them into [text_total_length + mm_total_length, hidden_dim] by slice-and-assign to the corresponding entries.
|
||||
@ -39,6 +40,7 @@ def fuse_input_embeds(
|
||||
Args:
|
||||
input_ids: shape [text_total_length + mm_total_length], flattened from List[(text_length1 + mm_total_length1), ..., (text_lengthi + mm_total_lengthi)]. For LLM model, the requests are inflight batched together, but the input_ids are flattened with padding removed. By the slice condition < vocab_size, we can easily separate text / multimodal tokens and naturally batched the LLM embedding lookup
|
||||
mm_embed: List[(mm_total_length1, hidden_dim), ..., (mm_total_lengthi, hidden_dim)].
|
||||
mm_token_ids: possible token ids for multimodal tokens, if known. If not known and set to None, it is assumed that the multimodal tokens are out-of-vocabulary tokens i.e. the `input_ids` contains tokens >= vocab_size that represent the multimodal tokens.
|
||||
Returns:
|
||||
- If (1) JIT test run, (2) non-multimodal run, i.e. all text-only requests, either context or generation phase (3) multimodal run, all requests in generation phase --> there is no multimodal data, return only the input_ids
|
||||
- If (4) multimodal run, mixed batch of context and generation requests, each context request has a multimodal feature --> return only the fused input_embeds of shape [total length, hidden_dim]. For text tokens, LLM embedding layer has already run.
|
||||
@ -46,11 +48,26 @@ def fuse_input_embeds(
|
||||
if len(mm_embeds) == 0:
|
||||
return input_ids, None
|
||||
|
||||
vocab_size = embedding_layer.num_embeddings
|
||||
mm_embed = torch.cat(mm_embeds, dim=0)
|
||||
|
||||
text_token_indices = torch.where(input_ids < vocab_size)[0]
|
||||
mm_token_indices = torch.where(input_ids >= vocab_size)[0]
|
||||
if mm_token_ids is None:
|
||||
# NOTE:
|
||||
# If mm_token_ids is None, it is assumed that the multimodal
|
||||
# tokens are out-of-vocab tokens i.e. the `input_ids` contains
|
||||
# tokens >= vocab_size that represent the multimodal tokens.
|
||||
# Since mm_token_ids is be unbounded in this case,
|
||||
# using torch.isin() may not be performant.
|
||||
# This provides a more performant alternative while keeping
|
||||
# the flexibility of still specifying all possible mm_token_ids,
|
||||
# if the user wants to.
|
||||
vocab_size = embedding_layer.num_embeddings
|
||||
mm_token_mask = input_ids >= vocab_size
|
||||
text_token_mask = input_ids < vocab_size
|
||||
else:
|
||||
mm_token_mask = torch.isin(input_ids, mm_token_ids)
|
||||
text_token_mask = ~mm_token_mask
|
||||
text_token_indices = torch.where(text_token_mask)[0]
|
||||
mm_token_indices = torch.where(mm_token_mask)[0]
|
||||
|
||||
text_embed = embedding_layer(input_ids[text_token_indices])
|
||||
input_embeds = torch.empty(input_ids.shape[0],
|
||||
|
||||
Loading…
Reference in New Issue
Block a user