chore: Handle qwen2audio inputs ids expansion during processing (#3080)

* Handle qwen2audio inputs ids expansion during processing

Signed-off-by: Aurelien Chartier <achartier@nvidia.com>

* remove more dead code

Signed-off-by: Aurelien Chartier <achartier@nvidia.com>

* fix yapf

Signed-off-by: Aurelien Chartier <achartier@nvidia.com>

---------

Signed-off-by: Aurelien Chartier <achartier@nvidia.com>
Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
This commit is contained in:
Aurelien Chartier 2025-03-26 00:00:27 -07:00 committed by GitHub
parent 3c7cb6629c
commit 0ec7b5701f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 42 deletions

View File

@ -328,13 +328,9 @@ class QWenInfer(object):
input_tokens,
args,
prompt_table=None,
tasks=None,
task_vocab_size=None,
extra_ids=None,
run_time=1,
):
input_ids = None
input_lengths = None
input_ids = torch.as_tensor(input_tokens,
device=self.gpu_device,
dtype=torch.int32)
@ -398,8 +394,7 @@ class QWenInfer(object):
stream,
history=None,
past_audio_features=None,
run_time=1,
gpu_id=0):
run_time=1):
assert input_text, "input_text must be provided"
assert torch.cuda.is_available(), "no gpu available"
# preprocess on CPU maybe faster
@ -464,9 +459,7 @@ class QWenInfer(object):
# 1. Create a mask to know where special audio tokens are
special_audio_token_mask = input_ids == self.config.audio_token_index
special_audio_token_num = special_audio_token_mask.sum().item()
if past_audio_features is None:
assert special_audio_token_num == num_audios, f'special_audio_token_num {special_audio_token_num} should be equal to num_audios {num_audios}'
else:
if past_audio_features is not None:
assert isinstance(past_audio_features,
list), f'past_audio_features should be a list'
assert (
@ -497,40 +490,16 @@ class QWenInfer(object):
batch_indices, non_audio_indices = torch.where(
input_ids != self.config.audio_token_index)
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged audio-text sequence.
# `special_audio_token_mask` identifies audio tokens. Each audio token will be replaced by `audio_feat_lengths - 1` text tokens.
# `torch.cumsum` computes how each audio token shifts subsequent text token positions.
token_placeholder_num = torch.zeros_like(input_ids, device=device)
token_placeholder_num[
special_audio_token_mask] = num_audio_tokens.long() - 1
token_placeholder_num = token_placeholder_num + 1
new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1
max_token_num = token_placeholder_num.sum(-1).max()
text_to_overwrite = new_token_positions[batch_indices,
non_audio_indices]
# 2. Fill the final input ids based on the mask.
batch_indices, audio_indices = torch.where(
input_ids == self.config.audio_token_index)
# 3. Create the final_input_ids, already padded to the maximum position
final_input_ids = torch.full((batch_size, max_token_num),
self.config.audio_token_index,
dtype=input_ids.dtype,
device=device)
# 4. Fill the final_input_ids based on the mask. If we have ["hey" "<audio>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the audio features
final_input_ids[batch_indices,
text_to_overwrite] = input_ids[batch_indices,
non_audio_indices]
vocab_size = self.config.vocab_size
fake_prompt_id = torch.arange(vocab_size,
vocab_size + num_audio_tokens.sum(),
device=device)
batch_indices, audio_indices = torch.where(
final_input_ids == self.config.audio_token_index)
final_input_ids[batch_indices, audio_indices] = fake_prompt_id
input_ids = final_input_ids.contiguous().to(dtype=torch.int32,
device=self.gpu_device)
input_ids[batch_indices, audio_indices] = fake_prompt_id
input_lengths = torch.tensor(input_ids.size(1),
dtype=torch.int32,
device=self.gpu_device)
@ -568,8 +537,7 @@ class QWenInfer(object):
# print(f"extra_ids: {extra_ids}")
output_ids, Qwen_time = self.generate_for_qwen_audio(
input_ids, args, prompt_table, tasks, task_vocab_size, extra_ids,
run_time)
input_ids, args, prompt_table, extra_ids, run_time)
runtime_rank = tensorrt_llm.mpi_rank()
input_lengths = torch.tensor([input_ids.size(1)],

View File

@ -41,9 +41,6 @@ def test_llm_qwen2audio_single_gpu(qwen2audio_example_root, llm_qwen_model_root,
"Build & run qwen2audio on 1 gpu."
workspace = llm_venv.get_working_directory()
# https://nvbugs/5136784
llm_venv.run_cmd(['-m', 'pip', 'install', 'transformers==4.47.1'])
print("Generate audio engine...")
audio_engine_dir = f"{engine_dir}/audio"
audio_cmd = [