mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-07 11:41:47 +08:00
[None][fix] Always reset drafting states for GuidedDecoder (#10899)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
parent
fafc22e3d4
commit
ccdd8461ac
@ -276,6 +276,7 @@ class GuidedDecoder:
|
|||||||
assert len(req.draft_tokens) == 0
|
assert len(req.draft_tokens) == 0
|
||||||
self.num_advanced_draft_tokens[
|
self.num_advanced_draft_tokens[
|
||||||
slot] += self.num_advanced_tokens[slot]
|
slot] += self.num_advanced_tokens[slot]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Guided decoding error: {str(e)}"
|
error_msg = f"Guided decoding error: {str(e)}"
|
||||||
failed_requests.append((req.request_id, error_msg))
|
failed_requests.append((req.request_id, error_msg))
|
||||||
@ -406,10 +407,9 @@ class GuidedDecoder:
|
|||||||
|
|
||||||
for req in requests.valid_requests():
|
for req in requests.valid_requests():
|
||||||
slot = req.seq_slot
|
slot = req.seq_slot
|
||||||
if self.num_advanced_draft_tokens[slot] <= 0:
|
if self.num_advanced_draft_tokens[slot] > 0:
|
||||||
continue
|
self.grammar_matchers[slot].rollback(
|
||||||
self.grammar_matchers[slot].rollback(
|
self.num_advanced_draft_tokens[slot])
|
||||||
self.num_advanced_draft_tokens[slot])
|
|
||||||
# Reset the drafting states.
|
# Reset the drafting states.
|
||||||
self.num_advanced_draft_tokens[slot] = 0
|
self.num_advanced_draft_tokens[slot] = 0
|
||||||
self.is_draft_terminated[slot] = False
|
self.is_draft_terminated[slot] = False
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user