diff --git a/examples/scaffolding/contrib/mcp/README.md b/examples/scaffolding/contrib/mcp/README.md new file mode 100644 index 0000000000..ca5fac0499 --- /dev/null +++ b/examples/scaffolding/contrib/mcp/README.md @@ -0,0 +1,50 @@ +# MCP USAGE + +## Step1: Run Servers + +### Terminal1: + +`cd weather` + +`pip install uv` + +`uv add "mcp[cli]" httpx openai` + +`uv pip install httpx mcp` +`uv init --no-workspace` +`uv run weather.py` + + + +### Terminal2: + +`cd e2b` + +`pip install uv` + +`uv add "mcp[cli]" httpx openai` + +`uv pip install e2b_code_interpreter mcp` +`uv init --no-workspace` +`uv run e2bserver.py` + + + +### Terminal3: + +`cd websearch` + +`pip install uv` + +`uv add "mcp[cli]" httpx openai` +`uv pip install brave-search mcp starlette uvicorn` +`uv init --no-workspace` +`uv run websearch.py` + + + + + +## Step2: Run Test + +`python3 mcptest.py --API_KEY YOUR_API_KEY` diff --git a/examples/scaffolding/contrib/mcp/e2b/.env b/examples/scaffolding/contrib/mcp/e2b/.env new file mode 100644 index 0000000000..98f94af65c --- /dev/null +++ b/examples/scaffolding/contrib/mcp/e2b/.env @@ -0,0 +1 @@ +E2B_API_KEY=$YOUR_API_KEY diff --git a/examples/scaffolding/contrib/mcp/e2b/README.md b/examples/scaffolding/contrib/mcp/e2b/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/scaffolding/contrib/mcp/e2b/e2bserver.py b/examples/scaffolding/contrib/mcp/e2b/e2bserver.py new file mode 100644 index 0000000000..1818619fca --- /dev/null +++ b/examples/scaffolding/contrib/mcp/e2b/e2bserver.py @@ -0,0 +1,92 @@ +import logging + +import uvicorn +from dotenv import load_dotenv +from e2b_code_interpreter import Sandbox +from mcp.server import Server +from mcp.server.fastmcp import FastMCP +from mcp.server.sse import SseServerTransport +from pydantic import BaseModel +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.routing import Mount, Route + +# Initialize FastMCP server for Weather tools (SSE) +mcp = FastMCP("sandbox") + +# Load environment variables +load_dotenv() +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("e2b-mcp-server") + + +# Tool schema +class ToolSchema(BaseModel): + code: str + + +@mcp.tool() +async def run_code(code: str) -> str: + """Run python code in a secure sandbox by E2B. Using the Jupyter Notebook syntax. Response include 1.results, the function return value. 2.stdout, the standard output. 3.stderr, the standard error. + Args: + code: string in Jupyter Notebook syntax. + """ + + sbx = Sandbox() + execution = sbx.run_code(code) + logger.info(f"Execution: {execution}") + + result = { + "results": execution.results, + "stdout": execution.logs.stdout, + "stderr": execution.logs.stderr, + } + + return f"{result}" + + +def create_starlette_app(mcp_server: Server, + *, + debug: bool = False) -> Starlette: + """Create a Starlette application that can server the provided mcp server with SSE.""" + sse = SseServerTransport("/messages/") + + async def handle_sse(request: Request) -> None: + async with sse.connect_sse( + request.scope, + request.receive, + request._send, # noqa: SLF001 + ) as (read_stream, write_stream): + await mcp_server.run( + read_stream, + write_stream, + mcp_server.create_initialization_options(), + ) + + return Starlette( + debug=debug, + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse.handle_post_message), + ], + ) + + +if __name__ == "__main__": + mcp_server = mcp._mcp_server # noqa: WPS437 + + import argparse + + parser = argparse.ArgumentParser(description='Run MCP SSE-based server') + parser.add_argument('--host', default='0.0.0.0', help='Host to bind to') + parser.add_argument('--port', + type=int, + default=8081, + help='Port to listen on') + args = parser.parse_args() + + # Bind SSE request handling to MCP server + starlette_app = create_starlette_app(mcp_server, debug=True) + + uvicorn.run(starlette_app, host=args.host, port=args.port) diff --git a/examples/scaffolding/contrib/mcp/e2b/main.py b/examples/scaffolding/contrib/mcp/e2b/main.py new file mode 100644 index 0000000000..53e31d164c --- /dev/null +++ b/examples/scaffolding/contrib/mcp/e2b/main.py @@ -0,0 +1,6 @@ +def main(): + print("Hello from e2b!") + + +if __name__ == "__main__": + main() diff --git a/examples/scaffolding/contrib/mcp/e2b/pyproject.toml b/examples/scaffolding/contrib/mcp/e2b/pyproject.toml new file mode 100644 index 0000000000..a90d116104 --- /dev/null +++ b/examples/scaffolding/contrib/mcp/e2b/pyproject.toml @@ -0,0 +1,7 @@ +[project] +name = "e2b" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [] diff --git a/examples/scaffolding/contrib/mcp/mcptest.py b/examples/scaffolding/contrib/mcp/mcptest.py new file mode 100644 index 0000000000..ebfa29012d --- /dev/null +++ b/examples/scaffolding/contrib/mcp/mcptest.py @@ -0,0 +1,77 @@ +import argparse +import asyncio + +from openai import AsyncOpenAI + +from tensorrt_llm.scaffolding import OpenaiWorker, ScaffoldingLlm +from tensorrt_llm.scaffolding.contrib import (ChatTask, MCPController, + MCPWorker, chat_handler) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--base_url', + type=str, + default="https://dashscope.aliyuncs.com/compatible-mode/v1", + ) + parser.add_argument( + '--model', + type=str, + default="qwen-plus-latest", + ) + parser.add_argument('--API_KEY', type=str) + args = parser.parse_args() + return args + + +from openai import AsyncOpenAI + +from tensorrt_llm.scaffolding import OpenaiWorker, ScaffoldingLlm +from tensorrt_llm.scaffolding.contrib import MCPController, MCPWorker + + +async def main(): + args = parse_arguments() + prompts = [ + # "What's the weather like today in LA?" + # 'Solve the problem with running python code: What is the number of Fibonacci array 20th element? The array goes like 0,1,1,2,3...' + # 'Which game won TGA Best Action Game and Players Voice awards in 2024?' + 'What was the score of the NBA playoffs game 7 between the Thunder and the Nuggets in 2025?' + ] + API_KEY = args.API_KEY + urls = [ + "http://0.0.0.0:8080/sse", "http://0.0.0.0:8081/sse", + "http://0.0.0.0:8082/sse" + ] + print(f"API_KEY {API_KEY}") + client = AsyncOpenAI(api_key=API_KEY, base_url=args.base_url) + qwen_worker = OpenaiWorker(client, args.model) + qwen_worker.register_task_handler(ChatTask, chat_handler) + mcp_worker = await MCPWorker.init_with_urls(urls) + + prototype_controller = MCPController() + llm = ScaffoldingLlm( + prototype_controller, + { + MCPController.WorkerTag.GENERATION: qwen_worker, + MCPController.WorkerTag.MCP: mcp_worker + }, + ) + + future = llm.generate_async(prompts[0]) + result = await future.aresult() + print(f"\nresult is {result.output.output_str}\n") + + print(f'main shutting down...') + llm.shutdown() + print(f'worker shutting down...') + qwen_worker.shutdown() + mcp_worker.shutdown() + + print(f'main shut down done') + return + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/examples/scaffolding/contrib/mcp/weather/pyproject.toml b/examples/scaffolding/contrib/mcp/weather/pyproject.toml new file mode 100644 index 0000000000..a36f3b9a81 --- /dev/null +++ b/examples/scaffolding/contrib/mcp/weather/pyproject.toml @@ -0,0 +1,12 @@ +[project] +name = "mcp" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [] + +[tool.uv.workspace] +members = [ + "e2b", +] diff --git a/examples/scaffolding/contrib/mcp/weather/weather.py b/examples/scaffolding/contrib/mcp/weather/weather.py new file mode 100644 index 0000000000..fb542bbc00 --- /dev/null +++ b/examples/scaffolding/contrib/mcp/weather/weather.py @@ -0,0 +1,144 @@ +from typing import Any + +import httpx +import uvicorn +from mcp.server import Server +from mcp.server.fastmcp import FastMCP +from mcp.server.sse import SseServerTransport +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.routing import Mount, Route + +# Initialize FastMCP server for Weather tools (SSE) +mcp = FastMCP("weather") + +# Constants +NWS_API_BASE = "https://api.weather.gov" +USER_AGENT = "weather-app/1.0" + + +async def make_nws_request(url: str) -> dict[str, Any] | None: + """Make a request to the NWS API with proper error handling.""" + headers = {"User-Agent": USER_AGENT, "Accept": "application/geo+json"} + async with httpx.AsyncClient() as client: + try: + response = await client.get(url, headers=headers, timeout=30.0) + response.raise_for_status() + return response.json() + except Exception: + return None + + +def format_alert(feature: dict) -> str: + """Format an alert feature into a readable string.""" + props = feature["properties"] + return f""" +Event: {props.get('event', 'Unknown')} +Area: {props.get('areaDesc', 'Unknown')} +Severity: {props.get('severity', 'Unknown')} +Description: {props.get('description', 'No description available')} +Instructions: {props.get('instruction', 'No specific instructions provided')} +""" + + +@mcp.tool() +async def get_alerts(state: str) -> str: + """Get weather alerts for a US state. + + Args: + state: Two-letter US state code (e.g. CA, NY) + """ + url = f"{NWS_API_BASE}/alerts/active/area/{state}" + data = await make_nws_request(url) + + if not data or "features" not in data: + return "Unable to fetch alerts or no alerts found." + + if not data["features"]: + return "No active alerts for this state." + + alerts = [format_alert(feature) for feature in data["features"]] + return "\n---\n".join(alerts) + + +@mcp.tool() +async def get_forecast(latitude: float, longitude: float) -> str: + """Get weather forecast for a location. + + Args: + latitude: Latitude of the location + longitude: Longitude of the location + """ + # First get the forecast grid endpoint + points_url = f"{NWS_API_BASE}/points/{latitude},{longitude}" + points_data = await make_nws_request(points_url) + + if not points_data: + return "Unable to fetch forecast data for this location." + + # Get the forecast URL from the points response + forecast_url = points_data["properties"]["forecast"] + forecast_data = await make_nws_request(forecast_url) + + if not forecast_data: + return "Unable to fetch detailed forecast." + + # Format the periods into a readable forecast + periods = forecast_data["properties"]["periods"] + forecasts = [] + for period in periods[:5]: # Only show next 5 periods + forecast = f""" +{period['name']}: +Temperature: {period['temperature']}°{period['temperatureUnit']} +Wind: {period['windSpeed']} {period['windDirection']} +Forecast: {period['detailedForecast']} +""" + forecasts.append(forecast) + + return "\n---\n".join(forecasts) + + +def create_starlette_app(mcp_server: Server, + *, + debug: bool = False) -> Starlette: + """Create a Starlette application that can server the provided mcp server with SSE.""" + sse = SseServerTransport("/messages/") + + async def handle_sse(request: Request) -> None: + async with sse.connect_sse( + request.scope, + request.receive, + request._send, # noqa: SLF001 + ) as (read_stream, write_stream): + await mcp_server.run( + read_stream, + write_stream, + mcp_server.create_initialization_options(), + ) + + return Starlette( + debug=debug, + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse.handle_post_message), + ], + ) + + +if __name__ == "__main__": + mcp_server = mcp._mcp_server # noqa: WPS437 + + import argparse + + parser = argparse.ArgumentParser(description='Run MCP SSE-based server') + parser.add_argument('--host', default='0.0.0.0', help='Host to bind to') + parser.add_argument('--port', + type=int, + default=8080, + help='Port to listen on') + args = parser.parse_args() + + # Bind SSE request handling to MCP server + starlette_app = create_starlette_app(mcp_server, debug=True) + + uvicorn.run(starlette_app, host=args.host, port=args.port) diff --git a/examples/scaffolding/contrib/mcp/websearch/.env b/examples/scaffolding/contrib/mcp/websearch/.env new file mode 100644 index 0000000000..e32a584138 --- /dev/null +++ b/examples/scaffolding/contrib/mcp/websearch/.env @@ -0,0 +1 @@ +BRAVE_API_KEY=$YOUR_API_KEY diff --git a/examples/scaffolding/contrib/mcp/websearch/README.md b/examples/scaffolding/contrib/mcp/websearch/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/scaffolding/contrib/mcp/websearch/main.py b/examples/scaffolding/contrib/mcp/websearch/main.py new file mode 100644 index 0000000000..9d579b76c3 --- /dev/null +++ b/examples/scaffolding/contrib/mcp/websearch/main.py @@ -0,0 +1,6 @@ +def main(): + print("Hello from websearch!") + + +if __name__ == "__main__": + main() diff --git a/examples/scaffolding/contrib/mcp/websearch/pyproject.toml b/examples/scaffolding/contrib/mcp/websearch/pyproject.toml new file mode 100644 index 0000000000..9a228c47df --- /dev/null +++ b/examples/scaffolding/contrib/mcp/websearch/pyproject.toml @@ -0,0 +1,11 @@ +[project] +name = "websearch" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "httpx>=0.28.1", + "mcp[cli]>=1.9.0", + "openai>=1.79.0", +] diff --git a/examples/scaffolding/contrib/mcp/websearch/websearch.py b/examples/scaffolding/contrib/mcp/websearch/websearch.py new file mode 100644 index 0000000000..1df2217455 --- /dev/null +++ b/examples/scaffolding/contrib/mcp/websearch/websearch.py @@ -0,0 +1,76 @@ +import os + +import uvicorn +from brave import Brave +from dotenv import load_dotenv +from mcp.server import Server +from mcp.server.fastmcp import FastMCP +from mcp.server.sse import SseServerTransport +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.routing import Mount, Route + +# Initialize FastMCP server for Weather tools (SSE) +mcp = FastMCP("websearch") + +# Load environment variables +load_dotenv() + + +@mcp.tool() +async def websearch(query: str) -> str: + """Web search, fetch information from the internet + Args: + query: string of what you want to search + """ + BRAVE_API_KEY = os.getenv("BRAVE_API_KEY") + brave = Brave(BRAVE_API_KEY) + print(f"brave apikey {BRAVE_API_KEY }") + search_results = brave.search(q=query, raw=True) + return f"{search_results}" + + +def create_starlette_app(mcp_server: Server, + *, + debug: bool = False) -> Starlette: + """Create a Starlette application that can server the provided mcp server with SSE.""" + sse = SseServerTransport("/messages/") + + async def handle_sse(request: Request) -> None: + async with sse.connect_sse( + request.scope, + request.receive, + request._send, # noqa: SLF001 + ) as (read_stream, write_stream): + await mcp_server.run( + read_stream, + write_stream, + mcp_server.create_initialization_options(), + ) + + return Starlette( + debug=debug, + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse.handle_post_message), + ], + ) + + +if __name__ == "__main__": + mcp_server = mcp._mcp_server # noqa: WPS437 + + import argparse + + parser = argparse.ArgumentParser(description='Run MCP SSE-based server') + parser.add_argument('--host', default='0.0.0.0', help='Host to bind to') + parser.add_argument('--port', + type=int, + default=8082, + help='Port to listen on') + args = parser.parse_args() + + # Bind SSE request handling to MCP server + starlette_app = create_starlette_app(mcp_server, debug=True) + + uvicorn.run(starlette_app, host=args.host, port=args.port) diff --git a/tensorrt_llm/scaffolding/contrib/__init__.py b/tensorrt_llm/scaffolding/contrib/__init__.py index 789666b5dc..29ef484f41 100644 --- a/tensorrt_llm/scaffolding/contrib/__init__.py +++ b/tensorrt_llm/scaffolding/contrib/__init__.py @@ -2,6 +2,8 @@ from tensorrt_llm.scaffolding import * # noqa from .AsyncGeneration import StreamGenerationTask, stream_generation_handler from .Dynasor import DynasorGenerationController +from .mcp import (ChatTask, MCPCallTask, MCPController, MCPListTask, MCPWorker, + chat_handler) __all__ = [ # AsyncGeneration @@ -9,4 +11,11 @@ __all__ = [ "StreamGenerationTask", # Dynasor "DynasorGenerationController", + #mcp + "MCPController", + "MCPWorker", + "MCPCallTask", + "MCPListTask", + "ChatTask", + "chat_handler" ] diff --git a/tensorrt_llm/scaffolding/contrib/mcp/__init__.py b/tensorrt_llm/scaffolding/contrib/mcp/__init__.py new file mode 100644 index 0000000000..d1d685c19e --- /dev/null +++ b/tensorrt_llm/scaffolding/contrib/mcp/__init__.py @@ -0,0 +1,10 @@ +from .chat_handler import chat_handler +from .chat_task import ChatTask +from .mcp_controller import MCPController +from .mcp_task import MCPCallTask, MCPListTask +from .mcp_worker import MCPWorker + +__all__ = [ + "MCPController", "MCPWorker", "MCPCallTask", "MCPListTask", "ChatTask", + "chat_handler" +] diff --git a/tensorrt_llm/scaffolding/contrib/mcp/chat_handler.py b/tensorrt_llm/scaffolding/contrib/mcp/chat_handler.py new file mode 100644 index 0000000000..90f10e0808 --- /dev/null +++ b/tensorrt_llm/scaffolding/contrib/mcp/chat_handler.py @@ -0,0 +1,54 @@ +import openai + +from tensorrt_llm.executor import GenerationExecutor +from tensorrt_llm.scaffolding import TaskStatus +from tensorrt_llm.scaffolding.contrib.mcp.chat_task import ChatTask + +ExecutorCls = GenerationExecutor + + +# helper function +# add first non-None candidate_values to params with key +def add_param_if_not_none(params, key, candidate_values): + for value in candidate_values: + if value is not None: + params[key] = value + return + + +def combine_params_with_chat_task(worker, params: dict, task: ChatTask): + params["messages"] = task.messages + + add_param_if_not_none(params, "max_tokens", + [task.max_tokens, worker.max_tokens]) + add_param_if_not_none(params, "temperature", + [task.temperature, worker.temperature]) + add_param_if_not_none(params, "top_p", [task.top_p, worker.top_p]) + + add_param_if_not_none(params, "tools", [task.tools]) + + +def fill_chat_task_with_response(task: ChatTask, response: openai.Completion): + task.output_str = response.choices[0].message.content + task.finish_reason = response.choices[0].finish_reason + task.tool_calls = response.choices[0].message.tool_calls + task.logprobs = response.choices[0].logprobs + + +async def chat_handler(worker, task: ChatTask) -> TaskStatus: + params = {} + # Set required parameters + params["model"] = worker.model + + combine_params_with_chat_task(worker, params, task) + + # Make the API call + try: + response = await worker.async_client.chat.completions.create(**params) + fill_chat_task_with_response(task, response) + return TaskStatus.SUCCESS + + except Exception as e: + # Handle errors + print('Openai chat client get exception: ' + str(e)) + return TaskStatus.WORKER_EXECEPTION diff --git a/tensorrt_llm/scaffolding/contrib/mcp/chat_task.py b/tensorrt_llm/scaffolding/contrib/mcp/chat_task.py new file mode 100644 index 0000000000..1f4d72d5eb --- /dev/null +++ b/tensorrt_llm/scaffolding/contrib/mcp/chat_task.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass + +from tensorrt_llm.scaffolding import GenerationTask + + +@dataclass +class ChatTask(GenerationTask): + messages: list = None + tools = None + finish_reason = None + tool_calls = None + + @staticmethod + def create_from_prompt(messages: list, prompt: str, tools) -> "ChatTask": + task = ChatTask() + messages.append({"role": "user", "content": prompt}) + task.messages = messages + task.tools = tools + return task diff --git a/tensorrt_llm/scaffolding/contrib/mcp/mcp_controller.py b/tensorrt_llm/scaffolding/contrib/mcp/mcp_controller.py new file mode 100644 index 0000000000..c27ab35e53 --- /dev/null +++ b/tensorrt_llm/scaffolding/contrib/mcp/mcp_controller.py @@ -0,0 +1,78 @@ +import copy +from enum import Enum +from typing import List + +from tensorrt_llm.scaffolding import Controller, Task + +from .chat_task import ChatTask +from .mcp_task import MCPCallTask, MCPListTask + + +class MCPController(Controller): + + class WorkerTag(Enum): + GENERATION = "generation" + MCP = "mcp" + + def __init__(self, custom_sampling_params: dict = None): + super().__init__() + self.custom_sampling_params = copy.deepcopy( + custom_sampling_params) if custom_sampling_params else None + + def process(self, tasks: List[Task], **kwargs): + list_task = MCPListTask.create_mcptask() + list_task.worker_tag = MCPController.WorkerTag.MCP + yield [list_task] + available_tools = [{ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.inputSchema + } + } for tool in list_task.result_tools] + + print(f"\navailable_tools {available_tools}\n") + # return + assert (len(tasks) == 1) + system_message = ( + "You are a helpful assistant with access tools:\n\n" + "After receiving a tool's response:\n" + "1. Transform the raw data into a natural, conversational response\n" + "2. Keep responses concise but informative\n" + "3. Focus on the most relevant information\n" + "4. Use appropriate context from the user's question\n" + "5. Avoid simply repeating the raw data\n\n" + "Please use only the tools that are explicitly defined above.") + messages = [{"role": "system", "content": system_message}] + chattask = ChatTask.create_from_prompt(messages, tasks[0].input_str, + available_tools) + result_task = tasks[0] + chattask.worker_tag = self.WorkerTag.GENERATION + if self.custom_sampling_params: + for key, value in self.custom_sampling_params.items(): + if hasattr(tasks[0], key) and getattr(tasks[0], key) is None: + setattr(tasks[0], key, value) + yield [chattask] + if chattask.finish_reason != 'tool_calls': + result_task.output_str = chattask.output_str + return + tool_calls = chattask.tool_calls + mcp_call_tasks = [ + MCPCallTask.create_mcptask(tool_call.function.name, + tool_call.function.arguments) + for tool_call in tool_calls + ] + for task in mcp_call_tasks: + task.worker_tag = MCPController.WorkerTag.MCP + print(f"\nmcp_call_tasks is {mcp_call_tasks}\n") + yield mcp_call_tasks + mcp_result = mcp_call_tasks[0].output_str + print(f"\nmcp_result is {mcp_result}\n") + messages.append({"role": "assistant", "content": chattask.output_str}) + finalchattask = ChatTask.create_from_prompt(messages, mcp_result, + available_tools) + finalchattask.worker_tag = self.WorkerTag.GENERATION + yield [finalchattask] + result_task.output_str = finalchattask.output_str + return diff --git a/tensorrt_llm/scaffolding/contrib/mcp/mcp_task.py b/tensorrt_llm/scaffolding/contrib/mcp/mcp_task.py new file mode 100644 index 0000000000..6fe50e0350 --- /dev/null +++ b/tensorrt_llm/scaffolding/contrib/mcp/mcp_task.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass, field +from typing import Optional, Union + +from tensorrt_llm.scaffolding.task import Task + + +@dataclass +class MCPCallTask(Task): + # mcp inputs + tool_name: Optional[str] = field(default=None) + args: Optional[dict] = field(default=None) + # retrying control + retry: Optional[int] = field(default=1) + delay: Optional[float] = field(default=10) + + worker_tag: Union[str, "Controller.WorkerTag"] = None + + #result field + result_str: Optional[str] = None + + @staticmethod + def create_mcptask(tool_name: str, + args: dict, + retry: int = 1, + delay: float = 1) -> "MCPCallTask": + task = MCPCallTask() + task.tool_name = tool_name + task.args = args + task.retry = retry + task.delay = delay + return task + + +@dataclass +class MCPListTask(Task): + worker_tag: Union[str, "Controller.WorkerTag"] = None + + #result field + result_str: Optional[str] = None + result_tools = None + + @staticmethod + def create_mcptask() -> "MCPListTask": + task = MCPListTask() + return task diff --git a/tensorrt_llm/scaffolding/contrib/mcp/mcp_utils.py b/tensorrt_llm/scaffolding/contrib/mcp/mcp_utils.py new file mode 100644 index 0000000000..10d332639b --- /dev/null +++ b/tensorrt_llm/scaffolding/contrib/mcp/mcp_utils.py @@ -0,0 +1,60 @@ +import asyncio +from contextlib import AsyncExitStack +from typing import Optional + +from dotenv import load_dotenv +from mcp import ClientSession +from mcp.client.sse import sse_client + +load_dotenv() # load environment variables from .env + + +class MCPClient: + + def __init__(self): + # Initialize session and client objects + self.session: Optional[ClientSession] = None + self.exit_stack = AsyncExitStack() + + async def list_tools(self): + response = await self.session.list_tools() + return response + + async def call_tool(self, tool_name, tool_args): + result = await self.session.call_tool(tool_name, tool_args) + return result + + async def connect_to_sse_server(self, server_url: str): + """Connect to an MCP server running with SSE transport""" + streams_context = sse_client(url=server_url) + streams = await self.exit_stack.enter_async_context(streams_context) + + session_context = ClientSession(*streams) + self.session = await self.exit_stack.enter_async_context(session_context + ) + + # Initialize session + await self.session.initialize() + + async def cleanup(self): + """Properly clean up all registered async resources.""" + await self.exit_stack.aclose() + + +async def main(): + if len(sys.argv) < 2: + print( + "Usage: uv run client.py " + ) + sys.exit(1) + + client = MCPClient() + try: + await client.connect_to_sse_server(server_url=sys.argv[1]) + finally: + await client.cleanup() + + +if __name__ == "__main__": + import sys + asyncio.run(main()) diff --git a/tensorrt_llm/scaffolding/contrib/mcp/mcp_worker.py b/tensorrt_llm/scaffolding/contrib/mcp/mcp_worker.py new file mode 100644 index 0000000000..85bb628134 --- /dev/null +++ b/tensorrt_llm/scaffolding/contrib/mcp/mcp_worker.py @@ -0,0 +1,57 @@ +import asyncio +import json +from typing import List + +from tensorrt_llm.scaffolding import TaskStatus, Worker + +from .mcp_task import MCPCallTask, MCPListTask +from .mcp_utils import MCPClient + + +class MCPWorker(Worker): + + def __init__( + self, + mcp_clients: List, + ): + self.mcp_clients = mcp_clients + + @classmethod + async def init_with_urls(cls, urls): + clients = [] + for url in urls: + client = MCPClient() + await client.connect_to_sse_server(server_url=url) + clients.append(client) + return cls(clients) + + async def call_handler(self, task: MCPCallTask) -> TaskStatus: + for mcp_client in self.mcp_clients: + response = await mcp_client.list_tools() + for tool in response.tools: + if task.tool_name not in tool.name: + continue + print(f"\ncall handler {tool.name} and {task.tool_name}\n") + tool_name = task.tool_name + args = json.loads(task.args) + response = await mcp_client.call_tool(tool_name, args) + task.output_str = response.content[0].text + return TaskStatus.SUCCESS + + async def list_handler(self, task: MCPListTask) -> TaskStatus: + result_tools = [] + for mcp_client in self.mcp_clients: + response = await mcp_client.list_tools() + result_tools.extend(response.tools) + task.result_tools = result_tools + return TaskStatus.SUCCESS + + def shutdown(self): + loop = asyncio.get_event_loop() + for mcp_client in self.mcp_clients: + if loop.is_running(): + asyncio.run_coroutine_threadsafe(mcp_client.cleanup(), loop) + else: + loop.run_until_complete(self.mcp_client.cleanup()) + + task_handlers = {MCPListTask: list_handler, MCPCallTask: call_handler}