mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix: fix accuracy and illegal memory access issues when using mtp + attention dp (#4379)
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
This commit is contained in:
parent
2ce05c3ab4
commit
7d356efc7d
@ -239,7 +239,7 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
|
||||
self.py_max_new_tokens = self.max_new_tokens
|
||||
self.py_batch_idx = None
|
||||
self.py_rewind_len = 0
|
||||
self.py_draft_tokens = self.draft_tokens
|
||||
self.py_draft_tokens = [] if self.draft_tokens is None else self.draft_tokens
|
||||
self.py_last_draft_tokens = None
|
||||
self.py_decoding_iter = 0
|
||||
self.is_attention_dp_dummy = False
|
||||
|
||||
@ -1087,7 +1087,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
extend_requests = []
|
||||
generation_requests = []
|
||||
for request in scheduled_requests.generation_requests:
|
||||
if request.py_draft_tokens is not None or next_draft_tokens_device is not None:
|
||||
if len(request.py_draft_tokens
|
||||
) > 0 or next_draft_tokens_device is not None:
|
||||
extend_requests.append(request)
|
||||
else:
|
||||
generation_requests.append(request)
|
||||
@ -1107,32 +1108,43 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# will contain previous batch incices of generation requests
|
||||
previous_batch_indices = []
|
||||
previous_pos_indices = []
|
||||
request_ids_with_previous_batch = []
|
||||
num_extend_reqs_wo_previous_batch = 0
|
||||
for request in extend_requests:
|
||||
if next_draft_tokens_device is None or request.py_batch_idx is None:
|
||||
num_draft_tokens = len(request.py_draft_tokens)
|
||||
# the request has no previous device tensors:
|
||||
# (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or
|
||||
# (2) request.py_batch_idx is None, which means the request has no previous batch.
|
||||
# the second condition includes dummy generation requests created for CUDA graph padding or
|
||||
# attention DP. These dummy generation requests should be at the head of generation_requests.
|
||||
# TODO: move the dummy generation requests to the end of generation_requests to align with
|
||||
# the logic for those requests in generation_requests.
|
||||
# get token ids, including input token ids and draft token ids
|
||||
input_ids.append(request.get_last_tokens(0))
|
||||
gather_ids.append(len(input_ids) - 1)
|
||||
sequence_lengths.append(1 + num_draft_tokens)
|
||||
past_seen_token_num = request.max_beam_num_tokens - 1
|
||||
position_ids.append(past_seen_token_num)
|
||||
draft_lens.append(num_draft_tokens)
|
||||
prompt_lengths.append(num_draft_tokens + 1)
|
||||
# draft tokens
|
||||
input_ids.extend(request.py_draft_tokens)
|
||||
draft_tokens.extend(request.py_draft_tokens)
|
||||
# get other ids and lengths
|
||||
num_draft_tokens = len(request.py_draft_tokens)
|
||||
past_seen_token_num = request.max_beam_num_tokens - 1
|
||||
draft_lens.append(num_draft_tokens)
|
||||
prompt_lengths.append(request.py_prompt_len)
|
||||
sequence_lengths.append(1 + num_draft_tokens)
|
||||
gather_ids.extend(
|
||||
list(
|
||||
range(
|
||||
len(input_ids) - num_draft_tokens, len(input_ids))))
|
||||
range(len(position_ids),
|
||||
len(position_ids) + 1 + self.max_draft_len)))
|
||||
position_ids.extend(
|
||||
list(
|
||||
range(past_seen_token_num + 1,
|
||||
range(past_seen_token_num,
|
||||
past_seen_token_num + 1 + num_draft_tokens)))
|
||||
draft_tokens.extend(request.py_draft_tokens)
|
||||
num_cached_tokens_per_seq.append(past_seen_token_num)
|
||||
request_ids.append(request.py_request_id)
|
||||
# update batch index
|
||||
request.py_batch_idx = batch_idx
|
||||
batch_idx += 1
|
||||
num_extend_reqs_wo_previous_batch += 1
|
||||
else:
|
||||
# batch index
|
||||
# update batch index
|
||||
previous_batch_idx = request.py_batch_idx
|
||||
request.py_batch_idx = batch_idx
|
||||
batch_idx += 1
|
||||
@ -1157,8 +1169,10 @@ class PyTorchModelEngine(ModelEngine):
|
||||
num_cached_tokens_per_seq.append(past_seen_token_num +
|
||||
self.max_draft_len + 1)
|
||||
prompt_lengths.append(request.py_prompt_len)
|
||||
request_ids_with_previous_batch.append(request.py_request_id)
|
||||
|
||||
request_ids.append(request.py_request_id)
|
||||
# move requests with previous batch to the end of the list
|
||||
request_ids.extend(request_ids_with_previous_batch)
|
||||
|
||||
sequence_lengths.extend([1] * len(generation_requests))
|
||||
gather_ids.extend(
|
||||
@ -1191,12 +1205,20 @@ class PyTorchModelEngine(ModelEngine):
|
||||
batch_idx += 1
|
||||
|
||||
num_tokens = len(input_ids)
|
||||
num_draft_tokens = len(draft_tokens)
|
||||
previous_batchs = len(previous_batch_indices)
|
||||
# if exist requests that do not have previous batch, copy input_ids and draft_tokens
|
||||
if num_tokens > 0:
|
||||
input_ids = torch.tensor(input_ids,
|
||||
dtype=torch.int,
|
||||
pin_memory=True)
|
||||
self.input_ids_cuda[:num_tokens].copy_(input_ids, non_blocking=True)
|
||||
if num_draft_tokens > 0:
|
||||
draft_tokens = torch.tensor(draft_tokens,
|
||||
dtype=torch.int,
|
||||
pin_memory=True)
|
||||
self.draft_tokens_cuda[:len(draft_tokens)].copy_(draft_tokens,
|
||||
non_blocking=True)
|
||||
if next_draft_tokens_device is not None:
|
||||
if len(previous_batch_indices) > 0:
|
||||
previous_batch_indices = torch.tensor(previous_batch_indices,
|
||||
@ -1215,26 +1237,39 @@ class PyTorchModelEngine(ModelEngine):
|
||||
non_blocking=True)
|
||||
# previous draft tokens
|
||||
previous_batch_draft_tokens = previous_batchs * self.max_draft_len
|
||||
self.draft_tokens_cuda[:previous_batch_draft_tokens].copy_(
|
||||
next_draft_tokens_device[
|
||||
self.draft_tokens_cuda[
|
||||
num_draft_tokens:num_draft_tokens +
|
||||
previous_batch_draft_tokens].copy_(next_draft_tokens_device[
|
||||
self.previous_batch_indices_cuda[:previous_batchs], :].
|
||||
flatten(),
|
||||
non_blocking=True)
|
||||
flatten(),
|
||||
non_blocking=True)
|
||||
# prepare data for the preprocess inputs
|
||||
kv_len_offsets_device = new_tokens_lens_device - self.max_draft_len - 1
|
||||
pre_tokens_start_idx = num_extend_reqs_wo_previous_batch * (
|
||||
1 + self.max_draft_len)
|
||||
pre_tokens_end_idx = pre_tokens_start_idx + previous_batch_tokens
|
||||
pre_batch_start_idx = num_extend_reqs_wo_previous_batch
|
||||
pre_batch_end_idx = pre_batch_start_idx + previous_batchs
|
||||
previous_pos_indices = torch.tensor(previous_pos_indices,
|
||||
dtype=torch.int,
|
||||
pin_memory=True)
|
||||
self.previous_pos_indices_cuda[:previous_batch_tokens].copy_(
|
||||
previous_pos_indices, non_blocking=True)
|
||||
self.previous_pos_id_offsets_cuda[:previous_batch_tokens].copy_(
|
||||
new_tokens_lens_device[
|
||||
self.previous_pos_indices_cuda[:previous_batch_tokens]],
|
||||
non_blocking=True)
|
||||
self.previous_kv_lens_offsets_cuda[:previous_batchs].copy_(
|
||||
kv_len_offsets_device[
|
||||
self.previous_batch_indices_cuda[:previous_batchs]],
|
||||
non_blocking=True)
|
||||
self.previous_pos_indices_cuda[
|
||||
pre_tokens_start_idx:pre_tokens_end_idx].copy_(
|
||||
previous_pos_indices, non_blocking=True)
|
||||
self.previous_pos_id_offsets_cuda[
|
||||
pre_tokens_start_idx:pre_tokens_end_idx].copy_(
|
||||
new_tokens_lens_device[self.previous_pos_indices_cuda[
|
||||
pre_tokens_start_idx:pre_tokens_end_idx]],
|
||||
non_blocking=True)
|
||||
self.previous_kv_lens_offsets_cuda[
|
||||
pre_batch_start_idx:pre_batch_end_idx].copy_(
|
||||
kv_len_offsets_device[
|
||||
self.previous_batch_indices_cuda[:previous_batchs]],
|
||||
non_blocking=True)
|
||||
# for the requests that do not have previous batch, set the previous_pos_id_offsets and
|
||||
# previous_kv_lens_offsets to zeros to skip the value changes in _preprocess_inputs
|
||||
self.previous_pos_id_offsets_cuda[:pre_tokens_start_idx] *= 0
|
||||
self.previous_kv_lens_offsets_cuda[:pre_batch_start_idx] *= 0
|
||||
else:
|
||||
# change the data to zeros to skip the value changes in _preprocess_inputs
|
||||
self.previous_pos_id_offsets_cuda *= 0
|
||||
@ -1305,12 +1340,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
if spec_metadata is not None:
|
||||
total_draft_lens = sum(draft_lens)
|
||||
if len(draft_tokens) > 0:
|
||||
draft_tokens = torch.tensor(draft_tokens,
|
||||
dtype=torch.int,
|
||||
pin_memory=True)
|
||||
self.draft_tokens_cuda[:len(draft_tokens)].copy_(
|
||||
draft_tokens, non_blocking=True)
|
||||
spec_metadata.draft_tokens = self.draft_tokens_cuda[:
|
||||
total_draft_lens]
|
||||
spec_metadata.request_ids = request_ids
|
||||
|
||||
@ -1488,8 +1488,7 @@ class PyExecutor:
|
||||
request_ids=[0],
|
||||
is_gen=not self.has_context_request,
|
||||
prepare_resource=not self.has_context_request,
|
||||
max_num_draft_tokens=0
|
||||
if self.has_context_request else self.max_draft_tokens,
|
||||
max_num_draft_tokens=self.max_draft_tokens,
|
||||
)[0]
|
||||
llm_request.is_attention_dp_dummy = True
|
||||
self.active_requests.append(llm_request)
|
||||
@ -1525,7 +1524,8 @@ class PyExecutor:
|
||||
req.decoding_iter = 1
|
||||
req.py_decoding_iter = 1
|
||||
first_gen_tokens = req.context_phase_params.first_gen_tokens
|
||||
req.py_draft_tokens = req.context_phase_params.draft_tokens
|
||||
ctx_draft_tokens = req.context_phase_params.draft_tokens
|
||||
req.py_draft_tokens = [] if ctx_draft_tokens is None else ctx_draft_tokens
|
||||
beam_width = req.sampling_config.beam_width
|
||||
for beam in range(0, beam_width):
|
||||
req.add_new_token(first_gen_tokens[beam], beam)
|
||||
|
||||
@ -300,15 +300,13 @@ class KVCacheManager(BaseResourceManager):
|
||||
req_beam_width, req)
|
||||
for _ in range(self.num_extra_kv_tokens):
|
||||
self.impl.add_token(req.py_request_id)
|
||||
if req.py_draft_tokens is not None:
|
||||
for _ in range(len(req.py_draft_tokens)):
|
||||
self.impl.add_token(req.py_request_id)
|
||||
for _ in range(len(req.py_draft_tokens)):
|
||||
self.impl.add_token(req.py_request_id)
|
||||
|
||||
for req in generation_batch:
|
||||
self.impl.add_token(req.py_request_id)
|
||||
if req.py_draft_tokens is not None:
|
||||
for _ in range(len(req.py_draft_tokens)):
|
||||
self.impl.add_token(req.py_request_id)
|
||||
for _ in range(len(req.py_draft_tokens)):
|
||||
self.impl.add_token(req.py_request_id)
|
||||
|
||||
def add_dummy_requests(
|
||||
self,
|
||||
@ -328,7 +326,11 @@ class KVCacheManager(BaseResourceManager):
|
||||
requests = []
|
||||
for i, req_id in enumerate(request_ids):
|
||||
sampling_params = SamplingParams()
|
||||
token_num = token_nums[i] if token_nums is not None else 1
|
||||
# Here 1+max_num_draft_tokens is used to extend the prompt length to
|
||||
# a non-zero number to skip illegal memory access issue in MLA kernel
|
||||
# during warmup.
|
||||
token_num = token_nums[
|
||||
i] if token_nums is not None else 1 + max_num_draft_tokens
|
||||
encoder_input_tokens = [
|
||||
1
|
||||
] * token_num if self.impl.cross_kv else None
|
||||
@ -343,12 +345,16 @@ class KVCacheManager(BaseResourceManager):
|
||||
req.paged_kv_block_ids = []
|
||||
if prepare_resource:
|
||||
self.impl.add_sequence(req_id, token_num, beam_width, req)
|
||||
for _ in range(self.num_extra_kv_tokens):
|
||||
self.impl.add_token(req_id)
|
||||
if is_gen:
|
||||
req.state = LlmRequestState.GENERATION_IN_PROGRESS
|
||||
req.prompt_len = token_num - 1
|
||||
req.py_prompt_len = req.prompt_len
|
||||
if max_num_draft_tokens > 0:
|
||||
req.py_draft_tokens = [0] * max_num_draft_tokens
|
||||
req.py_draft_tokens = [1] * max_num_draft_tokens
|
||||
if prepare_resource:
|
||||
for _ in range(max_num_draft_tokens):
|
||||
self.impl.add_token(req_id)
|
||||
requests.append(req)
|
||||
return requests
|
||||
|
||||
|
||||
@ -297,7 +297,7 @@ class TorchSampler(Sampler):
|
||||
extend_requests = []
|
||||
generation_requests = []
|
||||
for request in scheduled_requests.generation_requests:
|
||||
if request.py_draft_tokens is not None:
|
||||
if len(request.py_draft_tokens) > 0:
|
||||
extend_requests.append(request)
|
||||
else:
|
||||
generation_requests.append(request)
|
||||
@ -361,7 +361,9 @@ class TorchSampler(Sampler):
|
||||
for request in scheduled_requests.generation_requests:
|
||||
if request.state == LlmRequestState.GENERATION_COMPLETE:
|
||||
continue
|
||||
assert request.py_draft_tokens is None, "Speculative decoding not supported in SeparateDecoder."
|
||||
assert len(
|
||||
request.py_draft_tokens
|
||||
) == 0, "Speculative decoding not supported in SeparateDecoder."
|
||||
token_logits = logits[idx:idx + 1, :]
|
||||
new_token, probs = decode_single_request(request, token_logits)
|
||||
new_tokens_device_array.append(new_token)
|
||||
|
||||
@ -184,7 +184,7 @@ class BindMicroBatchScheduler(MicroBatchScheduler):
|
||||
self, active_requests: RequestList, inflight_request_ids: set[int]
|
||||
) -> tuple[list[LlmRequest], list[LlmRequest]]:
|
||||
for request in active_requests:
|
||||
if request.py_draft_tokens is not None:
|
||||
if len(request.py_draft_tokens) > 0:
|
||||
request.draft_tokens = request.py_draft_tokens
|
||||
return self.impl(active_requests, inflight_request_ids,
|
||||
self.max_batch_size, self.max_num_tokens)
|
||||
|
||||
@ -365,12 +365,7 @@ accuracy/test_cli_flow.py::TestLlama3_2_1B::test_cyclic_kv_cache SKIP (https://n
|
||||
accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype SKIP (https://nvbugs/5231468)
|
||||
accuracy/test_cli_flow.py::TestLlama3_2_1B::test_cyclic_kv_cache SKIP (https://nvbugs/5231310)
|
||||
test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-image] SKIP (https://nvbugs/5233423)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5239087)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5239087)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5294983)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5239087)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5239087)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5234002)
|
||||
examples/test_gemma.py::test_llm_hf_gemma_quantization_1gpu[gemma-2-27b-it-fp8-bfloat16-8] SKIP (https://nvbugs/5234164)
|
||||
full::GH200/examples/test_commandr.py::test_llm_commandr_v01_single_gpu_summary[disable_weight_only] SKIP (https://nvbugs/5250460)
|
||||
full::GH200/examples/test_gemma.py::test_llm_gemma_1gpu_summary[gemma-2-27b-it-other-bfloat16-8] SKIP (https://nvbugs/5250460)
|
||||
@ -415,7 +410,6 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5285965)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5285965)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5285965)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5294983)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[latency] SKIP (https://nvbugs/5285965)
|
||||
examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoder2-int4-float16] SKIP (https://nvbugs/5289523)
|
||||
examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoder2-int8-float16] SKIP (https://nvbugs/5289523)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user