mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][fix] revert https://github.com/NVIDIA/TensorRT-LLM/pull/10445. (#10547)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
56e779d09f
commit
afa55c12b6
@ -953,14 +953,6 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
|
||||
hidden_states = hidden_states[:attn_metadata.num_tokens]
|
||||
|
||||
if self.draft_model is not None:
|
||||
# For one-model speculative decoding with PP, only the last PP rank
|
||||
# has valid hidden_states from the target model. The spec_worker (which
|
||||
# runs the draft model loop) should only run on the last PP rank.
|
||||
# Non-last PP ranks return None and let the PP sync handle the results.
|
||||
mapping = self.model.model_config.mapping
|
||||
if mapping.has_pp() and not mapping.is_last_pp_rank():
|
||||
return None
|
||||
|
||||
# get logits
|
||||
logits = self.logits_processor.forward(
|
||||
hidden_states[spec_metadata.gather_ids],
|
||||
|
||||
Loading…
Reference in New Issue
Block a user