mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
chore: Handle qwen2audio inputs ids expansion during processing (#3080)
* Handle qwen2audio inputs ids expansion during processing Signed-off-by: Aurelien Chartier <achartier@nvidia.com> * remove more dead code Signed-off-by: Aurelien Chartier <achartier@nvidia.com> * fix yapf Signed-off-by: Aurelien Chartier <achartier@nvidia.com> --------- Signed-off-by: Aurelien Chartier <achartier@nvidia.com> Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
This commit is contained in:
parent
3c7cb6629c
commit
0ec7b5701f
@ -328,13 +328,9 @@ class QWenInfer(object):
|
||||
input_tokens,
|
||||
args,
|
||||
prompt_table=None,
|
||||
tasks=None,
|
||||
task_vocab_size=None,
|
||||
extra_ids=None,
|
||||
run_time=1,
|
||||
):
|
||||
input_ids = None
|
||||
input_lengths = None
|
||||
input_ids = torch.as_tensor(input_tokens,
|
||||
device=self.gpu_device,
|
||||
dtype=torch.int32)
|
||||
@ -398,8 +394,7 @@ class QWenInfer(object):
|
||||
stream,
|
||||
history=None,
|
||||
past_audio_features=None,
|
||||
run_time=1,
|
||||
gpu_id=0):
|
||||
run_time=1):
|
||||
assert input_text, "input_text must be provided"
|
||||
assert torch.cuda.is_available(), "no gpu available"
|
||||
# preprocess on CPU maybe faster
|
||||
@ -464,9 +459,7 @@ class QWenInfer(object):
|
||||
# 1. Create a mask to know where special audio tokens are
|
||||
special_audio_token_mask = input_ids == self.config.audio_token_index
|
||||
special_audio_token_num = special_audio_token_mask.sum().item()
|
||||
if past_audio_features is None:
|
||||
assert special_audio_token_num == num_audios, f'special_audio_token_num {special_audio_token_num} should be equal to num_audios {num_audios}'
|
||||
else:
|
||||
if past_audio_features is not None:
|
||||
assert isinstance(past_audio_features,
|
||||
list), f'past_audio_features should be a list'
|
||||
assert (
|
||||
@ -497,40 +490,16 @@ class QWenInfer(object):
|
||||
batch_indices, non_audio_indices = torch.where(
|
||||
input_ids != self.config.audio_token_index)
|
||||
|
||||
# 2. Compute the positions where text should be written
|
||||
# Calculate new positions for text tokens in merged audio-text sequence.
|
||||
# `special_audio_token_mask` identifies audio tokens. Each audio token will be replaced by `audio_feat_lengths - 1` text tokens.
|
||||
# `torch.cumsum` computes how each audio token shifts subsequent text token positions.
|
||||
token_placeholder_num = torch.zeros_like(input_ids, device=device)
|
||||
token_placeholder_num[
|
||||
special_audio_token_mask] = num_audio_tokens.long() - 1
|
||||
token_placeholder_num = token_placeholder_num + 1
|
||||
new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1
|
||||
max_token_num = token_placeholder_num.sum(-1).max()
|
||||
text_to_overwrite = new_token_positions[batch_indices,
|
||||
non_audio_indices]
|
||||
# 2. Fill the final input ids based on the mask.
|
||||
batch_indices, audio_indices = torch.where(
|
||||
input_ids == self.config.audio_token_index)
|
||||
|
||||
# 3. Create the final_input_ids, already padded to the maximum position
|
||||
final_input_ids = torch.full((batch_size, max_token_num),
|
||||
self.config.audio_token_index,
|
||||
dtype=input_ids.dtype,
|
||||
device=device)
|
||||
|
||||
# 4. Fill the final_input_ids based on the mask. If we have ["hey" "<audio>", "how", "are"]
|
||||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the audio features
|
||||
final_input_ids[batch_indices,
|
||||
text_to_overwrite] = input_ids[batch_indices,
|
||||
non_audio_indices]
|
||||
vocab_size = self.config.vocab_size
|
||||
fake_prompt_id = torch.arange(vocab_size,
|
||||
vocab_size + num_audio_tokens.sum(),
|
||||
device=device)
|
||||
batch_indices, audio_indices = torch.where(
|
||||
final_input_ids == self.config.audio_token_index)
|
||||
final_input_ids[batch_indices, audio_indices] = fake_prompt_id
|
||||
|
||||
input_ids = final_input_ids.contiguous().to(dtype=torch.int32,
|
||||
device=self.gpu_device)
|
||||
input_ids[batch_indices, audio_indices] = fake_prompt_id
|
||||
input_lengths = torch.tensor(input_ids.size(1),
|
||||
dtype=torch.int32,
|
||||
device=self.gpu_device)
|
||||
@ -568,8 +537,7 @@ class QWenInfer(object):
|
||||
|
||||
# print(f"extra_ids: {extra_ids}")
|
||||
output_ids, Qwen_time = self.generate_for_qwen_audio(
|
||||
input_ids, args, prompt_table, tasks, task_vocab_size, extra_ids,
|
||||
run_time)
|
||||
input_ids, args, prompt_table, extra_ids, run_time)
|
||||
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
input_lengths = torch.tensor([input_ids.size(1)],
|
||||
|
||||
@ -41,9 +41,6 @@ def test_llm_qwen2audio_single_gpu(qwen2audio_example_root, llm_qwen_model_root,
|
||||
"Build & run qwen2audio on 1 gpu."
|
||||
workspace = llm_venv.get_working_directory()
|
||||
|
||||
# https://nvbugs/5136784
|
||||
llm_venv.run_cmd(['-m', 'pip', 'install', 'transformers==4.47.1'])
|
||||
|
||||
print("Generate audio engine...")
|
||||
audio_engine_dir = f"{engine_dir}/audio"
|
||||
audio_cmd = [
|
||||
|
||||
Loading…
Reference in New Issue
Block a user