mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-06 03:01:50 +08:00
* support mcp # Conflicts: # tensorrt_llm/scaffolding/worker.py Signed-off-by: wu1du2 <wu1du2@gmail.com> * move all into contrib/mcp # Conflicts: # examples/scaffolding/contrib/mcp/mcptest.py # tensorrt_llm/scaffolding/__init__.py # tensorrt_llm/scaffolding/contrib/__init__.py # tensorrt_llm/scaffolding/contrib/mcp/__init__.py # tensorrt_llm/scaffolding/contrib/mcp/mcp_controller.py # tensorrt_llm/scaffolding/task.py # tensorrt_llm/scaffolding/worker.py Signed-off-by: wu1du2 <wu1du2@gmail.com> * support sandbox, websearch # Conflicts: # examples/scaffolding/contrib/mcp/mcptest.py # examples/scaffolding/contrib/mcp/weather/weather.py # tensorrt_llm/scaffolding/contrib/mcp/mcp_controller.py # tensorrt_llm/scaffolding/contrib/mcp/mcp_utils.py # tensorrt_llm/scaffolding/contrib/mcp/mcp_worker.py # tensorrt_llm/scaffolding/worker.py Signed-off-by: wu1du2 <wu1du2@gmail.com> * remove pics Signed-off-by: wu1du2 <wu1du2@gmail.com> * pre-commit fix # Conflicts: # tensorrt_llm/scaffolding/contrib/mcp/__init__.py # tensorrt_llm/scaffolding/contrib/mcp/mcp_utils.py # tensorrt_llm/scaffolding/contrib/mcp/mcp_worker.py Signed-off-by: wu1du2 <wu1du2@gmail.com> * fix spell Signed-off-by: wu1du2 <wu1du2@gmail.com> * rebase Signed-off-by: wu1du2 <wu1du2@gmail.com> --------- Signed-off-by: wu1du2 <wu1du2@gmail.com>
46 lines
1.1 KiB
Python
46 lines
1.1 KiB
Python
from dataclasses import dataclass, field
|
|
from typing import Optional, Union
|
|
|
|
from tensorrt_llm.scaffolding.task import Task
|
|
|
|
|
|
@dataclass
|
|
class MCPCallTask(Task):
|
|
# mcp inputs
|
|
tool_name: Optional[str] = field(default=None)
|
|
args: Optional[dict] = field(default=None)
|
|
# retrying control
|
|
retry: Optional[int] = field(default=1)
|
|
delay: Optional[float] = field(default=10)
|
|
|
|
worker_tag: Union[str, "Controller.WorkerTag"] = None
|
|
|
|
#result field
|
|
result_str: Optional[str] = None
|
|
|
|
@staticmethod
|
|
def create_mcptask(tool_name: str,
|
|
args: dict,
|
|
retry: int = 1,
|
|
delay: float = 1) -> "MCPCallTask":
|
|
task = MCPCallTask()
|
|
task.tool_name = tool_name
|
|
task.args = args
|
|
task.retry = retry
|
|
task.delay = delay
|
|
return task
|
|
|
|
|
|
@dataclass
|
|
class MCPListTask(Task):
|
|
worker_tag: Union[str, "Controller.WorkerTag"] = None
|
|
|
|
#result field
|
|
result_str: Optional[str] = None
|
|
result_tools = None
|
|
|
|
@staticmethod
|
|
def create_mcptask() -> "MCPListTask":
|
|
task = MCPListTask()
|
|
return task
|