mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-06 03:01:50 +08:00
30 lines
785 B
Python
30 lines
785 B
Python
from dataclasses import dataclass
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from tensorrt_llm.scaffolding import GenerationTask
|
|
|
|
|
|
@dataclass
|
|
class ChatTask(GenerationTask):
|
|
messages: list = None
|
|
tools = None
|
|
finish_reason = None
|
|
tool_calls = None
|
|
|
|
@staticmethod
|
|
def create_from_prompt(messages: list, prompt: str, tools) -> "ChatTask":
|
|
task = ChatTask()
|
|
messages.append({"role": "user", "content": prompt})
|
|
task.messages = messages
|
|
task.tools = tools
|
|
return task
|
|
|
|
@staticmethod
|
|
def from_messages(
|
|
messages: List[str],
|
|
tools: Optional[List[Dict[str, Any]]] = None) -> "ChatTask":
|
|
task = ChatTask()
|
|
task.messages = messages
|
|
task.tools = tools
|
|
return task
|