mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-14 06:53:50 +08:00
Scaffoldingllm supports MCP (#4410)
* support mcp # Conflicts: # tensorrt_llm/scaffolding/worker.py Signed-off-by: wu1du2 <wu1du2@gmail.com> * move all into contrib/mcp # Conflicts: # examples/scaffolding/contrib/mcp/mcptest.py # tensorrt_llm/scaffolding/__init__.py # tensorrt_llm/scaffolding/contrib/__init__.py # tensorrt_llm/scaffolding/contrib/mcp/__init__.py # tensorrt_llm/scaffolding/contrib/mcp/mcp_controller.py # tensorrt_llm/scaffolding/task.py # tensorrt_llm/scaffolding/worker.py Signed-off-by: wu1du2 <wu1du2@gmail.com> * support sandbox, websearch # Conflicts: # examples/scaffolding/contrib/mcp/mcptest.py # examples/scaffolding/contrib/mcp/weather/weather.py # tensorrt_llm/scaffolding/contrib/mcp/mcp_controller.py # tensorrt_llm/scaffolding/contrib/mcp/mcp_utils.py # tensorrt_llm/scaffolding/contrib/mcp/mcp_worker.py # tensorrt_llm/scaffolding/worker.py Signed-off-by: wu1du2 <wu1du2@gmail.com> * remove pics Signed-off-by: wu1du2 <wu1du2@gmail.com> * pre-commit fix # Conflicts: # tensorrt_llm/scaffolding/contrib/mcp/__init__.py # tensorrt_llm/scaffolding/contrib/mcp/mcp_utils.py # tensorrt_llm/scaffolding/contrib/mcp/mcp_worker.py Signed-off-by: wu1du2 <wu1du2@gmail.com> * fix spell Signed-off-by: wu1du2 <wu1du2@gmail.com> * rebase Signed-off-by: wu1du2 <wu1du2@gmail.com> --------- Signed-off-by: wu1du2 <wu1du2@gmail.com>
This commit is contained in:
parent
338744fba6
commit
60a6c20174
50
examples/scaffolding/contrib/mcp/README.md
Normal file
50
examples/scaffolding/contrib/mcp/README.md
Normal file
@ -0,0 +1,50 @@
|
||||
# MCP USAGE
|
||||
|
||||
## Step1: Run Servers
|
||||
|
||||
### Terminal1:
|
||||
|
||||
`cd weather`
|
||||
|
||||
`pip install uv`
|
||||
|
||||
`uv add "mcp[cli]" httpx openai`
|
||||
|
||||
`uv pip install httpx mcp`
|
||||
`uv init --no-workspace`
|
||||
`uv run weather.py`
|
||||
|
||||
|
||||
|
||||
### Terminal2:
|
||||
|
||||
`cd e2b`
|
||||
|
||||
`pip install uv`
|
||||
|
||||
`uv add "mcp[cli]" httpx openai`
|
||||
|
||||
`uv pip install e2b_code_interpreter mcp`
|
||||
`uv init --no-workspace`
|
||||
`uv run e2bserver.py`
|
||||
|
||||
|
||||
|
||||
### Terminal3:
|
||||
|
||||
`cd websearch`
|
||||
|
||||
`pip install uv`
|
||||
|
||||
`uv add "mcp[cli]" httpx openai`
|
||||
`uv pip install brave-search mcp starlette uvicorn`
|
||||
`uv init --no-workspace`
|
||||
`uv run websearch.py`
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## Step2: Run Test
|
||||
|
||||
`python3 mcptest.py --API_KEY YOUR_API_KEY`
|
||||
1
examples/scaffolding/contrib/mcp/e2b/.env
Normal file
1
examples/scaffolding/contrib/mcp/e2b/.env
Normal file
@ -0,0 +1 @@
|
||||
E2B_API_KEY=$YOUR_API_KEY
|
||||
0
examples/scaffolding/contrib/mcp/e2b/README.md
Normal file
0
examples/scaffolding/contrib/mcp/e2b/README.md
Normal file
92
examples/scaffolding/contrib/mcp/e2b/e2bserver.py
Normal file
92
examples/scaffolding/contrib/mcp/e2b/e2bserver.py
Normal file
@ -0,0 +1,92 @@
|
||||
import logging
|
||||
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
from e2b_code_interpreter import Sandbox
|
||||
from mcp.server import Server
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from pydantic import BaseModel
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
# Initialize FastMCP server for Weather tools (SSE)
|
||||
mcp = FastMCP("sandbox")
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("e2b-mcp-server")
|
||||
|
||||
|
||||
# Tool schema
|
||||
class ToolSchema(BaseModel):
|
||||
code: str
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def run_code(code: str) -> str:
|
||||
"""Run python code in a secure sandbox by E2B. Using the Jupyter Notebook syntax. Response include 1.results, the function return value. 2.stdout, the standard output. 3.stderr, the standard error.
|
||||
Args:
|
||||
code: string in Jupyter Notebook syntax.
|
||||
"""
|
||||
|
||||
sbx = Sandbox()
|
||||
execution = sbx.run_code(code)
|
||||
logger.info(f"Execution: {execution}")
|
||||
|
||||
result = {
|
||||
"results": execution.results,
|
||||
"stdout": execution.logs.stdout,
|
||||
"stderr": execution.logs.stderr,
|
||||
}
|
||||
|
||||
return f"{result}"
|
||||
|
||||
|
||||
def create_starlette_app(mcp_server: Server,
|
||||
*,
|
||||
debug: bool = False) -> Starlette:
|
||||
"""Create a Starlette application that can server the provided mcp server with SSE."""
|
||||
sse = SseServerTransport("/messages/")
|
||||
|
||||
async def handle_sse(request: Request) -> None:
|
||||
async with sse.connect_sse(
|
||||
request.scope,
|
||||
request.receive,
|
||||
request._send, # noqa: SLF001
|
||||
) as (read_stream, write_stream):
|
||||
await mcp_server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
mcp_server.create_initialization_options(),
|
||||
)
|
||||
|
||||
return Starlette(
|
||||
debug=debug,
|
||||
routes=[
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp_server = mcp._mcp_server # noqa: WPS437
|
||||
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run MCP SSE-based server')
|
||||
parser.add_argument('--host', default='0.0.0.0', help='Host to bind to')
|
||||
parser.add_argument('--port',
|
||||
type=int,
|
||||
default=8081,
|
||||
help='Port to listen on')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Bind SSE request handling to MCP server
|
||||
starlette_app = create_starlette_app(mcp_server, debug=True)
|
||||
|
||||
uvicorn.run(starlette_app, host=args.host, port=args.port)
|
||||
6
examples/scaffolding/contrib/mcp/e2b/main.py
Normal file
6
examples/scaffolding/contrib/mcp/e2b/main.py
Normal file
@ -0,0 +1,6 @@
|
||||
def main():
|
||||
print("Hello from e2b!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
7
examples/scaffolding/contrib/mcp/e2b/pyproject.toml
Normal file
7
examples/scaffolding/contrib/mcp/e2b/pyproject.toml
Normal file
@ -0,0 +1,7 @@
|
||||
[project]
|
||||
name = "e2b"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = []
|
||||
77
examples/scaffolding/contrib/mcp/mcptest.py
Normal file
77
examples/scaffolding/contrib/mcp/mcptest.py
Normal file
@ -0,0 +1,77 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from tensorrt_llm.scaffolding import OpenaiWorker, ScaffoldingLlm
|
||||
from tensorrt_llm.scaffolding.contrib import (ChatTask, MCPController,
|
||||
MCPWorker, chat_handler)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--base_url',
|
||||
type=str,
|
||||
default="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
default="qwen-plus-latest",
|
||||
)
|
||||
parser.add_argument('--API_KEY', type=str)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from tensorrt_llm.scaffolding import OpenaiWorker, ScaffoldingLlm
|
||||
from tensorrt_llm.scaffolding.contrib import MCPController, MCPWorker
|
||||
|
||||
|
||||
async def main():
|
||||
args = parse_arguments()
|
||||
prompts = [
|
||||
# "What's the weather like today in LA?"
|
||||
# 'Solve the problem with running python code: What is the number of Fibonacci array 20th element? The array goes like 0,1,1,2,3...'
|
||||
# 'Which game won TGA Best Action Game and Players Voice awards in 2024?'
|
||||
'What was the score of the NBA playoffs game 7 between the Thunder and the Nuggets in 2025?'
|
||||
]
|
||||
API_KEY = args.API_KEY
|
||||
urls = [
|
||||
"http://0.0.0.0:8080/sse", "http://0.0.0.0:8081/sse",
|
||||
"http://0.0.0.0:8082/sse"
|
||||
]
|
||||
print(f"API_KEY {API_KEY}")
|
||||
client = AsyncOpenAI(api_key=API_KEY, base_url=args.base_url)
|
||||
qwen_worker = OpenaiWorker(client, args.model)
|
||||
qwen_worker.register_task_handler(ChatTask, chat_handler)
|
||||
mcp_worker = await MCPWorker.init_with_urls(urls)
|
||||
|
||||
prototype_controller = MCPController()
|
||||
llm = ScaffoldingLlm(
|
||||
prototype_controller,
|
||||
{
|
||||
MCPController.WorkerTag.GENERATION: qwen_worker,
|
||||
MCPController.WorkerTag.MCP: mcp_worker
|
||||
},
|
||||
)
|
||||
|
||||
future = llm.generate_async(prompts[0])
|
||||
result = await future.aresult()
|
||||
print(f"\nresult is {result.output.output_str}\n")
|
||||
|
||||
print(f'main shutting down...')
|
||||
llm.shutdown()
|
||||
print(f'worker shutting down...')
|
||||
qwen_worker.shutdown()
|
||||
mcp_worker.shutdown()
|
||||
|
||||
print(f'main shut down done')
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
12
examples/scaffolding/contrib/mcp/weather/pyproject.toml
Normal file
12
examples/scaffolding/contrib/mcp/weather/pyproject.toml
Normal file
@ -0,0 +1,12 @@
|
||||
[project]
|
||||
name = "mcp"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = []
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = [
|
||||
"e2b",
|
||||
]
|
||||
144
examples/scaffolding/contrib/mcp/weather/weather.py
Normal file
144
examples/scaffolding/contrib/mcp/weather/weather.py
Normal file
@ -0,0 +1,144 @@
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
from mcp.server import Server
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
# Initialize FastMCP server for Weather tools (SSE)
|
||||
mcp = FastMCP("weather")
|
||||
|
||||
# Constants
|
||||
NWS_API_BASE = "https://api.weather.gov"
|
||||
USER_AGENT = "weather-app/1.0"
|
||||
|
||||
|
||||
async def make_nws_request(url: str) -> dict[str, Any] | None:
|
||||
"""Make a request to the NWS API with proper error handling."""
|
||||
headers = {"User-Agent": USER_AGENT, "Accept": "application/geo+json"}
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.get(url, headers=headers, timeout=30.0)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def format_alert(feature: dict) -> str:
|
||||
"""Format an alert feature into a readable string."""
|
||||
props = feature["properties"]
|
||||
return f"""
|
||||
Event: {props.get('event', 'Unknown')}
|
||||
Area: {props.get('areaDesc', 'Unknown')}
|
||||
Severity: {props.get('severity', 'Unknown')}
|
||||
Description: {props.get('description', 'No description available')}
|
||||
Instructions: {props.get('instruction', 'No specific instructions provided')}
|
||||
"""
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_alerts(state: str) -> str:
|
||||
"""Get weather alerts for a US state.
|
||||
|
||||
Args:
|
||||
state: Two-letter US state code (e.g. CA, NY)
|
||||
"""
|
||||
url = f"{NWS_API_BASE}/alerts/active/area/{state}"
|
||||
data = await make_nws_request(url)
|
||||
|
||||
if not data or "features" not in data:
|
||||
return "Unable to fetch alerts or no alerts found."
|
||||
|
||||
if not data["features"]:
|
||||
return "No active alerts for this state."
|
||||
|
||||
alerts = [format_alert(feature) for feature in data["features"]]
|
||||
return "\n---\n".join(alerts)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_forecast(latitude: float, longitude: float) -> str:
|
||||
"""Get weather forecast for a location.
|
||||
|
||||
Args:
|
||||
latitude: Latitude of the location
|
||||
longitude: Longitude of the location
|
||||
"""
|
||||
# First get the forecast grid endpoint
|
||||
points_url = f"{NWS_API_BASE}/points/{latitude},{longitude}"
|
||||
points_data = await make_nws_request(points_url)
|
||||
|
||||
if not points_data:
|
||||
return "Unable to fetch forecast data for this location."
|
||||
|
||||
# Get the forecast URL from the points response
|
||||
forecast_url = points_data["properties"]["forecast"]
|
||||
forecast_data = await make_nws_request(forecast_url)
|
||||
|
||||
if not forecast_data:
|
||||
return "Unable to fetch detailed forecast."
|
||||
|
||||
# Format the periods into a readable forecast
|
||||
periods = forecast_data["properties"]["periods"]
|
||||
forecasts = []
|
||||
for period in periods[:5]: # Only show next 5 periods
|
||||
forecast = f"""
|
||||
{period['name']}:
|
||||
Temperature: {period['temperature']}°{period['temperatureUnit']}
|
||||
Wind: {period['windSpeed']} {period['windDirection']}
|
||||
Forecast: {period['detailedForecast']}
|
||||
"""
|
||||
forecasts.append(forecast)
|
||||
|
||||
return "\n---\n".join(forecasts)
|
||||
|
||||
|
||||
def create_starlette_app(mcp_server: Server,
|
||||
*,
|
||||
debug: bool = False) -> Starlette:
|
||||
"""Create a Starlette application that can server the provided mcp server with SSE."""
|
||||
sse = SseServerTransport("/messages/")
|
||||
|
||||
async def handle_sse(request: Request) -> None:
|
||||
async with sse.connect_sse(
|
||||
request.scope,
|
||||
request.receive,
|
||||
request._send, # noqa: SLF001
|
||||
) as (read_stream, write_stream):
|
||||
await mcp_server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
mcp_server.create_initialization_options(),
|
||||
)
|
||||
|
||||
return Starlette(
|
||||
debug=debug,
|
||||
routes=[
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp_server = mcp._mcp_server # noqa: WPS437
|
||||
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run MCP SSE-based server')
|
||||
parser.add_argument('--host', default='0.0.0.0', help='Host to bind to')
|
||||
parser.add_argument('--port',
|
||||
type=int,
|
||||
default=8080,
|
||||
help='Port to listen on')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Bind SSE request handling to MCP server
|
||||
starlette_app = create_starlette_app(mcp_server, debug=True)
|
||||
|
||||
uvicorn.run(starlette_app, host=args.host, port=args.port)
|
||||
1
examples/scaffolding/contrib/mcp/websearch/.env
Normal file
1
examples/scaffolding/contrib/mcp/websearch/.env
Normal file
@ -0,0 +1 @@
|
||||
BRAVE_API_KEY=$YOUR_API_KEY
|
||||
6
examples/scaffolding/contrib/mcp/websearch/main.py
Normal file
6
examples/scaffolding/contrib/mcp/websearch/main.py
Normal file
@ -0,0 +1,6 @@
|
||||
def main():
|
||||
print("Hello from websearch!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
11
examples/scaffolding/contrib/mcp/websearch/pyproject.toml
Normal file
11
examples/scaffolding/contrib/mcp/websearch/pyproject.toml
Normal file
@ -0,0 +1,11 @@
|
||||
[project]
|
||||
name = "websearch"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"httpx>=0.28.1",
|
||||
"mcp[cli]>=1.9.0",
|
||||
"openai>=1.79.0",
|
||||
]
|
||||
76
examples/scaffolding/contrib/mcp/websearch/websearch.py
Normal file
76
examples/scaffolding/contrib/mcp/websearch/websearch.py
Normal file
@ -0,0 +1,76 @@
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
from brave import Brave
|
||||
from dotenv import load_dotenv
|
||||
from mcp.server import Server
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
# Initialize FastMCP server for Weather tools (SSE)
|
||||
mcp = FastMCP("websearch")
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def websearch(query: str) -> str:
|
||||
"""Web search, fetch information from the internet
|
||||
Args:
|
||||
query: string of what you want to search
|
||||
"""
|
||||
BRAVE_API_KEY = os.getenv("BRAVE_API_KEY")
|
||||
brave = Brave(BRAVE_API_KEY)
|
||||
print(f"brave apikey {BRAVE_API_KEY }")
|
||||
search_results = brave.search(q=query, raw=True)
|
||||
return f"{search_results}"
|
||||
|
||||
|
||||
def create_starlette_app(mcp_server: Server,
|
||||
*,
|
||||
debug: bool = False) -> Starlette:
|
||||
"""Create a Starlette application that can server the provided mcp server with SSE."""
|
||||
sse = SseServerTransport("/messages/")
|
||||
|
||||
async def handle_sse(request: Request) -> None:
|
||||
async with sse.connect_sse(
|
||||
request.scope,
|
||||
request.receive,
|
||||
request._send, # noqa: SLF001
|
||||
) as (read_stream, write_stream):
|
||||
await mcp_server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
mcp_server.create_initialization_options(),
|
||||
)
|
||||
|
||||
return Starlette(
|
||||
debug=debug,
|
||||
routes=[
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp_server = mcp._mcp_server # noqa: WPS437
|
||||
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run MCP SSE-based server')
|
||||
parser.add_argument('--host', default='0.0.0.0', help='Host to bind to')
|
||||
parser.add_argument('--port',
|
||||
type=int,
|
||||
default=8082,
|
||||
help='Port to listen on')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Bind SSE request handling to MCP server
|
||||
starlette_app = create_starlette_app(mcp_server, debug=True)
|
||||
|
||||
uvicorn.run(starlette_app, host=args.host, port=args.port)
|
||||
@ -2,6 +2,8 @@ from tensorrt_llm.scaffolding import * # noqa
|
||||
|
||||
from .AsyncGeneration import StreamGenerationTask, stream_generation_handler
|
||||
from .Dynasor import DynasorGenerationController
|
||||
from .mcp import (ChatTask, MCPCallTask, MCPController, MCPListTask, MCPWorker,
|
||||
chat_handler)
|
||||
|
||||
__all__ = [
|
||||
# AsyncGeneration
|
||||
@ -9,4 +11,11 @@ __all__ = [
|
||||
"StreamGenerationTask",
|
||||
# Dynasor
|
||||
"DynasorGenerationController",
|
||||
#mcp
|
||||
"MCPController",
|
||||
"MCPWorker",
|
||||
"MCPCallTask",
|
||||
"MCPListTask",
|
||||
"ChatTask",
|
||||
"chat_handler"
|
||||
]
|
||||
|
||||
10
tensorrt_llm/scaffolding/contrib/mcp/__init__.py
Normal file
10
tensorrt_llm/scaffolding/contrib/mcp/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
from .chat_handler import chat_handler
|
||||
from .chat_task import ChatTask
|
||||
from .mcp_controller import MCPController
|
||||
from .mcp_task import MCPCallTask, MCPListTask
|
||||
from .mcp_worker import MCPWorker
|
||||
|
||||
__all__ = [
|
||||
"MCPController", "MCPWorker", "MCPCallTask", "MCPListTask", "ChatTask",
|
||||
"chat_handler"
|
||||
]
|
||||
54
tensorrt_llm/scaffolding/contrib/mcp/chat_handler.py
Normal file
54
tensorrt_llm/scaffolding/contrib/mcp/chat_handler.py
Normal file
@ -0,0 +1,54 @@
|
||||
import openai
|
||||
|
||||
from tensorrt_llm.executor import GenerationExecutor
|
||||
from tensorrt_llm.scaffolding import TaskStatus
|
||||
from tensorrt_llm.scaffolding.contrib.mcp.chat_task import ChatTask
|
||||
|
||||
ExecutorCls = GenerationExecutor
|
||||
|
||||
|
||||
# helper function
|
||||
# add first non-None candidate_values to params with key
|
||||
def add_param_if_not_none(params, key, candidate_values):
|
||||
for value in candidate_values:
|
||||
if value is not None:
|
||||
params[key] = value
|
||||
return
|
||||
|
||||
|
||||
def combine_params_with_chat_task(worker, params: dict, task: ChatTask):
|
||||
params["messages"] = task.messages
|
||||
|
||||
add_param_if_not_none(params, "max_tokens",
|
||||
[task.max_tokens, worker.max_tokens])
|
||||
add_param_if_not_none(params, "temperature",
|
||||
[task.temperature, worker.temperature])
|
||||
add_param_if_not_none(params, "top_p", [task.top_p, worker.top_p])
|
||||
|
||||
add_param_if_not_none(params, "tools", [task.tools])
|
||||
|
||||
|
||||
def fill_chat_task_with_response(task: ChatTask, response: openai.Completion):
|
||||
task.output_str = response.choices[0].message.content
|
||||
task.finish_reason = response.choices[0].finish_reason
|
||||
task.tool_calls = response.choices[0].message.tool_calls
|
||||
task.logprobs = response.choices[0].logprobs
|
||||
|
||||
|
||||
async def chat_handler(worker, task: ChatTask) -> TaskStatus:
|
||||
params = {}
|
||||
# Set required parameters
|
||||
params["model"] = worker.model
|
||||
|
||||
combine_params_with_chat_task(worker, params, task)
|
||||
|
||||
# Make the API call
|
||||
try:
|
||||
response = await worker.async_client.chat.completions.create(**params)
|
||||
fill_chat_task_with_response(task, response)
|
||||
return TaskStatus.SUCCESS
|
||||
|
||||
except Exception as e:
|
||||
# Handle errors
|
||||
print('Openai chat client get exception: ' + str(e))
|
||||
return TaskStatus.WORKER_EXECEPTION
|
||||
19
tensorrt_llm/scaffolding/contrib/mcp/chat_task.py
Normal file
19
tensorrt_llm/scaffolding/contrib/mcp/chat_task.py
Normal file
@ -0,0 +1,19 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from tensorrt_llm.scaffolding import GenerationTask
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatTask(GenerationTask):
|
||||
messages: list = None
|
||||
tools = None
|
||||
finish_reason = None
|
||||
tool_calls = None
|
||||
|
||||
@staticmethod
|
||||
def create_from_prompt(messages: list, prompt: str, tools) -> "ChatTask":
|
||||
task = ChatTask()
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
task.messages = messages
|
||||
task.tools = tools
|
||||
return task
|
||||
78
tensorrt_llm/scaffolding/contrib/mcp/mcp_controller.py
Normal file
78
tensorrt_llm/scaffolding/contrib/mcp/mcp_controller.py
Normal file
@ -0,0 +1,78 @@
|
||||
import copy
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from tensorrt_llm.scaffolding import Controller, Task
|
||||
|
||||
from .chat_task import ChatTask
|
||||
from .mcp_task import MCPCallTask, MCPListTask
|
||||
|
||||
|
||||
class MCPController(Controller):
|
||||
|
||||
class WorkerTag(Enum):
|
||||
GENERATION = "generation"
|
||||
MCP = "mcp"
|
||||
|
||||
def __init__(self, custom_sampling_params: dict = None):
|
||||
super().__init__()
|
||||
self.custom_sampling_params = copy.deepcopy(
|
||||
custom_sampling_params) if custom_sampling_params else None
|
||||
|
||||
def process(self, tasks: List[Task], **kwargs):
|
||||
list_task = MCPListTask.create_mcptask()
|
||||
list_task.worker_tag = MCPController.WorkerTag.MCP
|
||||
yield [list_task]
|
||||
available_tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.inputSchema
|
||||
}
|
||||
} for tool in list_task.result_tools]
|
||||
|
||||
print(f"\navailable_tools {available_tools}\n")
|
||||
# return
|
||||
assert (len(tasks) == 1)
|
||||
system_message = (
|
||||
"You are a helpful assistant with access tools:\n\n"
|
||||
"After receiving a tool's response:\n"
|
||||
"1. Transform the raw data into a natural, conversational response\n"
|
||||
"2. Keep responses concise but informative\n"
|
||||
"3. Focus on the most relevant information\n"
|
||||
"4. Use appropriate context from the user's question\n"
|
||||
"5. Avoid simply repeating the raw data\n\n"
|
||||
"Please use only the tools that are explicitly defined above.")
|
||||
messages = [{"role": "system", "content": system_message}]
|
||||
chattask = ChatTask.create_from_prompt(messages, tasks[0].input_str,
|
||||
available_tools)
|
||||
result_task = tasks[0]
|
||||
chattask.worker_tag = self.WorkerTag.GENERATION
|
||||
if self.custom_sampling_params:
|
||||
for key, value in self.custom_sampling_params.items():
|
||||
if hasattr(tasks[0], key) and getattr(tasks[0], key) is None:
|
||||
setattr(tasks[0], key, value)
|
||||
yield [chattask]
|
||||
if chattask.finish_reason != 'tool_calls':
|
||||
result_task.output_str = chattask.output_str
|
||||
return
|
||||
tool_calls = chattask.tool_calls
|
||||
mcp_call_tasks = [
|
||||
MCPCallTask.create_mcptask(tool_call.function.name,
|
||||
tool_call.function.arguments)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
for task in mcp_call_tasks:
|
||||
task.worker_tag = MCPController.WorkerTag.MCP
|
||||
print(f"\nmcp_call_tasks is {mcp_call_tasks}\n")
|
||||
yield mcp_call_tasks
|
||||
mcp_result = mcp_call_tasks[0].output_str
|
||||
print(f"\nmcp_result is {mcp_result}\n")
|
||||
messages.append({"role": "assistant", "content": chattask.output_str})
|
||||
finalchattask = ChatTask.create_from_prompt(messages, mcp_result,
|
||||
available_tools)
|
||||
finalchattask.worker_tag = self.WorkerTag.GENERATION
|
||||
yield [finalchattask]
|
||||
result_task.output_str = finalchattask.output_str
|
||||
return
|
||||
45
tensorrt_llm/scaffolding/contrib/mcp/mcp_task.py
Normal file
45
tensorrt_llm/scaffolding/contrib/mcp/mcp_task.py
Normal file
@ -0,0 +1,45 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Union
|
||||
|
||||
from tensorrt_llm.scaffolding.task import Task
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPCallTask(Task):
|
||||
# mcp inputs
|
||||
tool_name: Optional[str] = field(default=None)
|
||||
args: Optional[dict] = field(default=None)
|
||||
# retrying control
|
||||
retry: Optional[int] = field(default=1)
|
||||
delay: Optional[float] = field(default=10)
|
||||
|
||||
worker_tag: Union[str, "Controller.WorkerTag"] = None
|
||||
|
||||
#result field
|
||||
result_str: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def create_mcptask(tool_name: str,
|
||||
args: dict,
|
||||
retry: int = 1,
|
||||
delay: float = 1) -> "MCPCallTask":
|
||||
task = MCPCallTask()
|
||||
task.tool_name = tool_name
|
||||
task.args = args
|
||||
task.retry = retry
|
||||
task.delay = delay
|
||||
return task
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPListTask(Task):
|
||||
worker_tag: Union[str, "Controller.WorkerTag"] = None
|
||||
|
||||
#result field
|
||||
result_str: Optional[str] = None
|
||||
result_tools = None
|
||||
|
||||
@staticmethod
|
||||
def create_mcptask() -> "MCPListTask":
|
||||
task = MCPListTask()
|
||||
return task
|
||||
60
tensorrt_llm/scaffolding/contrib/mcp/mcp_utils.py
Normal file
60
tensorrt_llm/scaffolding/contrib/mcp/mcp_utils.py
Normal file
@ -0,0 +1,60 @@
|
||||
import asyncio
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
load_dotenv() # load environment variables from .env
|
||||
|
||||
|
||||
class MCPClient:
|
||||
|
||||
def __init__(self):
|
||||
# Initialize session and client objects
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
async def list_tools(self):
|
||||
response = await self.session.list_tools()
|
||||
return response
|
||||
|
||||
async def call_tool(self, tool_name, tool_args):
|
||||
result = await self.session.call_tool(tool_name, tool_args)
|
||||
return result
|
||||
|
||||
async def connect_to_sse_server(self, server_url: str):
|
||||
"""Connect to an MCP server running with SSE transport"""
|
||||
streams_context = sse_client(url=server_url)
|
||||
streams = await self.exit_stack.enter_async_context(streams_context)
|
||||
|
||||
session_context = ClientSession(*streams)
|
||||
self.session = await self.exit_stack.enter_async_context(session_context
|
||||
)
|
||||
|
||||
# Initialize session
|
||||
await self.session.initialize()
|
||||
|
||||
async def cleanup(self):
|
||||
"""Properly clean up all registered async resources."""
|
||||
await self.exit_stack.aclose()
|
||||
|
||||
|
||||
async def main():
|
||||
if len(sys.argv) < 2:
|
||||
print(
|
||||
"Usage: uv run client.py <URL of SSE MCP server (i.e. http://localhost:8080/sse)>"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
client = MCPClient()
|
||||
try:
|
||||
await client.connect_to_sse_server(server_url=sys.argv[1])
|
||||
finally:
|
||||
await client.cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
asyncio.run(main())
|
||||
57
tensorrt_llm/scaffolding/contrib/mcp/mcp_worker.py
Normal file
57
tensorrt_llm/scaffolding/contrib/mcp/mcp_worker.py
Normal file
@ -0,0 +1,57 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from tensorrt_llm.scaffolding import TaskStatus, Worker
|
||||
|
||||
from .mcp_task import MCPCallTask, MCPListTask
|
||||
from .mcp_utils import MCPClient
|
||||
|
||||
|
||||
class MCPWorker(Worker):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_clients: List,
|
||||
):
|
||||
self.mcp_clients = mcp_clients
|
||||
|
||||
@classmethod
|
||||
async def init_with_urls(cls, urls):
|
||||
clients = []
|
||||
for url in urls:
|
||||
client = MCPClient()
|
||||
await client.connect_to_sse_server(server_url=url)
|
||||
clients.append(client)
|
||||
return cls(clients)
|
||||
|
||||
async def call_handler(self, task: MCPCallTask) -> TaskStatus:
|
||||
for mcp_client in self.mcp_clients:
|
||||
response = await mcp_client.list_tools()
|
||||
for tool in response.tools:
|
||||
if task.tool_name not in tool.name:
|
||||
continue
|
||||
print(f"\ncall handler {tool.name} and {task.tool_name}\n")
|
||||
tool_name = task.tool_name
|
||||
args = json.loads(task.args)
|
||||
response = await mcp_client.call_tool(tool_name, args)
|
||||
task.output_str = response.content[0].text
|
||||
return TaskStatus.SUCCESS
|
||||
|
||||
async def list_handler(self, task: MCPListTask) -> TaskStatus:
|
||||
result_tools = []
|
||||
for mcp_client in self.mcp_clients:
|
||||
response = await mcp_client.list_tools()
|
||||
result_tools.extend(response.tools)
|
||||
task.result_tools = result_tools
|
||||
return TaskStatus.SUCCESS
|
||||
|
||||
def shutdown(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
for mcp_client in self.mcp_clients:
|
||||
if loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(mcp_client.cleanup(), loop)
|
||||
else:
|
||||
loop.run_until_complete(self.mcp_client.cleanup())
|
||||
|
||||
task_handlers = {MCPListTask: list_handler, MCPCallTask: call_handler}
|
||||
Loading…
Reference in New Issue
Block a user