mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-9654][feat] Support DeepSeek-V32 chat template (#9814)
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com>
This commit is contained in:
parent
52cee573ad
commit
31bc14b350
@ -848,6 +848,7 @@ exclude = [
|
||||
"tensorrt_llm/serve/tool_parser/base_tool_parser.py",
|
||||
"tensorrt_llm/serve/tool_parser/qwen3_tool_parser.py",
|
||||
"tensorrt_llm/serve/tool_parser/utils.py",
|
||||
"tensorrt_llm/tokenizer/tokenizer.py",
|
||||
"tensorrt_llm/tools/__init__.py",
|
||||
"tensorrt_llm/tools/importlib_utils.py",
|
||||
"tensorrt_llm/tools/multimodal_builder.py",
|
||||
|
||||
@ -39,6 +39,14 @@ from ..logger import logger, severity_map
|
||||
default=None,
|
||||
help="Path | Name of the tokenizer."
|
||||
"Specify this value only if using TensorRT engine as model.")
|
||||
@click.option(
|
||||
"--custom_tokenizer",
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
"Custom tokenizer type: alias (e.g., 'deepseek_v32') or Python import path "
|
||||
"(e.g., 'tensorrt_llm.tokenizer.deepseek_v32.DeepseekV32Tokenizer'). [Experimental]"
|
||||
)
|
||||
@click.option(
|
||||
"--backend",
|
||||
type=click.Choice(["pytorch", "tensorrt"]),
|
||||
@ -109,13 +117,13 @@ from ..logger import logger, severity_map
|
||||
default=False,
|
||||
help="Flag for disabling KV cache reuse.")
|
||||
@click.pass_context
|
||||
def main(ctx, model: str, tokenizer: Optional[str], log_level: str,
|
||||
backend: str, max_beam_width: int, max_batch_size: int,
|
||||
max_num_tokens: int, max_seq_len: int, tp_size: int, pp_size: int,
|
||||
ep_size: Optional[int], gpus_per_node: Optional[int],
|
||||
kv_cache_free_gpu_memory_fraction: float, trust_remote_code: bool,
|
||||
revision: Optional[str], extra_llm_api_options: Optional[str],
|
||||
disable_kv_cache_reuse: bool):
|
||||
def main(ctx, model: str, tokenizer: Optional[str],
|
||||
custom_tokenizer: Optional[str], log_level: str, backend: str,
|
||||
max_beam_width: int, max_batch_size: int, max_num_tokens: int,
|
||||
max_seq_len: int, tp_size: int, pp_size: int, ep_size: Optional[int],
|
||||
gpus_per_node: Optional[int], kv_cache_free_gpu_memory_fraction: float,
|
||||
trust_remote_code: bool, revision: Optional[str],
|
||||
extra_llm_api_options: Optional[str], disable_kv_cache_reuse: bool):
|
||||
logger.set_level(log_level)
|
||||
|
||||
kv_cache_config = KvCacheConfig(
|
||||
@ -125,6 +133,7 @@ def main(ctx, model: str, tokenizer: Optional[str], log_level: str,
|
||||
llm_args = {
|
||||
"model": model,
|
||||
"tokenizer": tokenizer,
|
||||
"custom_tokenizer": custom_tokenizer,
|
||||
"tensor_parallel_size": tp_size,
|
||||
"pipeline_parallel_size": pp_size,
|
||||
"moe_expert_parallel_size": ep_size,
|
||||
|
||||
@ -82,6 +82,7 @@ def _signal_handler_cleanup_child(signum, frame):
|
||||
def get_llm_args(
|
||||
model: str,
|
||||
tokenizer: Optional[str] = None,
|
||||
custom_tokenizer: Optional[str] = None,
|
||||
backend: str = "pytorch",
|
||||
max_beam_width: int = BuildConfig.model_fields["max_beam_width"].
|
||||
default,
|
||||
@ -137,6 +138,7 @@ def get_llm_args(
|
||||
"model": model,
|
||||
"scheduler_config": scheduler_config,
|
||||
"tokenizer": tokenizer,
|
||||
"custom_tokenizer": custom_tokenizer,
|
||||
"tensor_parallel_size": tensor_parallel_size,
|
||||
"pipeline_parallel_size": pipeline_parallel_size,
|
||||
"context_parallel_size": context_parallel_size,
|
||||
@ -262,6 +264,14 @@ class ChoiceWithAlias(click.Choice):
|
||||
default=None,
|
||||
help="Path | Name of the tokenizer."
|
||||
"Specify this value only if using TensorRT engine as model.")
|
||||
@click.option(
|
||||
"--custom_tokenizer",
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
"Custom tokenizer type: alias (e.g., 'deepseek_v32') or Python import path "
|
||||
"(e.g., 'tensorrt_llm.tokenizer.deepseek_v32.DeepseekV32Tokenizer'). [Experimental]"
|
||||
)
|
||||
@click.option("--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
@ -418,17 +428,17 @@ class ChoiceWithAlias(click.Choice):
|
||||
default=None,
|
||||
help="[Experimental] Specify a custom chat template. "
|
||||
"Can be a file path or one-liner template string")
|
||||
def serve(model: str, tokenizer: Optional[str], host: str, port: int,
|
||||
log_level: str, backend: str, max_beam_width: int,
|
||||
def serve(
|
||||
model: str, tokenizer: Optional[str], custom_tokenizer: Optional[str],
|
||||
host: str, port: int, log_level: str, backend: str, max_beam_width: int,
|
||||
max_batch_size: int, max_num_tokens: int, max_seq_len: int,
|
||||
tensor_parallel_size: int, pipeline_parallel_size: int,
|
||||
context_parallel_size: int, moe_expert_parallel_size: Optional[int],
|
||||
moe_cluster_parallel_size: Optional[int],
|
||||
gpus_per_node: Optional[int], free_gpu_memory_fraction: float,
|
||||
num_postprocess_workers: int, trust_remote_code: bool,
|
||||
revision: Optional[str], extra_llm_api_options: Optional[str],
|
||||
reasoning_parser: Optional[str], tool_parser: Optional[str],
|
||||
metadata_server_config_file: Optional[str],
|
||||
moe_cluster_parallel_size: Optional[int], gpus_per_node: Optional[int],
|
||||
free_gpu_memory_fraction: float, num_postprocess_workers: int,
|
||||
trust_remote_code: bool, revision: Optional[str],
|
||||
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
|
||||
tool_parser: Optional[str], metadata_server_config_file: Optional[str],
|
||||
server_role: Optional[str],
|
||||
fail_fast_on_attention_window_too_large: bool,
|
||||
otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool,
|
||||
@ -450,6 +460,7 @@ def serve(model: str, tokenizer: Optional[str], host: str, port: int,
|
||||
llm_args, _ = get_llm_args(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
custom_tokenizer=custom_tokenizer,
|
||||
backend=backend,
|
||||
max_beam_width=max_beam_width,
|
||||
max_batch_size=max_batch_size,
|
||||
|
||||
@ -24,7 +24,8 @@ from tensorrt_llm.inputs.multimodal import (MultimodalServerConfig,
|
||||
from tensorrt_llm.inputs.registry import (MULTIMODAL_PLACEHOLDER_REGISTRY,
|
||||
MultimodalPlaceholderPlacement)
|
||||
from tensorrt_llm.llmapi.llm_utils import ModelLoader
|
||||
from tensorrt_llm.llmapi.tokenizer import TokenizerBase, TransformersTokenizer
|
||||
from tensorrt_llm.tokenizer import TokenizerBase, TransformersTokenizer
|
||||
from tensorrt_llm.tokenizer.deepseek_v32 import DeepseekV32Tokenizer
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -580,6 +581,18 @@ def apply_chat_template(
|
||||
if model_type in HF_CHAT_TEMPLATE_EXCEPTIONS:
|
||||
# special path for models like llava-llama
|
||||
return "".join([conv["content"] for conv in conversation])
|
||||
|
||||
# Handle DeepSeek V32 tokenizer with custom chat template
|
||||
if isinstance(tokenizer, DeepseekV32Tokenizer):
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages=conversation,
|
||||
tools=tools,
|
||||
**(chat_template_kwargs or {}),
|
||||
)
|
||||
if enable_tokenize:
|
||||
return tokenizer.encode(prompt)
|
||||
return prompt
|
||||
|
||||
if isinstance(tokenizer, TransformersTokenizer):
|
||||
tokenizer = tokenizer.tokenizer # we need the TokenizerBase for apply_chat_template
|
||||
|
||||
|
||||
@ -1838,6 +1838,14 @@ class BaseLlmArgs(StrictBaseModel):
|
||||
description="The mode to initialize the tokenizer.",
|
||||
json_schema_extra={"type": "Literal['auto', 'slow']"})
|
||||
|
||||
custom_tokenizer: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Specify a custom tokenizer implementation. Accepts either: "
|
||||
"(1) a built-in alias (e.g., 'deepseek_v32'), or "
|
||||
"(2) a Python import path (e.g., 'tensorrt_llm.tokenizer.deepseek_v32.DeepseekV32Tokenizer'). "
|
||||
"The tokenizer class must implement 'from_pretrained(path, **kwargs)' and the TokenizerBase interface.",
|
||||
status="prototype")
|
||||
|
||||
skip_tokenizer_init: bool = Field(
|
||||
default=False,
|
||||
description="Whether to skip the tokenizer initialization.")
|
||||
@ -2187,6 +2195,41 @@ class BaseLlmArgs(StrictBaseModel):
|
||||
"""Initialize tokenizer based on configuration."""
|
||||
if self.skip_tokenizer_init:
|
||||
self.tokenizer = None
|
||||
elif self.custom_tokenizer:
|
||||
# If tokenizer is already a tokenizer object, custom_tokenizer is not compatible
|
||||
if isinstance(self.tokenizer,
|
||||
(TokenizerBase, PreTrainedTokenizerBase)):
|
||||
raise ValueError(
|
||||
"Cannot use custom_tokenizer when tokenizer is already a tokenizer object. "
|
||||
"Please specify a tokenizer path or leave it as None to load from model path."
|
||||
)
|
||||
|
||||
# Support short aliases for built-in tokenizers
|
||||
TOKENIZER_ALIASES = {
|
||||
'deepseek_v32':
|
||||
'tensorrt_llm.tokenizer.deepseek_v32.DeepseekV32Tokenizer',
|
||||
}
|
||||
|
||||
tokenizer_path = TOKENIZER_ALIASES.get(self.custom_tokenizer,
|
||||
self.custom_tokenizer)
|
||||
|
||||
# Dynamically import and use custom tokenizer
|
||||
from importlib import import_module
|
||||
try:
|
||||
module_path, class_name = tokenizer_path.rsplit('.', 1)
|
||||
module = import_module(module_path)
|
||||
tokenizer_class = getattr(module, class_name)
|
||||
# Use tokenizer path if specified, otherwise use model path
|
||||
load_path = self.tokenizer if self.tokenizer else self.model
|
||||
self.tokenizer = tokenizer_class.from_pretrained(
|
||||
load_path,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
use_fast=self.tokenizer_mode != 'slow')
|
||||
except (ValueError, ImportError, AttributeError) as e:
|
||||
raise ValueError(
|
||||
f"Failed to load custom tokenizer '{self.custom_tokenizer}': {e}. "
|
||||
"Expected format: 'module.path.ClassName' or a recognized alias."
|
||||
) from e
|
||||
else:
|
||||
self.tokenizer = tokenizer_factory(
|
||||
self.tokenizer,
|
||||
|
||||
@ -1,365 +1,21 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
# Backward compatibility shim - the tokenizer module has moved to tensorrt_llm.tokenizer
|
||||
# All imports from tensorrt_llm.llmapi.tokenizer will continue to work.
|
||||
from tensorrt_llm.tokenizer import (TLLM_INCREMENTAL_DETOKENIZATION_BACKEND,
|
||||
TLLM_STREAM_INTERVAL_THRESHOLD,
|
||||
TokenizerBase, TransformersTokenizer,
|
||||
_llguidance_tokenizer_info,
|
||||
_xgrammar_tokenizer_info, load_hf_tokenizer,
|
||||
tokenizer_factory)
|
||||
from tensorrt_llm.tokenizer.deepseek_v32 import DeepseekV32Tokenizer
|
||||
|
||||
from transformers import (AutoTokenizer, PreTrainedTokenizerBase,
|
||||
PreTrainedTokenizerFast)
|
||||
|
||||
from .._utils import nvtx_range_debug
|
||||
from ..logger import logger
|
||||
|
||||
TLLM_INCREMENTAL_DETOKENIZATION_BACKEND = os.environ.get(
|
||||
"TLLM_INCREMENTAL_DETOKENIZATION_BACKEND", "HF")
|
||||
TLLM_STREAM_INTERVAL_THRESHOLD = int(
|
||||
os.environ.get("TLLM_STREAM_INTERVAL_THRESHOLD", "24"))
|
||||
try:
|
||||
from tokenizers.decoders import DecodeStream # noqa
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
f"HF incremental detokenization is unsupported by tokenizer<0.21.0; fallback to TRTLLM incremental detokenization."
|
||||
)
|
||||
TLLM_INCREMENTAL_DETOKENIZATION_BACKEND = "TRTLLM"
|
||||
|
||||
|
||||
class TokenizerBase(PreTrainedTokenizerBase):
|
||||
''' This is a protocol for the tokenizer. Users can implement their own tokenizer by inheriting this class. '''
|
||||
|
||||
|
||||
class TransformersTokenizer(TokenizerBase):
|
||||
''' A wrapper for the Transformers' tokenizer.
|
||||
This is the default tokenizer for LLM. '''
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
self._all_special_tokens_set = set(self.tokenizer.all_special_tokens)
|
||||
|
||||
def __call__(self, text: str, *args, **kwargs) -> Any:
|
||||
return self.tokenizer(text, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> int:
|
||||
return self.tokenizer.eos_token_id
|
||||
|
||||
@property
|
||||
def pad_token_id(self) -> int:
|
||||
return self.tokenizer.pad_token_id
|
||||
|
||||
@property
|
||||
def name_or_path(self) -> str:
|
||||
return self.tokenizer.name_or_path
|
||||
|
||||
def encode(self, text: str, *args, **kwargs) -> List[int]:
|
||||
return self.tokenizer.encode(text, *args, **kwargs)
|
||||
|
||||
def decode(self, token_ids: List[int], *args, **kwargs) -> str:
|
||||
return self.tokenizer.decode(token_ids, *args, **kwargs)
|
||||
|
||||
def batch_encode_plus(self, texts: List[str], *args, **kwargs) -> dict:
|
||||
return self.tokenizer.batch_encode_plus(texts, *args, **kwargs)
|
||||
|
||||
def get_chat_template(self,
|
||||
chat_template: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None) -> str:
|
||||
return self.tokenizer.get_chat_template(chat_template, tools)
|
||||
|
||||
def apply_chat_template(
|
||||
self, conversation: Union[List[Dict[str, str]],
|
||||
List[List[Dict[str, str]]]], *args,
|
||||
**kwargs) -> Union[str, List[int], List[str], List[List[int]]]:
|
||||
return self.tokenizer.apply_chat_template(conversation, *args, **kwargs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.tokenizer})"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_dir: str, **kwargs):
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir,
|
||||
**kwargs)
|
||||
return cls(tokenizer)
|
||||
|
||||
def save_pretrained(self, pretrained_model_dir: str, **kwargs):
|
||||
self.tokenizer.save_pretrained(pretrained_model_dir, **kwargs)
|
||||
|
||||
def clean_up_tokenization(self, out_string: str) -> str:
|
||||
return self.tokenizer.clean_up_tokenization(out_string)
|
||||
|
||||
@property
|
||||
def clean_up_tokenization_spaces(self):
|
||||
return self.tokenizer.clean_up_tokenization_spaces
|
||||
|
||||
@property
|
||||
def is_fast(self) -> bool:
|
||||
return self.tokenizer.is_fast
|
||||
|
||||
def get_added_vocab(self) -> Dict[str, int]:
|
||||
# Assumed to be O(1) complexity
|
||||
return self.tokenizer.get_added_vocab()
|
||||
|
||||
def convert_ids_to_tokens(
|
||||
self,
|
||||
ids: Union[int, List[int]],
|
||||
skip_special_tokens: bool = False) -> Union[str, List[str]]:
|
||||
return self.tokenizer.convert_ids_to_tokens(
|
||||
ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
def convert_tokens_to_string(
|
||||
self,
|
||||
tokens: List[str],
|
||||
skip_special_tokens: bool = False,
|
||||
spaces_between_special_tokens: bool = True) -> str:
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/transformers_utils/detokenizer.py#L172
|
||||
if self.is_fast or not self.get_added_vocab():
|
||||
return self.tokenizer.convert_tokens_to_string(tokens)
|
||||
|
||||
sub_texts: List[str] = []
|
||||
current_sub_text: List[str] = []
|
||||
for token in tokens:
|
||||
if skip_special_tokens and token in self._all_special_tokens_set:
|
||||
continue
|
||||
if token in self.get_added_vocab():
|
||||
if current_sub_text:
|
||||
sub_text = self.tokenizer.convert_tokens_to_string(
|
||||
current_sub_text)
|
||||
sub_texts.append(sub_text)
|
||||
current_sub_text = []
|
||||
sub_texts.append(token)
|
||||
else:
|
||||
current_sub_text.append(token)
|
||||
if current_sub_text:
|
||||
sub_text = self.tokenizer.convert_tokens_to_string(current_sub_text)
|
||||
sub_texts.append(sub_text)
|
||||
if spaces_between_special_tokens:
|
||||
return " ".join(sub_texts)
|
||||
else:
|
||||
return "".join(sub_texts)
|
||||
|
||||
@nvtx_range_debug("decode_incrementally")
|
||||
def decode_incrementally(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
prev_text: Optional[str] = None,
|
||||
states: Optional[dict] = None,
|
||||
*,
|
||||
flush: bool = False,
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: Optional[bool] = None,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
stream_interval: int = 1) -> Tuple[str, dict]:
|
||||
"""Incremental detokenization, typically used for streaming generation.
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The incremental token ids.
|
||||
prev_text (str): The previous decoded text. None if it's the first iteration.
|
||||
states (dict): A dict that saves previous states for incremental detokenization. None if it's the first iteration.
|
||||
flush (bool): Force flushing the pending tokens to decoded text.
|
||||
skip_special_tokens (bool): Whether to remove special tokens in the decoding.
|
||||
clean_up_tokenization_spaces (bool): Whether to clean up tokenization spaces.
|
||||
spaces_between_special_tokens (bool): Whether to add spaces between special tokens.
|
||||
stream_interval (int): The iteration interval to create responses under the streaming mode.
|
||||
|
||||
Returns:
|
||||
text, states (Tuple[str, dict]): text is the current decoded text, states is the current incremental detokenization states.
|
||||
They should be passed to next incremental detokenization iteration, if any.
|
||||
"""
|
||||
# HF incremental detokenization implementation is faster than TRTLLM when stream_interval is smaller.
|
||||
if (TLLM_INCREMENTAL_DETOKENIZATION_BACKEND == "TRTLLM"
|
||||
or stream_interval >= TLLM_STREAM_INTERVAL_THRESHOLD
|
||||
or spaces_between_special_tokens is False
|
||||
or not hasattr(self.tokenizer, "_tokenizer")):
|
||||
return self.trtllm_decode_incrementally(
|
||||
token_ids,
|
||||
prev_text,
|
||||
states,
|
||||
flush=flush,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens)
|
||||
else:
|
||||
return self.hf_decode_incrementally(
|
||||
token_ids,
|
||||
prev_text,
|
||||
states,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces)
|
||||
|
||||
def trtllm_decode_incrementally(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
prev_text: Optional[str] = None,
|
||||
states: Optional[dict] = None,
|
||||
*,
|
||||
flush: bool = False,
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: Optional[bool] = None,
|
||||
spaces_between_special_tokens: bool = True) -> Tuple[str, dict]:
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/transformers_utils/detokenizer.py#L238
|
||||
if prev_text is None:
|
||||
prev_text = ""
|
||||
|
||||
if states is None:
|
||||
states = {}
|
||||
last_new_tokens = states.pop('last_new_tokens', [])
|
||||
pending_tokens = states.pop('pending_tokens', [])
|
||||
|
||||
if len(last_new_tokens) > 0:
|
||||
last_new_text = self.convert_tokens_to_string(
|
||||
last_new_tokens,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens)
|
||||
else:
|
||||
last_new_text = ""
|
||||
|
||||
new_tokens = self.convert_ids_to_tokens(
|
||||
token_ids, skip_special_tokens=skip_special_tokens)
|
||||
pending_tokens.extend(new_tokens)
|
||||
|
||||
curr_new_text = self.convert_tokens_to_string(
|
||||
last_new_tokens + pending_tokens,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens)
|
||||
if not flush and (len(curr_new_text.rstrip()) <= len(
|
||||
last_new_text.rstrip()) or curr_new_text.endswith("<EFBFBD>")):
|
||||
return prev_text, {
|
||||
'last_new_tokens': last_new_tokens,
|
||||
'pending_tokens': pending_tokens
|
||||
}
|
||||
|
||||
# Remove the part of last_new_text
|
||||
curr_new_text = curr_new_text[len(last_new_text):]
|
||||
if clean_up_tokenization_spaces is None:
|
||||
clean_up_tokenization_spaces = self.clean_up_tokenization_spaces
|
||||
if clean_up_tokenization_spaces:
|
||||
curr_new_text = self.clean_up_tokenization(curr_new_text)
|
||||
return prev_text + curr_new_text, {'last_new_tokens': pending_tokens}
|
||||
|
||||
def hf_decode_incrementally(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
prev_text: Optional[str] = None,
|
||||
states: Optional[dict] = None,
|
||||
*,
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: Optional[bool] = None
|
||||
) -> Tuple[str, dict]:
|
||||
if states is None:
|
||||
states = {
|
||||
'decode_stream':
|
||||
DecodeStream(skip_special_tokens=skip_special_tokens)
|
||||
}
|
||||
|
||||
decode_stream = states.get('decode_stream')
|
||||
results = [
|
||||
result for tid in token_ids
|
||||
if (result := decode_stream.step(self.tokenizer._tokenizer, tid)
|
||||
) is not None
|
||||
__all__ = [
|
||||
"TLLM_INCREMENTAL_DETOKENIZATION_BACKEND",
|
||||
"TLLM_STREAM_INTERVAL_THRESHOLD",
|
||||
"TokenizerBase",
|
||||
"TransformersTokenizer",
|
||||
"DeepseekV32Tokenizer",
|
||||
"tokenizer_factory",
|
||||
"_xgrammar_tokenizer_info",
|
||||
"_llguidance_tokenizer_info",
|
||||
"load_hf_tokenizer",
|
||||
]
|
||||
curr_new_text = "".join(results)
|
||||
if clean_up_tokenization_spaces is None:
|
||||
clean_up_tokenization_spaces = self.clean_up_tokenization_spaces
|
||||
if clean_up_tokenization_spaces:
|
||||
curr_new_text = self.clean_up_tokenization(curr_new_text)
|
||||
|
||||
if prev_text is None:
|
||||
return curr_new_text, states
|
||||
else:
|
||||
return prev_text + curr_new_text, states
|
||||
|
||||
|
||||
def tokenizer_factory(obj: Optional[Union[str, Path, PreTrainedTokenizerBase,
|
||||
TokenizerBase]] = None,
|
||||
**kwargs) -> Optional[TokenizerBase]:
|
||||
if obj is None:
|
||||
return None
|
||||
elif isinstance(obj, (str, Path)):
|
||||
default_kwargs = {
|
||||
'legacy': False,
|
||||
'padding_side': 'left',
|
||||
'truncation_side': 'left',
|
||||
'trust_remote_code': True,
|
||||
'use_fast': True,
|
||||
}
|
||||
default_kwargs.update(kwargs)
|
||||
return TransformersTokenizer.from_pretrained(obj, **default_kwargs)
|
||||
elif isinstance(obj, TokenizerBase):
|
||||
return obj
|
||||
elif isinstance(obj, PreTrainedTokenizerBase):
|
||||
return TransformersTokenizer(obj)
|
||||
else:
|
||||
raise TypeError(f"Unrecognized tokenizer {obj}")
|
||||
|
||||
|
||||
def _xgrammar_tokenizer_info(tokenizer):
|
||||
# Reference: https://github.com/mlc-ai/xgrammar/blob/b9a16de54e1e0eff58da14c65750414cceaf1a6f/python/xgrammar/tokenizer_info.py#L133
|
||||
if isinstance(tokenizer, TokenizerBase):
|
||||
tokenizer = tokenizer.tokenizer
|
||||
|
||||
stop_token_ids = [tokenizer.eos_token_id]
|
||||
|
||||
try:
|
||||
encoded_vocab = tokenizer.get_vocab()
|
||||
encoded_vocab = [
|
||||
token
|
||||
for token, _ in sorted(encoded_vocab.items(), key=lambda x: x[1])
|
||||
]
|
||||
except AttributeError as e:
|
||||
msg = (
|
||||
f"Cannot get the vocabulary of the tokenizer {type(tokenizer)}. The tokenizer "
|
||||
"should have a get_vocab method.")
|
||||
raise ValueError(msg) from e
|
||||
|
||||
if isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
backend_str = tokenizer.backend_tokenizer.to_str()
|
||||
return {
|
||||
"encoded_vocab": encoded_vocab,
|
||||
"tokenizer_str": backend_str,
|
||||
"stop_token_ids": stop_token_ids
|
||||
}
|
||||
elif ("vocab_file" in tokenizer.vocab_files_names
|
||||
and "tiktoken" in tokenizer.vocab_files_names["vocab_file"]):
|
||||
return {
|
||||
"encoded_vocab": encoded_vocab,
|
||||
"stop_token_ids": stop_token_ids
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported tokenizer type: {type(tokenizer)}")
|
||||
|
||||
|
||||
def _llguidance_tokenizer_info(tokenizer):
|
||||
tokenizer_info = _xgrammar_tokenizer_info(tokenizer)
|
||||
if tokenizer_info.get("tokenizer_str") is None:
|
||||
raise ValueError("missing tokenizer_str")
|
||||
return tokenizer_info
|
||||
|
||||
|
||||
def load_hf_tokenizer(model_dir: str,
|
||||
trust_remote_code: bool = True,
|
||||
use_fast: bool = True,
|
||||
**kwargs) -> Optional[TransformersTokenizer]:
|
||||
''' Load a tokenizer from a Hugging Face model directory.
|
||||
|
||||
Args:
|
||||
model_dir (str): The model directory.
|
||||
trust_remote_code (bool): Whether to trust the remote code.
|
||||
use_fast (bool): Whether to use the fast tokenizer.
|
||||
|
||||
Returns:
|
||||
A TransformersTokenizer object if the tokenizer is loaded successfully.
|
||||
'''
|
||||
|
||||
try:
|
||||
return TransformersTokenizer.from_pretrained(
|
||||
model_dir,
|
||||
legacy=False,
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_fast=use_fast,
|
||||
**kwargs)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to load hf tokenizer from {model_dir}, encounter error: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
21
tensorrt_llm/tokenizer/__init__.py
Normal file
21
tensorrt_llm/tokenizer/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
from .tokenizer import (
|
||||
TLLM_INCREMENTAL_DETOKENIZATION_BACKEND,
|
||||
TLLM_STREAM_INTERVAL_THRESHOLD,
|
||||
TokenizerBase,
|
||||
TransformersTokenizer,
|
||||
_llguidance_tokenizer_info,
|
||||
_xgrammar_tokenizer_info,
|
||||
load_hf_tokenizer,
|
||||
tokenizer_factory,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TLLM_INCREMENTAL_DETOKENIZATION_BACKEND",
|
||||
"TLLM_STREAM_INTERVAL_THRESHOLD",
|
||||
"TokenizerBase",
|
||||
"TransformersTokenizer",
|
||||
"tokenizer_factory",
|
||||
"_xgrammar_tokenizer_info",
|
||||
"_llguidance_tokenizer_info",
|
||||
"load_hf_tokenizer",
|
||||
]
|
||||
14
tensorrt_llm/tokenizer/deepseek_v32/__init__.py
Normal file
14
tensorrt_llm/tokenizer/deepseek_v32/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""DeepSeek V3.2 tokenizer and encoding utilities.
|
||||
|
||||
This is a temporary workaround for DeepSeek-V3.2 model as HF does not support it yet.
|
||||
TODO: Remove this once HF supports DeepSeek-V3.2
|
||||
"""
|
||||
|
||||
from .encoding import encode_messages, parse_message_from_completion_text
|
||||
from .tokenizer import DeepseekV32Tokenizer
|
||||
|
||||
__all__ = [
|
||||
"DeepseekV32Tokenizer",
|
||||
"encode_messages",
|
||||
"parse_message_from_completion_text",
|
||||
]
|
||||
425
tensorrt_llm/tokenizer/deepseek_v32/encoding.py
Normal file
425
tensorrt_llm/tokenizer/deepseek_v32/encoding.py
Normal file
@ -0,0 +1,425 @@
|
||||
# copy from https://huggingface.co/deepseek-ai/DeepSeek-V3.2/blob/main/encoding/encoding_dsv32.py
|
||||
# ruff: noqa: E501
|
||||
import copy
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
TOOLS_SYSTEM_TEMPLATE = """## Tools
|
||||
You have access to a set of tools you can use to answer the user's question.
|
||||
You can invoke functions by writing a "<{dsml_token}function_calls>" block like the following as part of your reply to the user:
|
||||
<{dsml_token}function_calls>
|
||||
<{dsml_token}invoke name="$FUNCTION_NAME">
|
||||
<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</{dsml_token}parameter>
|
||||
...
|
||||
</{dsml_token}invoke>
|
||||
<{dsml_token}invoke name="$FUNCTION_NAME2">
|
||||
...
|
||||
</{dsml_token}invoke>
|
||||
</{dsml_token}function_calls>
|
||||
String and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects).
|
||||
If the thinking_mode is enabled, then after function results you should strongly consider outputting a thinking block. Here is an example:
|
||||
<{dsml_token}function_calls>
|
||||
...
|
||||
</{dsml_token}function_calls>
|
||||
<function_results>
|
||||
...
|
||||
</function_results>
|
||||
{thinking_start_token}...thinking about results{thinking_end_token}
|
||||
Here are the functions available in JSONSchema format:
|
||||
<functions>
|
||||
{tool_schemas}
|
||||
</functions>
|
||||
"""
|
||||
|
||||
bos_token: str = "<|begin▁of▁sentence|>"
|
||||
eos_token: str = "<|end▁of▁sentence|>"
|
||||
thinking_start_token: str = "<think>"
|
||||
thinking_end_token: str = "</think>"
|
||||
dsml_token: str = "|DSML|"
|
||||
system_msg_template: str = "{content}"
|
||||
user_msg_template: str = "<|User|>{content}<|Assistant|>"
|
||||
assistant_msg_template: str = "{reasoning}{content}{tool_calls}<|end▁of▁sentence|>"
|
||||
thinking_template = "{reasoning_content}"
|
||||
|
||||
response_format_template: str = (
|
||||
"## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}"
|
||||
)
|
||||
tool_call_template: str = '<{dsml_token}invoke name="{name}">\n{arguments}\n</{dsml_token}invoke>'
|
||||
tool_calls_template = "<{dsml_token}function_calls>\n{tool_calls}\n</{dsml_token}function_calls>"
|
||||
|
||||
tool_output_template: str = "\n<result>{content}</result>"
|
||||
|
||||
|
||||
def to_json(value: Any) -> str:
|
||||
try:
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
except Exception:
|
||||
return json.dumps(value, ensure_ascii=True)
|
||||
|
||||
|
||||
def tools_from_openai_format(tools):
|
||||
return [tool["function"] for tool in tools]
|
||||
|
||||
|
||||
def tool_calls_from_openai_format(tool_calls):
|
||||
return [
|
||||
{
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": tool_call["function"]["arguments"],
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
|
||||
|
||||
def tool_calls_to_openai_format(tool_calls):
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call["name"],
|
||||
"arguments": tool_call["arguments"],
|
||||
},
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
|
||||
|
||||
def encode_arguments_to_dsml(tool_call: Dict[str, str]) -> str:
|
||||
p_dsml_template = (
|
||||
"""<{dsml_token}parameter name="{key}" string="{is_str}">{value}</{dsml_token}parameter>"""
|
||||
)
|
||||
P_dsml_strs = []
|
||||
|
||||
arguments = json.loads(tool_call["arguments"])
|
||||
|
||||
for k, v in arguments.items():
|
||||
p_dsml_str = p_dsml_template.format(
|
||||
dsml_token=dsml_token,
|
||||
key=k,
|
||||
is_str="true" if isinstance(v, str) else "false",
|
||||
value=v if isinstance(v, str) else to_json(v),
|
||||
)
|
||||
|
||||
P_dsml_strs.append(p_dsml_str)
|
||||
|
||||
return "\n".join(P_dsml_strs)
|
||||
|
||||
|
||||
def decode_dsml_to_arguments(
|
||||
tool_name: str, tool_args: Dict[str, Tuple[str, str]]
|
||||
) -> Dict[str, str]:
|
||||
def _decode_value(key: str, value: str, string: str):
|
||||
if string == "true":
|
||||
value = to_json(value)
|
||||
return f"{to_json(key)}: {value}"
|
||||
|
||||
tool_args_json = (
|
||||
"{"
|
||||
+ ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()])
|
||||
+ "}"
|
||||
)
|
||||
return dict(name=tool_name, arguments=tool_args_json)
|
||||
|
||||
|
||||
def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:
|
||||
tools_json = [to_json(t) for t in tools]
|
||||
|
||||
return TOOLS_SYSTEM_TEMPLATE.format(
|
||||
tool_schemas="\n".join(tools_json),
|
||||
dsml_token=dsml_token,
|
||||
thinking_start_token=thinking_start_token,
|
||||
thinking_end_token=thinking_end_token,
|
||||
)
|
||||
|
||||
|
||||
def find_last_user_index(messages: List[Dict[str, Any]]) -> int:
|
||||
last_user_index = -1
|
||||
for idx in range(len(messages) - 1, -1, -1):
|
||||
if messages[idx].get("role") in ["user", "developer"]:
|
||||
last_user_index = idx
|
||||
break
|
||||
return last_user_index
|
||||
|
||||
|
||||
def render_message(index: int, messages: List[Dict[str, Any]], thinking_mode: str) -> str:
|
||||
assert 0 <= index < len(messages)
|
||||
assert thinking_mode in ["chat", "thinking"], f"Invalid thinking_mode `{thinking_mode}`"
|
||||
|
||||
prompt = ""
|
||||
msg = messages[index]
|
||||
last_user_idx = find_last_user_index(messages)
|
||||
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
tools = msg.get("tools")
|
||||
response_format = msg.get("response_format")
|
||||
tool_calls = msg.get("tool_calls")
|
||||
# support both reasoning_content and reasoning for compatibility
|
||||
reasoning_content = msg.get("reasoning") or msg.get("reasoning_content")
|
||||
|
||||
if tools:
|
||||
tools = tools_from_openai_format(tools)
|
||||
if tool_calls:
|
||||
tool_calls = tool_calls_from_openai_format(tool_calls)
|
||||
|
||||
if role == "system":
|
||||
prompt += system_msg_template.format(content=content or "")
|
||||
if tools:
|
||||
prompt += "\n\n" + render_tools(tools)
|
||||
|
||||
if response_format:
|
||||
prompt += "\n\n" + response_format_template.format(schema=to_json(response_format))
|
||||
|
||||
elif role == "developer":
|
||||
assert content, f"Invalid message for role `{role}`: {msg}"
|
||||
content_developer = ""
|
||||
if tools:
|
||||
content_developer += "\n\n" + render_tools(tools)
|
||||
|
||||
if response_format:
|
||||
content_developer += "\n\n" + response_format_template.format(
|
||||
schema=to_json(response_format)
|
||||
)
|
||||
|
||||
content_developer += "\n\n# The user's message is: {}".format(content)
|
||||
|
||||
prompt += user_msg_template.format(content=content_developer)
|
||||
if index == last_user_idx and thinking_mode == "thinking":
|
||||
prompt += thinking_start_token
|
||||
else:
|
||||
prompt += thinking_end_token
|
||||
|
||||
elif role == "user":
|
||||
prompt += user_msg_template.format(content=content)
|
||||
|
||||
if index == last_user_idx and thinking_mode == "thinking":
|
||||
prompt += thinking_start_token
|
||||
else:
|
||||
prompt += thinking_end_token
|
||||
|
||||
elif role == "tool":
|
||||
prev_assistant_idx = index - 1
|
||||
assistant_msg = messages[prev_assistant_idx]
|
||||
while prev_assistant_idx >= 0 and assistant_msg.get("role") == "tool":
|
||||
prev_assistant_idx -= 1
|
||||
assistant_msg = messages[prev_assistant_idx]
|
||||
|
||||
assert index == 0 or prev_assistant_idx >= 0 and assistant_msg.get("role") == "assistant", (
|
||||
f"Invalid messages at {index}:\n{assistant_msg}"
|
||||
)
|
||||
|
||||
tool_call_order = index - prev_assistant_idx
|
||||
assistant_tool_calls = assistant_msg.get("tool_calls")
|
||||
assert assistant_tool_calls and len(assistant_tool_calls) >= tool_call_order, (
|
||||
"No tool calls but found tool output"
|
||||
)
|
||||
|
||||
if tool_call_order == 1:
|
||||
prompt += "\n\n<function_results>"
|
||||
|
||||
prompt += tool_output_template.format(content=content)
|
||||
|
||||
if tool_call_order == len(assistant_tool_calls):
|
||||
prompt += "\n</function_results>"
|
||||
|
||||
if index >= last_user_idx and thinking_mode == "thinking":
|
||||
prompt += "\n\n" + thinking_start_token
|
||||
else:
|
||||
prompt += "\n\n" + thinking_end_token
|
||||
|
||||
elif role == "assistant":
|
||||
prev_assistant_idx = index
|
||||
thinking_part = ""
|
||||
|
||||
tool_calls_content = ""
|
||||
if tool_calls:
|
||||
tool_calls = [
|
||||
tool_call_template.format(
|
||||
dsml_token=dsml_token,
|
||||
name=tool_call.get("name"),
|
||||
arguments=encode_arguments_to_dsml(tool_call),
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
tool_calls_content += "\n\n" + tool_calls_template.format(
|
||||
dsml_token=dsml_token, tool_calls="\n".join(tool_calls)
|
||||
)
|
||||
|
||||
summary_content = content or ""
|
||||
|
||||
if thinking_mode == "thinking" and index > last_user_idx:
|
||||
assert reasoning_content or tool_calls, (
|
||||
f"ThinkingMode: {thinking_mode}, invalid message without reasoning_content/tool_calls `{msg}` after last user message"
|
||||
)
|
||||
thinking_part = (
|
||||
thinking_template.format(reasoning_content=reasoning_content or "")
|
||||
+ thinking_end_token
|
||||
)
|
||||
|
||||
prompt += assistant_msg_template.format(
|
||||
reasoning=thinking_part,
|
||||
content=summary_content,
|
||||
tool_calls=tool_calls_content,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown role: {role}")
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def drop_thinking_messages(
|
||||
messages: List[Dict[str, Any]], last_user_idx: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
messages_wo_thinking: List[Dict[str, Any]] = []
|
||||
last_user_idx = find_last_user_index(messages) if last_user_idx is None else last_user_idx
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
if role in ["user", "system", "tool"] or idx >= last_user_idx:
|
||||
messages_wo_thinking.append(msg)
|
||||
continue
|
||||
|
||||
elif role == "assistant":
|
||||
msg_wo_thinking = copy.copy(msg)
|
||||
msg_wo_thinking.pop("reasoning_content", None)
|
||||
msg_wo_thinking.pop("reasoning", None)
|
||||
messages_wo_thinking.append(msg_wo_thinking)
|
||||
|
||||
return messages_wo_thinking
|
||||
|
||||
|
||||
def encode_messages(
|
||||
messages: List[Dict[str, Any]],
|
||||
thinking_mode: str,
|
||||
context: Optional[List[Dict[str, Any]]] = None,
|
||||
drop_thinking: bool = True,
|
||||
add_default_bos_token: bool = True,
|
||||
) -> str:
|
||||
context = context if context else []
|
||||
full_messages = context + messages
|
||||
|
||||
prompt = bos_token if add_default_bos_token and len(context) == 0 else ""
|
||||
|
||||
if thinking_mode == "thinking" and drop_thinking:
|
||||
full_messages = drop_thinking_messages(full_messages)
|
||||
|
||||
for idx in range(len(messages)):
|
||||
prompt += render_message(idx + len(context), full_messages, thinking_mode=thinking_mode)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str, Optional[str]]:
|
||||
min_pos = len(text)
|
||||
matched_stop = None
|
||||
|
||||
for s in stop:
|
||||
pos = text.find(s, index)
|
||||
if pos != -1 and pos < min_pos:
|
||||
min_pos = pos
|
||||
matched_stop = s
|
||||
|
||||
if matched_stop:
|
||||
content = text[index:min_pos]
|
||||
return min_pos + len(matched_stop), content, matched_stop
|
||||
else:
|
||||
content = text[index:]
|
||||
return len(text), content, None
|
||||
|
||||
|
||||
def parse_tool_calls(index: int, text: str):
|
||||
tool_calls: List[Dict[str, Any]] = []
|
||||
stop_token = None
|
||||
tool_calls_end_token = f"</{dsml_token}function_calls>"
|
||||
|
||||
while index < len(text):
|
||||
index, _, stop_token = _read_until_stop(
|
||||
index, text, [f"<{dsml_token}invoke", tool_calls_end_token]
|
||||
)
|
||||
assert _ == ">\n", "Tool call format error"
|
||||
|
||||
if stop_token == tool_calls_end_token:
|
||||
break
|
||||
|
||||
assert stop_token is not None, "Missing special token"
|
||||
|
||||
index, tool_name_content, stop_token = _read_until_stop(
|
||||
index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"]
|
||||
)
|
||||
|
||||
p_tool_name = re.findall(r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL)
|
||||
assert len(p_tool_name) == 1, "Tool name format error"
|
||||
tool_name = p_tool_name[0]
|
||||
|
||||
tool_args: Dict[str, Tuple[str, str]] = {}
|
||||
while stop_token == f"<{dsml_token}parameter":
|
||||
index, param_content, stop_token = _read_until_stop(
|
||||
index, text, [f"/{dsml_token}parameter"]
|
||||
)
|
||||
|
||||
param_kv = re.findall(
|
||||
r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL
|
||||
)
|
||||
assert len(param_kv) == 1, "Parameter format error"
|
||||
param_name, string, param_value = param_kv[0]
|
||||
|
||||
assert param_name not in tool_args, "Duplicate parameter name"
|
||||
tool_args[param_name] = (param_value, string)
|
||||
|
||||
index, content, stop_token = _read_until_stop(
|
||||
index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"]
|
||||
)
|
||||
assert content == ">\n", "Parameter format error"
|
||||
|
||||
tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return index, stop_token, tool_calls
|
||||
|
||||
|
||||
# NOTE: This function is designed to parse only correctly
|
||||
# formatted string and will not attempt to correct malformed output
|
||||
# that may be generated by the model.
|
||||
def parse_message_from_completion_text(text: str, thinking_mode: str):
|
||||
summary_content, reasoning_content, tool_calls = "", "", []
|
||||
index, stop_token = 0, None
|
||||
tool_calls_start_token = f"\n\n<{dsml_token}function_calls"
|
||||
|
||||
is_thinking, is_tool_calling = thinking_mode == "thinking", False
|
||||
|
||||
if is_thinking:
|
||||
index, content_delta, stop_token = _read_until_stop(
|
||||
index, text, [thinking_end_token, tool_calls_start_token]
|
||||
)
|
||||
reasoning_content = content_delta
|
||||
assert stop_token == thinking_end_token, "Invalid thinking format"
|
||||
|
||||
index, content_delta, stop_token = _read_until_stop(
|
||||
index, text, [eos_token, tool_calls_start_token]
|
||||
)
|
||||
summary_content = content_delta
|
||||
if stop_token == tool_calls_start_token:
|
||||
is_tool_calling = True
|
||||
else:
|
||||
assert stop_token == eos_token, "Invalid summary format"
|
||||
|
||||
if is_tool_calling:
|
||||
index, stop_token, tool_calls = parse_tool_calls(index, text)
|
||||
|
||||
index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])
|
||||
assert not tool_ends_text, "Unexpected content after tool calls"
|
||||
|
||||
assert len(text) == index and stop_token in [eos_token, None], "Unexpected content at end"
|
||||
|
||||
for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]:
|
||||
assert sp_token not in summary_content and sp_token not in reasoning_content, (
|
||||
"Unexpected special token in content"
|
||||
)
|
||||
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": summary_content,
|
||||
"reasoning_content": reasoning_content,
|
||||
"reasoning": reasoning_content,
|
||||
"tool_calls": tool_calls_to_openai_format(tool_calls),
|
||||
}
|
||||
147
tensorrt_llm/tokenizer/deepseek_v32/tokenizer.py
Normal file
147
tensorrt_llm/tokenizer/deepseek_v32/tokenizer.py
Normal file
@ -0,0 +1,147 @@
|
||||
"""DeepSeek V3.2 tokenizer implementation.
|
||||
|
||||
This is a temporary workaround for DeepSeek-V3.2 model as HF does not support it yet.
|
||||
TODO: Remove this once HF supports DeepSeek-V3.2
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from ..tokenizer import TransformersTokenizer
|
||||
from .encoding import encode_messages
|
||||
|
||||
|
||||
class DeepseekV32Tokenizer(TransformersTokenizer):
|
||||
"""DeepSeek V3.2 tokenizer with custom chat template."""
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
# tokenizer should be the HF tokenizer
|
||||
self.tokenizer = tokenizer
|
||||
self._all_special_tokens_set = set(self.tokenizer.all_special_tokens)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
path_or_repo_id: str | Path,
|
||||
*args,
|
||||
trust_remote_code: bool = False,
|
||||
revision: str | None = None,
|
||||
download_dir: str | None = None,
|
||||
**kwargs,
|
||||
) -> "DeepseekV32Tokenizer":
|
||||
# Load HF tokenizer
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(
|
||||
path_or_repo_id,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
return DeepseekV32Tokenizer(hf_tokenizer)
|
||||
|
||||
def apply_chat_template(self, messages, tools=None, **kwargs):
|
||||
thinking = kwargs.get("thinking", False)
|
||||
thinking_mode = "thinking" if thinking else "chat"
|
||||
messages = messages.copy()
|
||||
drop_thinking = True
|
||||
if tools is not None and len(tools) > 0:
|
||||
messages.insert(0, {"role": "system"})
|
||||
messages[0]["tools"] = tools
|
||||
drop_thinking = False
|
||||
encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking)
|
||||
prompt_str = encode_messages(messages, **encode_config) # type: ignore
|
||||
return prompt_str
|
||||
|
||||
@property
|
||||
def all_special_tokens(self) -> list[str]:
|
||||
return self.tokenizer.all_special_tokens
|
||||
|
||||
@property
|
||||
def all_special_ids(self) -> list[int]:
|
||||
return self.tokenizer.all_special_ids
|
||||
|
||||
@property
|
||||
def bos_token_id(self) -> int:
|
||||
return self.tokenizer.bos_token_id
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> int:
|
||||
return self.tokenizer.eos_token_id
|
||||
|
||||
@property
|
||||
def pad_token_id(self) -> int:
|
||||
return self.tokenizer.pad_token_id
|
||||
|
||||
@property
|
||||
def is_fast(self) -> bool:
|
||||
return self.tokenizer.is_fast
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.tokenizer.vocab_size
|
||||
|
||||
@property
|
||||
def max_token_id(self) -> int:
|
||||
return self.tokenizer.max_token_id
|
||||
|
||||
@property
|
||||
def truncation_side(self) -> str:
|
||||
return self.tokenizer.truncation_side
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(id(self))
|
||||
|
||||
def __len__(self) -> int:
|
||||
# </think> is an added token in DeepseekV32 tokenizer
|
||||
return self.vocab_size + len(self.get_added_vocab())
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: str | list[str],
|
||||
text_pair: str | None = None,
|
||||
add_special_tokens: bool = True,
|
||||
truncation: bool = False,
|
||||
max_length: int | None = None,
|
||||
) -> Any:
|
||||
return self.tokenizer(
|
||||
text,
|
||||
text_pair=text_pair,
|
||||
add_special_tokens=add_special_tokens,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
def get_vocab(self) -> dict[str, int]:
|
||||
return self.tokenizer.get_vocab()
|
||||
|
||||
def get_added_vocab(self) -> dict[str, int]:
|
||||
return self.tokenizer.get_added_vocab()
|
||||
|
||||
def encode(
|
||||
self,
|
||||
text: str,
|
||||
truncation: bool | None = None,
|
||||
max_length: int | None = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> list[int]:
|
||||
return self.tokenizer.encode(
|
||||
text,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
|
||||
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
||||
return self.tokenizer.convert_tokens_to_string(tokens)
|
||||
|
||||
def decode(self, ids: list[int] | int, skip_special_tokens: bool = False, **kwargs) -> str:
|
||||
return self.tokenizer.decode(ids, skip_special_tokens=skip_special_tokens, **kwargs)
|
||||
|
||||
def convert_ids_to_tokens(
|
||||
self,
|
||||
ids: list[int],
|
||||
skip_special_tokens: bool = False,
|
||||
) -> list[str]:
|
||||
return self.tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
|
||||
365
tensorrt_llm/tokenizer/tokenizer.py
Normal file
365
tensorrt_llm/tokenizer/tokenizer.py
Normal file
@ -0,0 +1,365 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from transformers import (AutoTokenizer, PreTrainedTokenizerBase,
|
||||
PreTrainedTokenizerFast)
|
||||
|
||||
from .._utils import nvtx_range_debug
|
||||
from ..logger import logger
|
||||
|
||||
TLLM_INCREMENTAL_DETOKENIZATION_BACKEND = os.environ.get(
|
||||
"TLLM_INCREMENTAL_DETOKENIZATION_BACKEND", "HF")
|
||||
TLLM_STREAM_INTERVAL_THRESHOLD = int(
|
||||
os.environ.get("TLLM_STREAM_INTERVAL_THRESHOLD", "24"))
|
||||
try:
|
||||
from tokenizers.decoders import DecodeStream # noqa
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
f"HF incremental detokenization is unsupported by tokenizer<0.21.0; fallback to TRTLLM incremental detokenization."
|
||||
)
|
||||
TLLM_INCREMENTAL_DETOKENIZATION_BACKEND = "TRTLLM"
|
||||
|
||||
|
||||
class TokenizerBase(PreTrainedTokenizerBase):
|
||||
''' This is a protocol for the tokenizer. Users can implement their own tokenizer by inheriting this class. '''
|
||||
|
||||
|
||||
class TransformersTokenizer(TokenizerBase):
|
||||
''' A wrapper for the Transformers' tokenizer.
|
||||
This is the default tokenizer for LLM. '''
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
self._all_special_tokens_set = set(self.tokenizer.all_special_tokens)
|
||||
|
||||
def __call__(self, text: str, *args, **kwargs) -> Any:
|
||||
return self.tokenizer(text, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> int:
|
||||
return self.tokenizer.eos_token_id
|
||||
|
||||
@property
|
||||
def pad_token_id(self) -> int:
|
||||
return self.tokenizer.pad_token_id
|
||||
|
||||
@property
|
||||
def name_or_path(self) -> str:
|
||||
return self.tokenizer.name_or_path
|
||||
|
||||
def encode(self, text: str, *args, **kwargs) -> List[int]:
|
||||
return self.tokenizer.encode(text, *args, **kwargs)
|
||||
|
||||
def decode(self, token_ids: List[int], *args, **kwargs) -> str:
|
||||
return self.tokenizer.decode(token_ids, *args, **kwargs)
|
||||
|
||||
def batch_encode_plus(self, texts: List[str], *args, **kwargs) -> dict:
|
||||
return self.tokenizer.batch_encode_plus(texts, *args, **kwargs)
|
||||
|
||||
def get_chat_template(self,
|
||||
chat_template: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None) -> str:
|
||||
return self.tokenizer.get_chat_template(chat_template, tools)
|
||||
|
||||
def apply_chat_template(
|
||||
self, conversation: Union[List[Dict[str, str]],
|
||||
List[List[Dict[str, str]]]], *args,
|
||||
**kwargs) -> Union[str, List[int], List[str], List[List[int]]]:
|
||||
return self.tokenizer.apply_chat_template(conversation, *args, **kwargs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.tokenizer})"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_dir: str, **kwargs):
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir,
|
||||
**kwargs)
|
||||
return cls(tokenizer)
|
||||
|
||||
def save_pretrained(self, pretrained_model_dir: str, **kwargs):
|
||||
self.tokenizer.save_pretrained(pretrained_model_dir, **kwargs)
|
||||
|
||||
def clean_up_tokenization(self, out_string: str) -> str:
|
||||
return self.tokenizer.clean_up_tokenization(out_string)
|
||||
|
||||
@property
|
||||
def clean_up_tokenization_spaces(self):
|
||||
return self.tokenizer.clean_up_tokenization_spaces
|
||||
|
||||
@property
|
||||
def is_fast(self) -> bool:
|
||||
return self.tokenizer.is_fast
|
||||
|
||||
def get_added_vocab(self) -> Dict[str, int]:
|
||||
# Assumed to be O(1) complexity
|
||||
return self.tokenizer.get_added_vocab()
|
||||
|
||||
def convert_ids_to_tokens(
|
||||
self,
|
||||
ids: Union[int, List[int]],
|
||||
skip_special_tokens: bool = False) -> Union[str, List[str]]:
|
||||
return self.tokenizer.convert_ids_to_tokens(
|
||||
ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
def convert_tokens_to_string(
|
||||
self,
|
||||
tokens: List[str],
|
||||
skip_special_tokens: bool = False,
|
||||
spaces_between_special_tokens: bool = True) -> str:
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/transformers_utils/detokenizer.py#L172
|
||||
if self.is_fast or not self.get_added_vocab():
|
||||
return self.tokenizer.convert_tokens_to_string(tokens)
|
||||
|
||||
sub_texts: List[str] = []
|
||||
current_sub_text: List[str] = []
|
||||
for token in tokens:
|
||||
if skip_special_tokens and token in self._all_special_tokens_set:
|
||||
continue
|
||||
if token in self.get_added_vocab():
|
||||
if current_sub_text:
|
||||
sub_text = self.tokenizer.convert_tokens_to_string(
|
||||
current_sub_text)
|
||||
sub_texts.append(sub_text)
|
||||
current_sub_text = []
|
||||
sub_texts.append(token)
|
||||
else:
|
||||
current_sub_text.append(token)
|
||||
if current_sub_text:
|
||||
sub_text = self.tokenizer.convert_tokens_to_string(current_sub_text)
|
||||
sub_texts.append(sub_text)
|
||||
if spaces_between_special_tokens:
|
||||
return " ".join(sub_texts)
|
||||
else:
|
||||
return "".join(sub_texts)
|
||||
|
||||
@nvtx_range_debug("decode_incrementally")
|
||||
def decode_incrementally(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
prev_text: Optional[str] = None,
|
||||
states: Optional[dict] = None,
|
||||
*,
|
||||
flush: bool = False,
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: Optional[bool] = None,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
stream_interval: int = 1) -> Tuple[str, dict]:
|
||||
"""Incremental detokenization, typically used for streaming generation.
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The incremental token ids.
|
||||
prev_text (str): The previous decoded text. None if it's the first iteration.
|
||||
states (dict): A dict that saves previous states for incremental detokenization. None if it's the first iteration.
|
||||
flush (bool): Force flushing the pending tokens to decoded text.
|
||||
skip_special_tokens (bool): Whether to remove special tokens in the decoding.
|
||||
clean_up_tokenization_spaces (bool): Whether to clean up tokenization spaces.
|
||||
spaces_between_special_tokens (bool): Whether to add spaces between special tokens.
|
||||
stream_interval (int): The iteration interval to create responses under the streaming mode.
|
||||
|
||||
Returns:
|
||||
text, states (Tuple[str, dict]): text is the current decoded text, states is the current incremental detokenization states.
|
||||
They should be passed to next incremental detokenization iteration, if any.
|
||||
"""
|
||||
# HF incremental detokenization implementation is faster than TRTLLM when stream_interval is smaller.
|
||||
if (TLLM_INCREMENTAL_DETOKENIZATION_BACKEND == "TRTLLM"
|
||||
or stream_interval >= TLLM_STREAM_INTERVAL_THRESHOLD
|
||||
or spaces_between_special_tokens is False
|
||||
or not hasattr(self.tokenizer, "_tokenizer")):
|
||||
return self.trtllm_decode_incrementally(
|
||||
token_ids,
|
||||
prev_text,
|
||||
states,
|
||||
flush=flush,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens)
|
||||
else:
|
||||
return self.hf_decode_incrementally(
|
||||
token_ids,
|
||||
prev_text,
|
||||
states,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces)
|
||||
|
||||
def trtllm_decode_incrementally(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
prev_text: Optional[str] = None,
|
||||
states: Optional[dict] = None,
|
||||
*,
|
||||
flush: bool = False,
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: Optional[bool] = None,
|
||||
spaces_between_special_tokens: bool = True) -> Tuple[str, dict]:
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/transformers_utils/detokenizer.py#L238
|
||||
if prev_text is None:
|
||||
prev_text = ""
|
||||
|
||||
if states is None:
|
||||
states = {}
|
||||
last_new_tokens = states.pop('last_new_tokens', [])
|
||||
pending_tokens = states.pop('pending_tokens', [])
|
||||
|
||||
if len(last_new_tokens) > 0:
|
||||
last_new_text = self.convert_tokens_to_string(
|
||||
last_new_tokens,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens)
|
||||
else:
|
||||
last_new_text = ""
|
||||
|
||||
new_tokens = self.convert_ids_to_tokens(
|
||||
token_ids, skip_special_tokens=skip_special_tokens)
|
||||
pending_tokens.extend(new_tokens)
|
||||
|
||||
curr_new_text = self.convert_tokens_to_string(
|
||||
last_new_tokens + pending_tokens,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens)
|
||||
if not flush and (len(curr_new_text.rstrip()) <= len(
|
||||
last_new_text.rstrip()) or curr_new_text.endswith("<EFBFBD>")):
|
||||
return prev_text, {
|
||||
'last_new_tokens': last_new_tokens,
|
||||
'pending_tokens': pending_tokens
|
||||
}
|
||||
|
||||
# Remove the part of last_new_text
|
||||
curr_new_text = curr_new_text[len(last_new_text):]
|
||||
if clean_up_tokenization_spaces is None:
|
||||
clean_up_tokenization_spaces = self.clean_up_tokenization_spaces
|
||||
if clean_up_tokenization_spaces:
|
||||
curr_new_text = self.clean_up_tokenization(curr_new_text)
|
||||
return prev_text + curr_new_text, {'last_new_tokens': pending_tokens}
|
||||
|
||||
def hf_decode_incrementally(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
prev_text: Optional[str] = None,
|
||||
states: Optional[dict] = None,
|
||||
*,
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: Optional[bool] = None
|
||||
) -> Tuple[str, dict]:
|
||||
if states is None:
|
||||
states = {
|
||||
'decode_stream':
|
||||
DecodeStream(skip_special_tokens=skip_special_tokens)
|
||||
}
|
||||
|
||||
decode_stream = states.get('decode_stream')
|
||||
results = [
|
||||
result for tid in token_ids
|
||||
if (result := decode_stream.step(self.tokenizer._tokenizer, tid)
|
||||
) is not None
|
||||
]
|
||||
curr_new_text = "".join(results)
|
||||
if clean_up_tokenization_spaces is None:
|
||||
clean_up_tokenization_spaces = self.clean_up_tokenization_spaces
|
||||
if clean_up_tokenization_spaces:
|
||||
curr_new_text = self.clean_up_tokenization(curr_new_text)
|
||||
|
||||
if prev_text is None:
|
||||
return curr_new_text, states
|
||||
else:
|
||||
return prev_text + curr_new_text, states
|
||||
|
||||
|
||||
def tokenizer_factory(obj: Optional[Union[str, Path, PreTrainedTokenizerBase,
|
||||
TokenizerBase]] = None,
|
||||
**kwargs) -> Optional[TokenizerBase]:
|
||||
if obj is None:
|
||||
return None
|
||||
elif isinstance(obj, (str, Path)):
|
||||
default_kwargs = {
|
||||
'legacy': False,
|
||||
'padding_side': 'left',
|
||||
'truncation_side': 'left',
|
||||
'trust_remote_code': True,
|
||||
'use_fast': True,
|
||||
}
|
||||
default_kwargs.update(kwargs)
|
||||
return TransformersTokenizer.from_pretrained(obj, **default_kwargs)
|
||||
elif isinstance(obj, TokenizerBase):
|
||||
return obj
|
||||
elif isinstance(obj, PreTrainedTokenizerBase):
|
||||
return TransformersTokenizer(obj)
|
||||
else:
|
||||
raise TypeError(f"Unrecognized tokenizer {obj}")
|
||||
|
||||
|
||||
def _xgrammar_tokenizer_info(tokenizer):
|
||||
# Reference: https://github.com/mlc-ai/xgrammar/blob/b9a16de54e1e0eff58da14c65750414cceaf1a6f/python/xgrammar/tokenizer_info.py#L133
|
||||
if isinstance(tokenizer, TokenizerBase):
|
||||
tokenizer = tokenizer.tokenizer
|
||||
|
||||
stop_token_ids = [tokenizer.eos_token_id]
|
||||
|
||||
try:
|
||||
encoded_vocab = tokenizer.get_vocab()
|
||||
encoded_vocab = [
|
||||
token
|
||||
for token, _ in sorted(encoded_vocab.items(), key=lambda x: x[1])
|
||||
]
|
||||
except AttributeError as e:
|
||||
msg = (
|
||||
f"Cannot get the vocabulary of the tokenizer {type(tokenizer)}. The tokenizer "
|
||||
"should have a get_vocab method.")
|
||||
raise ValueError(msg) from e
|
||||
|
||||
if isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
backend_str = tokenizer.backend_tokenizer.to_str()
|
||||
return {
|
||||
"encoded_vocab": encoded_vocab,
|
||||
"tokenizer_str": backend_str,
|
||||
"stop_token_ids": stop_token_ids
|
||||
}
|
||||
elif ("vocab_file" in tokenizer.vocab_files_names
|
||||
and "tiktoken" in tokenizer.vocab_files_names["vocab_file"]):
|
||||
return {
|
||||
"encoded_vocab": encoded_vocab,
|
||||
"stop_token_ids": stop_token_ids
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported tokenizer type: {type(tokenizer)}")
|
||||
|
||||
|
||||
def _llguidance_tokenizer_info(tokenizer):
|
||||
tokenizer_info = _xgrammar_tokenizer_info(tokenizer)
|
||||
if tokenizer_info.get("tokenizer_str") is None:
|
||||
raise ValueError("missing tokenizer_str")
|
||||
return tokenizer_info
|
||||
|
||||
|
||||
def load_hf_tokenizer(model_dir: str,
|
||||
trust_remote_code: bool = True,
|
||||
use_fast: bool = True,
|
||||
**kwargs) -> Optional[TransformersTokenizer]:
|
||||
''' Load a tokenizer from a Hugging Face model directory.
|
||||
|
||||
Args:
|
||||
model_dir (str): The model directory.
|
||||
trust_remote_code (bool): Whether to trust the remote code.
|
||||
use_fast (bool): Whether to use the fast tokenizer.
|
||||
|
||||
Returns:
|
||||
A TransformersTokenizer object if the tokenizer is loaded successfully.
|
||||
'''
|
||||
|
||||
try:
|
||||
return TransformersTokenizer.from_pretrained(
|
||||
model_dir,
|
||||
legacy=False,
|
||||
padding_side='left',
|
||||
truncation_side='left',
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_fast=use_fast,
|
||||
**kwargs)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to load hf tokenizer from {model_dir}, encounter error: {e}"
|
||||
)
|
||||
return None
|
||||
@ -60,6 +60,10 @@ methods:
|
||||
annotation: Optional[str]
|
||||
default: null
|
||||
status: prototype
|
||||
custom_tokenizer:
|
||||
annotation: Optional[str]
|
||||
default: null
|
||||
status: prototype
|
||||
# reasoning
|
||||
reasoning_parser:
|
||||
annotation: Optional[str]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user