mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat: support more parameters in openai worker of scaffolding (#5115)
Signed-off-by: Clay <ccs96307@gmail.com>
This commit is contained in:
parent
24ac9b5f69
commit
7a319524da
@ -1,9 +1,11 @@
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm.serve.openai_protocol import StreamOptions
|
||||
|
||||
|
||||
class ScaffoldingOutput:
|
||||
|
||||
@ -37,10 +39,28 @@ class GenerationTask(Task):
|
||||
skip_tokenizer: bool = False
|
||||
skip_detokenizer: bool = False
|
||||
|
||||
# sampling params
|
||||
# sampling params for openai
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/completions/create
|
||||
# The special case is `num_logprobs`, its original name si `logprobs` but conflicted by the result field
|
||||
best_of: Optional[int] = None
|
||||
echo: Optional[bool] = False
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
num_logprobs: Optional[int] = None
|
||||
max_tokens: Optional[int] = 2048
|
||||
n: int = 1
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
suffix: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
# sampling params
|
||||
top_k: Optional[int] = None
|
||||
return_context_logits: Optional[bool] = False
|
||||
|
||||
|
||||
@ -73,9 +73,25 @@ class OpenaiWorker(Worker):
|
||||
"model": self.model,
|
||||
"prompt": task.input_str,
|
||||
}
|
||||
add_param_if_not_none(params, "best_of", [task.best_of])
|
||||
add_param_if_not_none(params, "echo", [task.echo])
|
||||
add_param_if_not_none(params, "frequency_penalty",
|
||||
[task.frequency_penalty])
|
||||
add_param_if_not_none(params, "logit_bias", [task.logit_bias])
|
||||
add_param_if_not_none(params, "logprobs", [task.num_logprobs])
|
||||
add_param_if_not_none(params, "max_tokens", [task.max_tokens])
|
||||
add_param_if_not_none(params, "n", [task.n])
|
||||
add_param_if_not_none(params, "presence_penalty",
|
||||
[task.presence_penalty])
|
||||
add_param_if_not_none(params, "seed", [task.seed])
|
||||
add_param_if_not_none(params, "stop", [task.stop])
|
||||
add_param_if_not_none(params, "stream", [task.stream])
|
||||
add_param_if_not_none(params, "stream_options", [task.stream_options])
|
||||
add_param_if_not_none(params, "suffix", [task.suffix])
|
||||
add_param_if_not_none(params, "temperature", [task.temperature])
|
||||
add_param_if_not_none(params, "top_p", [task.top_p])
|
||||
add_param_if_not_none(params, "user", [task.user])
|
||||
|
||||
return params
|
||||
|
||||
def fill_generation_task_with_response(self, task: GenerationTask,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user