Scaffoldingllm supports MCP (#4410)

* support mcp

# Conflicts:
#	tensorrt_llm/scaffolding/worker.py

Signed-off-by: wu1du2 <wu1du2@gmail.com>

* move all into contrib/mcp

# Conflicts:
#	examples/scaffolding/contrib/mcp/mcptest.py
#	tensorrt_llm/scaffolding/__init__.py
#	tensorrt_llm/scaffolding/contrib/__init__.py
#	tensorrt_llm/scaffolding/contrib/mcp/__init__.py
#	tensorrt_llm/scaffolding/contrib/mcp/mcp_controller.py
#	tensorrt_llm/scaffolding/task.py
#	tensorrt_llm/scaffolding/worker.py

Signed-off-by: wu1du2 <wu1du2@gmail.com>

* support sandbox, websearch

# Conflicts:
#	examples/scaffolding/contrib/mcp/mcptest.py
#	examples/scaffolding/contrib/mcp/weather/weather.py
#	tensorrt_llm/scaffolding/contrib/mcp/mcp_controller.py
#	tensorrt_llm/scaffolding/contrib/mcp/mcp_utils.py
#	tensorrt_llm/scaffolding/contrib/mcp/mcp_worker.py
#	tensorrt_llm/scaffolding/worker.py

Signed-off-by: wu1du2 <wu1du2@gmail.com>

* remove pics

Signed-off-by: wu1du2 <wu1du2@gmail.com>

* pre-commit fix

# Conflicts:
#	tensorrt_llm/scaffolding/contrib/mcp/__init__.py
#	tensorrt_llm/scaffolding/contrib/mcp/mcp_utils.py
#	tensorrt_llm/scaffolding/contrib/mcp/mcp_worker.py

Signed-off-by: wu1du2 <wu1du2@gmail.com>

* fix spell

Signed-off-by: wu1du2 <wu1du2@gmail.com>

* rebase

Signed-off-by: wu1du2 <wu1du2@gmail.com>

---------

Signed-off-by: wu1du2 <wu1du2@gmail.com>
This commit is contained in:
Kunyao Wu 2025-05-23 09:54:49 +08:00 committed by GitHub
parent 338744fba6
commit 60a6c20174
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 815 additions and 0 deletions

View File

@ -0,0 +1,50 @@
# MCP USAGE
## Step1: Run Servers
### Terminal1:
`cd weather`
`pip install uv`
`uv add "mcp[cli]" httpx openai`
`uv pip install httpx mcp`
`uv init --no-workspace`
`uv run weather.py`
### Terminal2:
`cd e2b`
`pip install uv`
`uv add "mcp[cli]" httpx openai`
`uv pip install e2b_code_interpreter mcp`
`uv init --no-workspace`
`uv run e2bserver.py`
### Terminal3:
`cd websearch`
`pip install uv`
`uv add "mcp[cli]" httpx openai`
`uv pip install brave-search mcp starlette uvicorn`
`uv init --no-workspace`
`uv run websearch.py`
## Step2: Run Test
`python3 mcptest.py --API_KEY YOUR_API_KEY`

View File

@ -0,0 +1 @@
E2B_API_KEY=$YOUR_API_KEY

View File

@ -0,0 +1,92 @@
import logging
import uvicorn
from dotenv import load_dotenv
from e2b_code_interpreter import Sandbox
from mcp.server import Server
from mcp.server.fastmcp import FastMCP
from mcp.server.sse import SseServerTransport
from pydantic import BaseModel
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.routing import Mount, Route
# Initialize FastMCP server for Weather tools (SSE)
mcp = FastMCP("sandbox")
# Load environment variables
load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("e2b-mcp-server")
# Tool schema
class ToolSchema(BaseModel):
code: str
@mcp.tool()
async def run_code(code: str) -> str:
"""Run python code in a secure sandbox by E2B. Using the Jupyter Notebook syntax. Response include 1.results, the function return value. 2.stdout, the standard output. 3.stderr, the standard error.
Args:
code: string in Jupyter Notebook syntax.
"""
sbx = Sandbox()
execution = sbx.run_code(code)
logger.info(f"Execution: {execution}")
result = {
"results": execution.results,
"stdout": execution.logs.stdout,
"stderr": execution.logs.stderr,
}
return f"{result}"
def create_starlette_app(mcp_server: Server,
*,
debug: bool = False) -> Starlette:
"""Create a Starlette application that can server the provided mcp server with SSE."""
sse = SseServerTransport("/messages/")
async def handle_sse(request: Request) -> None:
async with sse.connect_sse(
request.scope,
request.receive,
request._send, # noqa: SLF001
) as (read_stream, write_stream):
await mcp_server.run(
read_stream,
write_stream,
mcp_server.create_initialization_options(),
)
return Starlette(
debug=debug,
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)
if __name__ == "__main__":
mcp_server = mcp._mcp_server # noqa: WPS437
import argparse
parser = argparse.ArgumentParser(description='Run MCP SSE-based server')
parser.add_argument('--host', default='0.0.0.0', help='Host to bind to')
parser.add_argument('--port',
type=int,
default=8081,
help='Port to listen on')
args = parser.parse_args()
# Bind SSE request handling to MCP server
starlette_app = create_starlette_app(mcp_server, debug=True)
uvicorn.run(starlette_app, host=args.host, port=args.port)

View File

@ -0,0 +1,6 @@
def main():
print("Hello from e2b!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,7 @@
[project]
name = "e2b"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = []

View File

@ -0,0 +1,77 @@
import argparse
import asyncio
from openai import AsyncOpenAI
from tensorrt_llm.scaffolding import OpenaiWorker, ScaffoldingLlm
from tensorrt_llm.scaffolding.contrib import (ChatTask, MCPController,
MCPWorker, chat_handler)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--base_url',
type=str,
default="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
parser.add_argument(
'--model',
type=str,
default="qwen-plus-latest",
)
parser.add_argument('--API_KEY', type=str)
args = parser.parse_args()
return args
from openai import AsyncOpenAI
from tensorrt_llm.scaffolding import OpenaiWorker, ScaffoldingLlm
from tensorrt_llm.scaffolding.contrib import MCPController, MCPWorker
async def main():
args = parse_arguments()
prompts = [
# "What's the weather like today in LA?"
# 'Solve the problem with running python code: What is the number of Fibonacci array 20th element? The array goes like 0,1,1,2,3...'
# 'Which game won TGA Best Action Game and Players Voice awards in 2024?'
'What was the score of the NBA playoffs game 7 between the Thunder and the Nuggets in 2025?'
]
API_KEY = args.API_KEY
urls = [
"http://0.0.0.0:8080/sse", "http://0.0.0.0:8081/sse",
"http://0.0.0.0:8082/sse"
]
print(f"API_KEY {API_KEY}")
client = AsyncOpenAI(api_key=API_KEY, base_url=args.base_url)
qwen_worker = OpenaiWorker(client, args.model)
qwen_worker.register_task_handler(ChatTask, chat_handler)
mcp_worker = await MCPWorker.init_with_urls(urls)
prototype_controller = MCPController()
llm = ScaffoldingLlm(
prototype_controller,
{
MCPController.WorkerTag.GENERATION: qwen_worker,
MCPController.WorkerTag.MCP: mcp_worker
},
)
future = llm.generate_async(prompts[0])
result = await future.aresult()
print(f"\nresult is {result.output.output_str}\n")
print(f'main shutting down...')
llm.shutdown()
print(f'worker shutting down...')
qwen_worker.shutdown()
mcp_worker.shutdown()
print(f'main shut down done')
return
if __name__ == '__main__':
asyncio.run(main())

View File

@ -0,0 +1,12 @@
[project]
name = "mcp"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = []
[tool.uv.workspace]
members = [
"e2b",
]

View File

@ -0,0 +1,144 @@
from typing import Any
import httpx
import uvicorn
from mcp.server import Server
from mcp.server.fastmcp import FastMCP
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.routing import Mount, Route
# Initialize FastMCP server for Weather tools (SSE)
mcp = FastMCP("weather")
# Constants
NWS_API_BASE = "https://api.weather.gov"
USER_AGENT = "weather-app/1.0"
async def make_nws_request(url: str) -> dict[str, Any] | None:
"""Make a request to the NWS API with proper error handling."""
headers = {"User-Agent": USER_AGENT, "Accept": "application/geo+json"}
async with httpx.AsyncClient() as client:
try:
response = await client.get(url, headers=headers, timeout=30.0)
response.raise_for_status()
return response.json()
except Exception:
return None
def format_alert(feature: dict) -> str:
"""Format an alert feature into a readable string."""
props = feature["properties"]
return f"""
Event: {props.get('event', 'Unknown')}
Area: {props.get('areaDesc', 'Unknown')}
Severity: {props.get('severity', 'Unknown')}
Description: {props.get('description', 'No description available')}
Instructions: {props.get('instruction', 'No specific instructions provided')}
"""
@mcp.tool()
async def get_alerts(state: str) -> str:
"""Get weather alerts for a US state.
Args:
state: Two-letter US state code (e.g. CA, NY)
"""
url = f"{NWS_API_BASE}/alerts/active/area/{state}"
data = await make_nws_request(url)
if not data or "features" not in data:
return "Unable to fetch alerts or no alerts found."
if not data["features"]:
return "No active alerts for this state."
alerts = [format_alert(feature) for feature in data["features"]]
return "\n---\n".join(alerts)
@mcp.tool()
async def get_forecast(latitude: float, longitude: float) -> str:
"""Get weather forecast for a location.
Args:
latitude: Latitude of the location
longitude: Longitude of the location
"""
# First get the forecast grid endpoint
points_url = f"{NWS_API_BASE}/points/{latitude},{longitude}"
points_data = await make_nws_request(points_url)
if not points_data:
return "Unable to fetch forecast data for this location."
# Get the forecast URL from the points response
forecast_url = points_data["properties"]["forecast"]
forecast_data = await make_nws_request(forecast_url)
if not forecast_data:
return "Unable to fetch detailed forecast."
# Format the periods into a readable forecast
periods = forecast_data["properties"]["periods"]
forecasts = []
for period in periods[:5]: # Only show next 5 periods
forecast = f"""
{period['name']}:
Temperature: {period['temperature']}°{period['temperatureUnit']}
Wind: {period['windSpeed']} {period['windDirection']}
Forecast: {period['detailedForecast']}
"""
forecasts.append(forecast)
return "\n---\n".join(forecasts)
def create_starlette_app(mcp_server: Server,
*,
debug: bool = False) -> Starlette:
"""Create a Starlette application that can server the provided mcp server with SSE."""
sse = SseServerTransport("/messages/")
async def handle_sse(request: Request) -> None:
async with sse.connect_sse(
request.scope,
request.receive,
request._send, # noqa: SLF001
) as (read_stream, write_stream):
await mcp_server.run(
read_stream,
write_stream,
mcp_server.create_initialization_options(),
)
return Starlette(
debug=debug,
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)
if __name__ == "__main__":
mcp_server = mcp._mcp_server # noqa: WPS437
import argparse
parser = argparse.ArgumentParser(description='Run MCP SSE-based server')
parser.add_argument('--host', default='0.0.0.0', help='Host to bind to')
parser.add_argument('--port',
type=int,
default=8080,
help='Port to listen on')
args = parser.parse_args()
# Bind SSE request handling to MCP server
starlette_app = create_starlette_app(mcp_server, debug=True)
uvicorn.run(starlette_app, host=args.host, port=args.port)

View File

@ -0,0 +1 @@
BRAVE_API_KEY=$YOUR_API_KEY

View File

@ -0,0 +1,6 @@
def main():
print("Hello from websearch!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,11 @@
[project]
name = "websearch"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"httpx>=0.28.1",
"mcp[cli]>=1.9.0",
"openai>=1.79.0",
]

View File

@ -0,0 +1,76 @@
import os
import uvicorn
from brave import Brave
from dotenv import load_dotenv
from mcp.server import Server
from mcp.server.fastmcp import FastMCP
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.routing import Mount, Route
# Initialize FastMCP server for Weather tools (SSE)
mcp = FastMCP("websearch")
# Load environment variables
load_dotenv()
@mcp.tool()
async def websearch(query: str) -> str:
"""Web search, fetch information from the internet
Args:
query: string of what you want to search
"""
BRAVE_API_KEY = os.getenv("BRAVE_API_KEY")
brave = Brave(BRAVE_API_KEY)
print(f"brave apikey {BRAVE_API_KEY }")
search_results = brave.search(q=query, raw=True)
return f"{search_results}"
def create_starlette_app(mcp_server: Server,
*,
debug: bool = False) -> Starlette:
"""Create a Starlette application that can server the provided mcp server with SSE."""
sse = SseServerTransport("/messages/")
async def handle_sse(request: Request) -> None:
async with sse.connect_sse(
request.scope,
request.receive,
request._send, # noqa: SLF001
) as (read_stream, write_stream):
await mcp_server.run(
read_stream,
write_stream,
mcp_server.create_initialization_options(),
)
return Starlette(
debug=debug,
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)
if __name__ == "__main__":
mcp_server = mcp._mcp_server # noqa: WPS437
import argparse
parser = argparse.ArgumentParser(description='Run MCP SSE-based server')
parser.add_argument('--host', default='0.0.0.0', help='Host to bind to')
parser.add_argument('--port',
type=int,
default=8082,
help='Port to listen on')
args = parser.parse_args()
# Bind SSE request handling to MCP server
starlette_app = create_starlette_app(mcp_server, debug=True)
uvicorn.run(starlette_app, host=args.host, port=args.port)

View File

@ -2,6 +2,8 @@ from tensorrt_llm.scaffolding import * # noqa
from .AsyncGeneration import StreamGenerationTask, stream_generation_handler
from .Dynasor import DynasorGenerationController
from .mcp import (ChatTask, MCPCallTask, MCPController, MCPListTask, MCPWorker,
chat_handler)
__all__ = [
# AsyncGeneration
@ -9,4 +11,11 @@ __all__ = [
"StreamGenerationTask",
# Dynasor
"DynasorGenerationController",
#mcp
"MCPController",
"MCPWorker",
"MCPCallTask",
"MCPListTask",
"ChatTask",
"chat_handler"
]

View File

@ -0,0 +1,10 @@
from .chat_handler import chat_handler
from .chat_task import ChatTask
from .mcp_controller import MCPController
from .mcp_task import MCPCallTask, MCPListTask
from .mcp_worker import MCPWorker
__all__ = [
"MCPController", "MCPWorker", "MCPCallTask", "MCPListTask", "ChatTask",
"chat_handler"
]

View File

@ -0,0 +1,54 @@
import openai
from tensorrt_llm.executor import GenerationExecutor
from tensorrt_llm.scaffolding import TaskStatus
from tensorrt_llm.scaffolding.contrib.mcp.chat_task import ChatTask
ExecutorCls = GenerationExecutor
# helper function
# add first non-None candidate_values to params with key
def add_param_if_not_none(params, key, candidate_values):
for value in candidate_values:
if value is not None:
params[key] = value
return
def combine_params_with_chat_task(worker, params: dict, task: ChatTask):
params["messages"] = task.messages
add_param_if_not_none(params, "max_tokens",
[task.max_tokens, worker.max_tokens])
add_param_if_not_none(params, "temperature",
[task.temperature, worker.temperature])
add_param_if_not_none(params, "top_p", [task.top_p, worker.top_p])
add_param_if_not_none(params, "tools", [task.tools])
def fill_chat_task_with_response(task: ChatTask, response: openai.Completion):
task.output_str = response.choices[0].message.content
task.finish_reason = response.choices[0].finish_reason
task.tool_calls = response.choices[0].message.tool_calls
task.logprobs = response.choices[0].logprobs
async def chat_handler(worker, task: ChatTask) -> TaskStatus:
params = {}
# Set required parameters
params["model"] = worker.model
combine_params_with_chat_task(worker, params, task)
# Make the API call
try:
response = await worker.async_client.chat.completions.create(**params)
fill_chat_task_with_response(task, response)
return TaskStatus.SUCCESS
except Exception as e:
# Handle errors
print('Openai chat client get exception: ' + str(e))
return TaskStatus.WORKER_EXECEPTION

View File

@ -0,0 +1,19 @@
from dataclasses import dataclass
from tensorrt_llm.scaffolding import GenerationTask
@dataclass
class ChatTask(GenerationTask):
messages: list = None
tools = None
finish_reason = None
tool_calls = None
@staticmethod
def create_from_prompt(messages: list, prompt: str, tools) -> "ChatTask":
task = ChatTask()
messages.append({"role": "user", "content": prompt})
task.messages = messages
task.tools = tools
return task

View File

@ -0,0 +1,78 @@
import copy
from enum import Enum
from typing import List
from tensorrt_llm.scaffolding import Controller, Task
from .chat_task import ChatTask
from .mcp_task import MCPCallTask, MCPListTask
class MCPController(Controller):
class WorkerTag(Enum):
GENERATION = "generation"
MCP = "mcp"
def __init__(self, custom_sampling_params: dict = None):
super().__init__()
self.custom_sampling_params = copy.deepcopy(
custom_sampling_params) if custom_sampling_params else None
def process(self, tasks: List[Task], **kwargs):
list_task = MCPListTask.create_mcptask()
list_task.worker_tag = MCPController.WorkerTag.MCP
yield [list_task]
available_tools = [{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.inputSchema
}
} for tool in list_task.result_tools]
print(f"\navailable_tools {available_tools}\n")
# return
assert (len(tasks) == 1)
system_message = (
"You are a helpful assistant with access tools:\n\n"
"After receiving a tool's response:\n"
"1. Transform the raw data into a natural, conversational response\n"
"2. Keep responses concise but informative\n"
"3. Focus on the most relevant information\n"
"4. Use appropriate context from the user's question\n"
"5. Avoid simply repeating the raw data\n\n"
"Please use only the tools that are explicitly defined above.")
messages = [{"role": "system", "content": system_message}]
chattask = ChatTask.create_from_prompt(messages, tasks[0].input_str,
available_tools)
result_task = tasks[0]
chattask.worker_tag = self.WorkerTag.GENERATION
if self.custom_sampling_params:
for key, value in self.custom_sampling_params.items():
if hasattr(tasks[0], key) and getattr(tasks[0], key) is None:
setattr(tasks[0], key, value)
yield [chattask]
if chattask.finish_reason != 'tool_calls':
result_task.output_str = chattask.output_str
return
tool_calls = chattask.tool_calls
mcp_call_tasks = [
MCPCallTask.create_mcptask(tool_call.function.name,
tool_call.function.arguments)
for tool_call in tool_calls
]
for task in mcp_call_tasks:
task.worker_tag = MCPController.WorkerTag.MCP
print(f"\nmcp_call_tasks is {mcp_call_tasks}\n")
yield mcp_call_tasks
mcp_result = mcp_call_tasks[0].output_str
print(f"\nmcp_result is {mcp_result}\n")
messages.append({"role": "assistant", "content": chattask.output_str})
finalchattask = ChatTask.create_from_prompt(messages, mcp_result,
available_tools)
finalchattask.worker_tag = self.WorkerTag.GENERATION
yield [finalchattask]
result_task.output_str = finalchattask.output_str
return

View File

@ -0,0 +1,45 @@
from dataclasses import dataclass, field
from typing import Optional, Union
from tensorrt_llm.scaffolding.task import Task
@dataclass
class MCPCallTask(Task):
# mcp inputs
tool_name: Optional[str] = field(default=None)
args: Optional[dict] = field(default=None)
# retrying control
retry: Optional[int] = field(default=1)
delay: Optional[float] = field(default=10)
worker_tag: Union[str, "Controller.WorkerTag"] = None
#result field
result_str: Optional[str] = None
@staticmethod
def create_mcptask(tool_name: str,
args: dict,
retry: int = 1,
delay: float = 1) -> "MCPCallTask":
task = MCPCallTask()
task.tool_name = tool_name
task.args = args
task.retry = retry
task.delay = delay
return task
@dataclass
class MCPListTask(Task):
worker_tag: Union[str, "Controller.WorkerTag"] = None
#result field
result_str: Optional[str] = None
result_tools = None
@staticmethod
def create_mcptask() -> "MCPListTask":
task = MCPListTask()
return task

View File

@ -0,0 +1,60 @@
import asyncio
from contextlib import AsyncExitStack
from typing import Optional
from dotenv import load_dotenv
from mcp import ClientSession
from mcp.client.sse import sse_client
load_dotenv() # load environment variables from .env
class MCPClient:
def __init__(self):
# Initialize session and client objects
self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack()
async def list_tools(self):
response = await self.session.list_tools()
return response
async def call_tool(self, tool_name, tool_args):
result = await self.session.call_tool(tool_name, tool_args)
return result
async def connect_to_sse_server(self, server_url: str):
"""Connect to an MCP server running with SSE transport"""
streams_context = sse_client(url=server_url)
streams = await self.exit_stack.enter_async_context(streams_context)
session_context = ClientSession(*streams)
self.session = await self.exit_stack.enter_async_context(session_context
)
# Initialize session
await self.session.initialize()
async def cleanup(self):
"""Properly clean up all registered async resources."""
await self.exit_stack.aclose()
async def main():
if len(sys.argv) < 2:
print(
"Usage: uv run client.py <URL of SSE MCP server (i.e. http://localhost:8080/sse)>"
)
sys.exit(1)
client = MCPClient()
try:
await client.connect_to_sse_server(server_url=sys.argv[1])
finally:
await client.cleanup()
if __name__ == "__main__":
import sys
asyncio.run(main())

View File

@ -0,0 +1,57 @@
import asyncio
import json
from typing import List
from tensorrt_llm.scaffolding import TaskStatus, Worker
from .mcp_task import MCPCallTask, MCPListTask
from .mcp_utils import MCPClient
class MCPWorker(Worker):
def __init__(
self,
mcp_clients: List,
):
self.mcp_clients = mcp_clients
@classmethod
async def init_with_urls(cls, urls):
clients = []
for url in urls:
client = MCPClient()
await client.connect_to_sse_server(server_url=url)
clients.append(client)
return cls(clients)
async def call_handler(self, task: MCPCallTask) -> TaskStatus:
for mcp_client in self.mcp_clients:
response = await mcp_client.list_tools()
for tool in response.tools:
if task.tool_name not in tool.name:
continue
print(f"\ncall handler {tool.name} and {task.tool_name}\n")
tool_name = task.tool_name
args = json.loads(task.args)
response = await mcp_client.call_tool(tool_name, args)
task.output_str = response.content[0].text
return TaskStatus.SUCCESS
async def list_handler(self, task: MCPListTask) -> TaskStatus:
result_tools = []
for mcp_client in self.mcp_clients:
response = await mcp_client.list_tools()
result_tools.extend(response.tools)
task.result_tools = result_tools
return TaskStatus.SUCCESS
def shutdown(self):
loop = asyncio.get_event_loop()
for mcp_client in self.mcp_clients:
if loop.is_running():
asyncio.run_coroutine_threadsafe(mcp_client.cleanup(), loop)
else:
loop.run_until_complete(self.mcp_client.cleanup())
task_handlers = {MCPListTask: list_handler, MCPCallTask: call_handler}