mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Altair-Alpha <62340011+Altair-Alpha@users.noreply.github.com>
118 lines
3.4 KiB
Python
118 lines
3.4 KiB
Python
from typing import Optional
|
|
|
|
from tensorrt_llm.functional import Tensor, silu
|
|
from tensorrt_llm.layers import ColumnLinear
|
|
from tensorrt_llm.mapping import Mapping
|
|
from tensorrt_llm.module import Module, ModuleList
|
|
|
|
from ..._utils import str_dtype_to_trt
|
|
|
|
|
|
class ResBlock(Module):
|
|
|
|
def __init__(self,
|
|
exit_dim: int,
|
|
dtype: Optional[str],
|
|
mapping: Mapping = Mapping()):
|
|
super().__init__()
|
|
self.linear = ColumnLinear(
|
|
exit_dim,
|
|
exit_dim,
|
|
bias=True,
|
|
dtype=dtype,
|
|
tp_group=mapping.tp_group,
|
|
tp_size=mapping.tp_size,
|
|
gather_output=True,
|
|
)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return x + silu(self.linear(x))
|
|
|
|
|
|
class Drafter(Module):
|
|
|
|
def __init__(
|
|
self,
|
|
num_layers: int,
|
|
hidden_size: int,
|
|
exit_dim: int,
|
|
vocab_size: int,
|
|
dtype: Optional[str] = None,
|
|
is_rnn: bool = False,
|
|
mapping: Mapping = Mapping(),
|
|
):
|
|
super().__init__()
|
|
self.num_layers = num_layers
|
|
self.is_rnn = is_rnn
|
|
self.dtype = str_dtype_to_trt(dtype)
|
|
|
|
input_dim = 2 * hidden_size
|
|
self.input_proj = (None if input_dim == exit_dim else ColumnLinear(
|
|
input_dim,
|
|
exit_dim,
|
|
bias=True,
|
|
dtype=dtype,
|
|
tp_group=mapping.tp_group,
|
|
tp_size=mapping.tp_size,
|
|
gather_output=True,
|
|
))
|
|
|
|
self.layers = ModuleList([
|
|
ResBlock(exit_dim, dtype, mapping) for _ in range(self.num_layers)
|
|
])
|
|
self.lm_head = ColumnLinear(
|
|
exit_dim,
|
|
vocab_size,
|
|
bias=False,
|
|
dtype=dtype,
|
|
tp_group=mapping.tp_group,
|
|
tp_size=mapping.tp_size,
|
|
gather_output=True,
|
|
)
|
|
|
|
if is_rnn:
|
|
self.rnn_u = ColumnLinear(
|
|
hidden_size,
|
|
hidden_size,
|
|
bias=True,
|
|
dtype=dtype,
|
|
tp_group=mapping.tp_group,
|
|
tp_size=mapping.tp_size,
|
|
gather_output=True,
|
|
)
|
|
self.rnn_w = ColumnLinear(
|
|
hidden_size,
|
|
hidden_size,
|
|
bias=False,
|
|
dtype=dtype,
|
|
tp_group=mapping.tp_group,
|
|
tp_size=mapping.tp_size,
|
|
gather_output=True,
|
|
)
|
|
return
|
|
|
|
@classmethod
|
|
def from_config(cls, config, vocab_size_padded):
|
|
kwargs = {
|
|
"num_layers": config.redrafter_num_layers,
|
|
"hidden_size": config.redrafter_hidden_size,
|
|
"exit_dim": config.redrafter_exit_dim,
|
|
"vocab_size": vocab_size_padded,
|
|
"dtype": config.dtype,
|
|
"is_rnn": config.redrafter_is_rnn,
|
|
"mapping": config.mapping,
|
|
}
|
|
return cls(**kwargs)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
hidden_states = self.input_proj(x) if self.input_proj is not None else x
|
|
for layer in self.layers:
|
|
hidden_states = layer(hidden_states)
|
|
|
|
return self.lm_head(hidden_states)
|
|
|
|
def rnn_embed(self, x: Tensor, prev: Tensor = None) -> Tensor:
|
|
assert self.is_rnn, "This function should not be called when redrafter_is_rnn is false."
|
|
w_embd = self.rnn_w(x)
|
|
return w_embd if prev is None else w_embd + self.rnn_u(prev)
|