mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Deep Research Implemented with Scaffolding (#8452)
Signed-off-by: Yi Sun <yisun0618@gmail.com>
This commit is contained in:
parent
6bbb43f2b9
commit
cc12d33393
@ -0,0 +1,10 @@
|
||||
# Tavily MCP
|
||||
|
||||
This is a MCP server for the Tavily API. It is used to search the web for information.
|
||||
|
||||
## Usage
|
||||
|
||||
```
|
||||
export TAVILY_API_KEY=<your_api_key>
|
||||
uv run travily.py
|
||||
```
|
||||
@ -0,0 +1,12 @@
|
||||
[project]
|
||||
name = "TavilyMCP"
|
||||
version = "0.1.0"
|
||||
description = "An MCP server for searching information"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"httpx>=0.28.1",
|
||||
"mcp[cli]>=1.9.0",
|
||||
"openai>=1.79.0",
|
||||
"tavily-python",
|
||||
]
|
||||
@ -0,0 +1,72 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
from mcp.server import Server
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.routing import Mount, Route
|
||||
from tavily import TavilyClient
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Initialize FastMCP server for Weather tools (SSE)
|
||||
mcp = FastMCP("tavily_search")
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def tavily_search(query: str) -> str:
|
||||
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
|
||||
|
||||
client = TavilyClient(TAVILY_API_KEY)
|
||||
response = client.search(query=query)
|
||||
|
||||
search_result = ""
|
||||
for result in response["results"]:
|
||||
search_result += f"{result['title']}: {result['content']}\n"
|
||||
|
||||
return search_result
|
||||
|
||||
|
||||
def create_starlette_app(mcp_server: Server, *, debug: bool = False) -> Starlette:
|
||||
"""Create a Starlette application that can server the provided mcp server with SSE."""
|
||||
sse = SseServerTransport("/messages/")
|
||||
|
||||
async def handle_sse(request: Request) -> None:
|
||||
async with sse.connect_sse(
|
||||
request.scope,
|
||||
request.receive,
|
||||
request._send, # noqa: SLF001
|
||||
) as (read_stream, write_stream):
|
||||
await mcp_server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
mcp_server.create_initialization_options(),
|
||||
)
|
||||
|
||||
return Starlette(
|
||||
debug=debug,
|
||||
routes=[
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp_server = mcp._mcp_server # noqa: WPS437
|
||||
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run MCP SSE-based server")
|
||||
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
|
||||
parser.add_argument("--port", type=int, default=8082, help="Port to listen on")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Bind SSE request handling to MCP server
|
||||
starlette_app = create_starlette_app(mcp_server, debug=True)
|
||||
|
||||
uvicorn.run(starlette_app, host=args.host, port=args.port)
|
||||
@ -0,0 +1,64 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from tensorrt_llm.scaffolding import OpenaiWorker, ScaffoldingLlm
|
||||
from tensorrt_llm.scaffolding.contrib.DeepResearch import Researcher, Supervisor
|
||||
from tensorrt_llm.scaffolding.contrib.mcp import ChatTask, MCPController, MCPWorker, chat_handler
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--max_research_iter", type=int, default=10)
|
||||
parser.add_argument("--max_concurrent_research_units", type=int, default=10)
|
||||
parser.add_argument("--openai_api_key", type=str, required=True)
|
||||
parser.add_argument("--base_url", type=str, default="http://localhost:8000/v1")
|
||||
parser.add_argument("--model", type=str, default="gpt-oss-20b")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def main():
|
||||
args = parse_arguments()
|
||||
client = AsyncOpenAI(api_key=args.openai_api_key, base_url=args.base_url)
|
||||
|
||||
generation_worker = OpenaiWorker(client, args.model)
|
||||
generation_worker.register_task_handler(ChatTask, chat_handler)
|
||||
|
||||
mcp_worker = await MCPWorker.init_with_urls(["http://0.0.0.0:8082/sse"])
|
||||
|
||||
supervisor = Supervisor(
|
||||
max_research_iter=args.max_research_iter,
|
||||
max_concurrent_research_units=args.max_concurrent_research_units,
|
||||
)
|
||||
|
||||
llm = ScaffoldingLlm(
|
||||
prototype_controller=supervisor,
|
||||
workers={
|
||||
Supervisor.WorkerTag.GENERATION: generation_worker,
|
||||
Researcher.WorkerTag.GENERATION: generation_worker,
|
||||
MCPController.WorkerTag.MCP: mcp_worker,
|
||||
},
|
||||
)
|
||||
|
||||
prompt = """
|
||||
From 2020 to 2050, how many elderly people will there be in Japan? What is their consumption \
|
||||
potential across various aspects such as clothing, food, housing, and transportation? \
|
||||
Based on population projections, elderly consumer willingness, and potential changes in their \
|
||||
consumption habits, please produce a market size analysis report for the elderly demographic.
|
||||
"""
|
||||
|
||||
future = llm.generate_async(prompt)
|
||||
result = await future.aresult()
|
||||
|
||||
print(result.outputs[0].text)
|
||||
|
||||
llm.shutdown()
|
||||
generation_worker.shutdown()
|
||||
mcp_worker.shutdown()
|
||||
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -0,0 +1,4 @@
|
||||
from .researcher import Researcher
|
||||
from .supervisor import Supervisor
|
||||
|
||||
__all__ = ["Supervisor", "Researcher"]
|
||||
274
tensorrt_llm/scaffolding/contrib/DeepResearch/prompts.py
Normal file
274
tensorrt_llm/scaffolding/contrib/DeepResearch/prompts.py
Normal file
@ -0,0 +1,274 @@
|
||||
# ruff: noqa: E501
|
||||
|
||||
generate_research_brief_prompt = """You will be given a set of messages that have been exchanged so far between yourself and the user.
|
||||
Your job is to translate these messages into a more detailed and concrete research question that will be used to guide the research.
|
||||
|
||||
The messages that have been exchanged so far between yourself and the user are:
|
||||
<Messages>
|
||||
{messages}
|
||||
</Messages>
|
||||
|
||||
Today's date is {date}.
|
||||
|
||||
You will return a single research question that will be used to guide the research.
|
||||
|
||||
Guidelines:
|
||||
1. Maximize Specificity and Detail
|
||||
- Include all known user preferences and explicitly list key attributes or dimensions to consider.
|
||||
- It is important that all details from the user are included in the instructions.
|
||||
|
||||
2. Fill in Unstated But Necessary Dimensions as Open-Ended
|
||||
- If certain attributes are essential for a meaningful output but the user has not provided them, explicitly state that they are open-ended or default to no specific constraint.
|
||||
|
||||
3. Avoid Unwarranted Assumptions
|
||||
- If the user has not provided a particular detail, do not invent one.
|
||||
- Instead, state the lack of specification and guide the researcher to treat it as flexible or accept all possible options.
|
||||
|
||||
4. Use the First Person
|
||||
- Phrase the request from the perspective of the user.
|
||||
|
||||
5. Sources
|
||||
- If specific sources should be prioritized, specify them in the research question.
|
||||
- For product and travel research, prefer linking directly to official or primary websites (e.g., official brand sites, manufacturer pages, or reputable e-commerce platforms like Amazon for user reviews) rather than aggregator sites or SEO-heavy blogs.
|
||||
- For academic or scientific queries, prefer linking directly to the original paper or official journal publication rather than survey papers or secondary summaries.
|
||||
- For people, try linking directly to their LinkedIn profile, or their personal website if they have one.
|
||||
- If the query is in a specific language, prioritize sources published in that language.
|
||||
"""
|
||||
|
||||
supervisor_system_prompt = """You are a research supervisor. Your job is to conduct research by calling the "conduct_research" tool. For context, today's date is {date}.
|
||||
|
||||
<Task>
|
||||
Your focus is to call the "conduct_research" tool to conduct research against the overall research question passed in by the user.
|
||||
When you are completely satisfied with the research findings returned from the tool calls, then you should call the "complete_research" tool to indicate that you are done with your research.
|
||||
</Task>
|
||||
|
||||
<Available Tools>
|
||||
You have access to three main tools:
|
||||
1. **conduct_research**: Delegate research tasks to specialized sub-agents
|
||||
2. **complete_research**: Indicate that research is complete
|
||||
3. **think_tool**: For reflection and strategic planning during research
|
||||
|
||||
**CRITICAL: Use think_tool before calling conduct_research to plan your approach, and after each conduct_research to assess progress. Do not call think_tool with any other tools in parallel.**
|
||||
|
||||
A typical workflow is
|
||||
|
||||
1. Use think_tool to plan the sub-tasks.
|
||||
2. For each the planned sub-tasks, call conduct_research to conduct the research, and then use think_tool to assess the progress.
|
||||
3. Call complete_research to indicate that you are done with your research.
|
||||
|
||||
You should use think_tool to plan the sub-tasks only once. Do not call think_tools multiple times in a row, which is unnecessary and will waste resources.
|
||||
|
||||
</Available Tools>
|
||||
|
||||
<Instructions>
|
||||
Think like a research manager with limited time and resources. Follow these steps:
|
||||
|
||||
1. **Read the question carefully** - What specific information does the user need?
|
||||
2. **Decide how to delegate the research** - Carefully consider the question and decide how to delegate the research. Are there multiple independent directions that can be explored simultaneously?
|
||||
3. **After each call to conduct_research, pause and assess** - Do I have enough to answer? What's still missing?
|
||||
</Instructions>
|
||||
|
||||
<Hard Limits>
|
||||
**Task Delegation Budgets** (Prevent excessive delegation):
|
||||
- **Bias towards single agent** - Use single agent for simplicity unless the user request has clear opportunity for parallelization
|
||||
- **Stop when you can answer confidently** - Don't keep delegating research for perfection
|
||||
- **Limit tool calls** - Always stop after {max_researcher_iterations} tool calls to conduct_research and think_tool if you cannot find the right sources
|
||||
|
||||
**Maximum {max_concurrent_research_units} parallel agents per iteration**
|
||||
</Hard Limits>
|
||||
|
||||
<Show Your Thinking>
|
||||
Before you call conduct_research tool call, use think_tool to plan your approach:
|
||||
- Can the task be broken down into smaller sub-tasks?
|
||||
|
||||
After each conduct_research tool call, use think_tool to analyze the results:
|
||||
- What key information did I find?
|
||||
- What's missing?
|
||||
- Do I have enough to answer the question comprehensively?
|
||||
- Should I delegate more research or call complete_research?
|
||||
</Show Your Thinking>
|
||||
|
||||
<Scaling Rules>
|
||||
**Simple fact-finding, lists, and rankings** can use a single sub-agent:
|
||||
- *Example*: List the top 10 coffee shops in San Francisco → Use 1 sub-agent
|
||||
|
||||
**Comparisons presented in the user request** can use a sub-agent for each element of the comparison:
|
||||
- *Example*: Compare OpenAI vs. Anthropic vs. DeepMind approaches to AI safety → Use 3 sub-agents
|
||||
- Delegate clear, distinct, non-overlapping subtopics
|
||||
|
||||
**Important Reminders:**
|
||||
- Each conduct_research call spawns a dedicated research agent for that specific topic
|
||||
- A separate agent will write the final report - you just need to gather information
|
||||
- When calling conduct_research, provide complete standalone instructions - sub-agents can't see other agents' work
|
||||
- Do NOT use acronyms or abbreviations in your research questions, be very clear and specific
|
||||
</Scaling Rules>"""
|
||||
|
||||
research_system_prompt = """You are a research assistant conducting research on the user's input topic. For context, today's date is {date}.
|
||||
|
||||
<Task>
|
||||
Your job is to use tools to gather information about the user's input topic.
|
||||
You can use any of the tools provided to you to find resources that can help answer the research question. You can call these tools in series or in parallel, your research is conducted in a tool-calling loop.
|
||||
</Task>
|
||||
|
||||
<Available Tools>
|
||||
You have access to two main tools:
|
||||
1. **tavily_search**: For conducting web searches to gather information
|
||||
2. **reflection**: For reflection and strategic planning during research
|
||||
|
||||
**CRITICAL: Use reflection after each search to reflect on results and plan next steps. Do not call reflection with the tavily_search or any other tools. It should be to reflect on the results of the search.**
|
||||
</Available Tools>
|
||||
|
||||
<Instructions>
|
||||
Think like a human researcher with limited time. Follow these steps:
|
||||
|
||||
1. **Read the question carefully** - What specific information does the user need?
|
||||
2. **Start with broader searches** - Use broad, comprehensive queries first
|
||||
3. **After each search, pause and assess** - Do I have enough to answer? What's still missing?
|
||||
4. **Execute narrower searches as you gather information** - Fill in the gaps
|
||||
5. **Stop when you can answer confidently** - Don't keep searching for perfection
|
||||
</Instructions>
|
||||
|
||||
<Hard Limits>
|
||||
**Tool Call Budgets** (Prevent excessive searching):
|
||||
- **Simple queries**: Use 2-3 search tool calls maximum
|
||||
- **Complex queries**: Use up to 5 search tool calls maximum
|
||||
- **Always stop**: After 5 search tool calls if you cannot find the right sources
|
||||
|
||||
**Stop Immediately When**:
|
||||
- You can answer the user's question comprehensively
|
||||
- You have 3+ relevant examples/sources for the question
|
||||
- Your last 2 searches returned similar information
|
||||
</Hard Limits>
|
||||
|
||||
<Show Your Thinking>
|
||||
After each search tool call, use reflection to analyze the results:
|
||||
- What key information did I find?
|
||||
- What's missing?
|
||||
- Do I have enough to answer the question comprehensively?
|
||||
- Should I search more or provide my answer?
|
||||
</Show Your Thinking>
|
||||
"""
|
||||
|
||||
compress_system_prompt = """You are a research assistant that has conducted research on a topic by calling several tools and web searches. Your job is now to clean up the findings, but preserve all of the relevant statements and information that the researcher has gathered. For context, today's date is {date}.
|
||||
|
||||
<Task>
|
||||
You need to clean up information gathered from tool calls and web searches in the existing messages.
|
||||
All relevant information should be repeated and rewritten verbatim, but in a cleaner format.
|
||||
The purpose of this step is just to remove any obviously irrelevant or duplicative information.
|
||||
For example, if three sources all say "X", you could say "These three sources all stated X".
|
||||
Only these fully comprehensive cleaned findings are going to be returned to the user, so it's crucial that you don't lose any information from the raw messages.
|
||||
</Task>
|
||||
|
||||
<Guidelines>
|
||||
1. Your output findings should be fully comprehensive and include ALL of the information and sources that the researcher has gathered from tool calls and web searches. It is expected that you repeat key information verbatim.
|
||||
2. This report can be as long as necessary to return ALL of the information that the researcher has gathered.
|
||||
3. In your report, you should return inline citations for each source that the researcher found.
|
||||
4. You should include a "Sources" section at the end of the report that lists all of the sources the researcher found with corresponding citations, cited against statements in the report.
|
||||
5. Make sure to include ALL of the sources that the researcher gathered in the report, and how they were used to answer the question!
|
||||
6. It's really important not to lose any sources. A later LLM will be used to merge this report with others, so having all of the sources is critical.
|
||||
</Guidelines>
|
||||
|
||||
<Output Format>
|
||||
The report should be structured like this:
|
||||
**List of Queries and Tool Calls Made**
|
||||
**Fully Comprehensive Findings**
|
||||
**List of All Relevant Sources (with citations in the report)**
|
||||
</Output Format>
|
||||
|
||||
<Citation Rules>
|
||||
- Assign each unique URL a single citation number in your text
|
||||
- End with ### Sources that lists each source with corresponding numbers
|
||||
- IMPORTANT: Number sources sequentially without gaps (1,2,3,4...) in the final list regardless of which sources you choose
|
||||
- Example format:
|
||||
[1] Source Title: URL
|
||||
[2] Source Title: URL
|
||||
</Citation Rules>
|
||||
|
||||
Critical Reminder: It is extremely important that any information that is even remotely relevant to the user's research topic is preserved verbatim (e.g. don't rewrite it, don't summarize it, don't paraphrase it).
|
||||
"""
|
||||
|
||||
compress_research_simple_human_message = """All above messages are about research conducted by an AI Researcher. Please clean up these findings.
|
||||
|
||||
DO NOT summarize the information. I want the raw information returned, just in a cleaner format. Make sure all relevant information is preserved - you can rewrite findings verbatim."""
|
||||
|
||||
final_report_generation_prompt = """Based on all the research conducted, create a comprehensive, well-structured answer to the overall research brief:
|
||||
<Research Brief>
|
||||
{research_brief}
|
||||
</Research Brief>
|
||||
|
||||
For more context, here is all of the messages so far. Focus on the research brief above, but consider these messages as well for more context.
|
||||
<Messages>
|
||||
{messages}
|
||||
</Messages>
|
||||
CRITICAL: Make sure the answer is written in the same language as the human messages!
|
||||
For example, if the user's messages are in English, then MAKE SURE you write your response in English. If the user's messages are in Chinese, then MAKE SURE you write your entire response in Chinese.
|
||||
This is critical. The user will only understand the answer if it is written in the same language as their input message.
|
||||
|
||||
Today's date is {date}.
|
||||
|
||||
Here are the findings from the research that you conducted:
|
||||
<Findings>
|
||||
{findings}
|
||||
</Findings>
|
||||
|
||||
Please create a detailed answer to the overall research brief that:
|
||||
1. Is well-organized with proper headings (# for title, ## for sections, ### for subsections)
|
||||
2. Includes specific facts and insights from the research
|
||||
3. References relevant sources using [Title](URL) format
|
||||
4. Provides a balanced, thorough analysis. Be as comprehensive as possible, and include all information that is relevant to the overall research question. People are using you for deep research and will expect detailed, comprehensive answers.
|
||||
5. Includes a "Sources" section at the end with all referenced links
|
||||
|
||||
You can structure your report in a number of different ways. Here are some examples:
|
||||
|
||||
To answer a question that asks you to compare two things, you might structure your report like this:
|
||||
1/ intro
|
||||
2/ overview of topic A
|
||||
3/ overview of topic B
|
||||
4/ comparison between A and B
|
||||
5/ conclusion
|
||||
|
||||
To answer a question that asks you to return a list of things, you might only need a single section which is the entire list.
|
||||
1/ list of things or table of things
|
||||
Or, you could choose to make each item in the list a separate section in the report. When asked for lists, you don't need an introduction or conclusion.
|
||||
1/ item 1
|
||||
2/ item 2
|
||||
3/ item 3
|
||||
|
||||
To answer a question that asks you to summarize a topic, give a report, or give an overview, you might structure your report like this:
|
||||
1/ overview of topic
|
||||
2/ concept 1
|
||||
3/ concept 2
|
||||
4/ concept 3
|
||||
5/ conclusion
|
||||
|
||||
If you think you can answer the question with a single section, you can do that too!
|
||||
1/ answer
|
||||
|
||||
REMEMBER: Section is a VERY fluid and loose concept. You can structure your report however you think is best, including in ways that are not listed above!
|
||||
Make sure that your sections are cohesive, and make sense for the reader.
|
||||
|
||||
For each section of the report, do the following:
|
||||
- Use simple, clear language
|
||||
- Use ## for section title (Markdown format) for each section of the report
|
||||
- Do NOT ever refer to yourself as the writer of the report. This should be a professional report without any self-referential language.
|
||||
- Do not say what you are doing in the report. Just write the report without any commentary from yourself.
|
||||
- Each section should be as long as necessary to deeply answer the question with the information you have gathered. It is expected that sections will be fairly long and verbose. You are writing a deep research report, and users will expect a thorough answer.
|
||||
- Use bullet points to list out information when appropriate, but by default, write in paragraph form.
|
||||
|
||||
REMEMBER:
|
||||
The brief and research may be in English, but you need to translate this information to the right language when writing the final answer.
|
||||
Make sure the final answer report is in the SAME language as the human messages in the message history.
|
||||
|
||||
Format the report in clear markdown with proper structure and include source references where appropriate.
|
||||
|
||||
<Citation Rules>
|
||||
- Assign each unique URL a single citation number in your text
|
||||
- End with ### Sources that lists each source with corresponding numbers
|
||||
- IMPORTANT: Number sources sequentially without gaps (1,2,3,4...) in the final list regardless of which sources you choose
|
||||
- Each source should be a separate line item in a list, so that in markdown it is rendered as a list.
|
||||
- Example format:
|
||||
[1] Source Title: URL
|
||||
[2] Source Title: URL
|
||||
- Citations are extremely important. Make sure to include these, and pay a lot of attention to getting these right. Users will often use these citations to look into more information.
|
||||
</Citation Rules>
|
||||
"""
|
||||
125
tensorrt_llm/scaffolding/contrib/DeepResearch/researcher.py
Normal file
125
tensorrt_llm/scaffolding/contrib/DeepResearch/researcher.py
Normal file
@ -0,0 +1,125 @@
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from tensorrt_llm.scaffolding import Controller, Task
|
||||
from tensorrt_llm.scaffolding.contrib.mcp import ChatTask, MCPCallTask, MCPController
|
||||
|
||||
from .prompts import (
|
||||
compress_research_simple_human_message,
|
||||
compress_system_prompt,
|
||||
research_system_prompt,
|
||||
)
|
||||
from .utils import AssistantMessage, SystemMessage, UserMessage, get_today_str
|
||||
|
||||
LOGGER = logging.getLogger()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResearchTask(Task):
|
||||
research_topic: str = field(default=None)
|
||||
research_result: str = field(default=None)
|
||||
|
||||
@staticmethod
|
||||
def from_topic(topic: str) -> "ResearchTask":
|
||||
task = ResearchTask()
|
||||
task.research_topic = topic
|
||||
task.research_result = ""
|
||||
return task
|
||||
|
||||
|
||||
class Researcher(Controller):
|
||||
class WorkerTag(Enum):
|
||||
GENERATION = "generation"
|
||||
|
||||
def __init__(self, max_tools_iter: int = 3, max_compress_iter: int = 3):
|
||||
# TODO: Add more tools (e.g., MCP tools) beyond search.
|
||||
self.tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "tavily_search",
|
||||
"description": "For conducting web searches to gather information",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "reflection",
|
||||
"description": "For reflection and strategic planning during research",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"reflection": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
self.max_tools_iter = max_tools_iter
|
||||
self.max_compress_iter = max_compress_iter
|
||||
|
||||
def process(self, research_tasks: List[ResearchTask], **kwargs):
|
||||
for research_task in research_tasks:
|
||||
research_prompt_messages = [
|
||||
SystemMessage(research_system_prompt.format(date=get_today_str())).to_dict(),
|
||||
UserMessage(research_task.research_topic).to_dict(),
|
||||
]
|
||||
|
||||
research_tools_messages = []
|
||||
chat_with_tools_task = ChatTask.from_messages(
|
||||
research_prompt_messages + research_tools_messages, self.tools
|
||||
)
|
||||
chat_with_tools_task.worker_tag = Researcher.WorkerTag.GENERATION
|
||||
|
||||
for _ in range(self.max_tools_iter):
|
||||
yield [chat_with_tools_task]
|
||||
|
||||
if chat_with_tools_task.finish_reason != "tool_calls":
|
||||
break
|
||||
|
||||
if chat_with_tools_task.output_str is not None:
|
||||
research_tools_messages.append(
|
||||
AssistantMessage(chat_with_tools_task.output_str).to_dict()
|
||||
)
|
||||
|
||||
mcp_call_tasks = [
|
||||
MCPCallTask.create_mcptask(
|
||||
tool_call.function.name, tool_call.function.arguments
|
||||
)
|
||||
for tool_call in chat_with_tools_task.tool_calls
|
||||
]
|
||||
|
||||
for mcp_call_task in mcp_call_tasks:
|
||||
mcp_call_task.worker_tag = MCPController.WorkerTag.MCP
|
||||
|
||||
yield mcp_call_tasks
|
||||
|
||||
for mcp_call_task in mcp_call_tasks:
|
||||
research_tools_messages.append(UserMessage(mcp_call_task.output_str).to_dict())
|
||||
|
||||
chat_with_tools_task = ChatTask.from_messages(
|
||||
research_prompt_messages + research_tools_messages, self.tools
|
||||
)
|
||||
chat_with_tools_task.worker_tag = Researcher.WorkerTag.GENERATION
|
||||
|
||||
compress_prompt_messages = [
|
||||
SystemMessage(compress_system_prompt.format(date=get_today_str())).to_dict()
|
||||
]
|
||||
|
||||
compress_messages = research_tools_messages + [
|
||||
UserMessage(compress_research_simple_human_message).to_dict()
|
||||
]
|
||||
compress_task = ChatTask.from_messages(compress_prompt_messages + compress_messages)
|
||||
compress_task.worker_tag = Researcher.WorkerTag.GENERATION
|
||||
|
||||
for _ in range(self.max_compress_iter):
|
||||
yield [compress_task]
|
||||
research_task.research_result = compress_task.output_str
|
||||
if compress_task.finish_reason == "finish":
|
||||
break
|
||||
return
|
||||
200
tensorrt_llm/scaffolding/contrib/DeepResearch/supervisor.py
Normal file
200
tensorrt_llm/scaffolding/contrib/DeepResearch/supervisor.py
Normal file
@ -0,0 +1,200 @@
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from tensorrt_llm.scaffolding.contrib.mcp import ChatTask
|
||||
from tensorrt_llm.scaffolding.controller import Controller
|
||||
from tensorrt_llm.scaffolding.task import Task
|
||||
|
||||
from .prompts import (
|
||||
final_report_generation_prompt,
|
||||
generate_research_brief_prompt,
|
||||
supervisor_system_prompt,
|
||||
)
|
||||
from .researcher import Researcher, ResearchTask
|
||||
from .utils import AssistantMessage, SystemMessage, UserMessage, get_today_str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SupervisorTask(Task):
|
||||
user_prompt: str = field(default=None)
|
||||
research_brief: str = field(default=None)
|
||||
final_report: str = field(default=None)
|
||||
|
||||
@staticmethod
|
||||
def create_from_prompt(prompt: str) -> "SupervisorTask":
|
||||
task = SupervisorTask()
|
||||
task.user_prompt = prompt
|
||||
task.research_brief = None
|
||||
task.final_report = None
|
||||
return task
|
||||
|
||||
|
||||
class Supervisor(Controller):
|
||||
class WorkerTag(Enum):
|
||||
GENERATION = "generation"
|
||||
|
||||
def __init__(self, max_research_iter: int = 3, max_concurrent_research_units: int = 3):
|
||||
super().__init__()
|
||||
self.max_research_iter = max_research_iter
|
||||
self.max_concurrent_research_units = max_concurrent_research_units
|
||||
|
||||
# TODO: Definition of researcher tools subject to certain specifications.
|
||||
# TODO: Add more tools (e.g., MCP tools) beyond search.
|
||||
self.researcher_tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "conduct_research",
|
||||
"description": "Conduct research on a given topic",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"research_topic": {
|
||||
"type": "string",
|
||||
"description": "The topic of the research",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "complete_research",
|
||||
"description": "Complete the research",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "think_tool",
|
||||
"description": "Think about the research",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"think": {
|
||||
"type": "string",
|
||||
"description": "The reflection of the research",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
self.researcher_controller = Researcher()
|
||||
|
||||
def process(self, tasks: List[Task], **kwargs):
|
||||
supervisor_task = tasks[0]
|
||||
|
||||
# For now, user messages only contain the user's prompt. Later, the user's
|
||||
# interactions with the supervisor (e.g., clarifying the research question)
|
||||
# can be added.
|
||||
user_messages = [UserMessage(content=supervisor_task.input_str).to_dict()]
|
||||
|
||||
# Generate research brief by wrapping the user original prompt with the
|
||||
# system prompt for generating research brief.
|
||||
research_brief_messages = [
|
||||
UserMessage(
|
||||
generate_research_brief_prompt.format(date=get_today_str(), messages=user_messages)
|
||||
).to_dict()
|
||||
]
|
||||
|
||||
research_brief_task = ChatTask.from_messages(messages=research_brief_messages)
|
||||
research_brief_task.worker_tag = Supervisor.WorkerTag.GENERATION
|
||||
|
||||
yield [research_brief_task]
|
||||
|
||||
research_brief = research_brief_task.output_str
|
||||
|
||||
supervisor_prompt_messages = [
|
||||
SystemMessage(
|
||||
supervisor_system_prompt.format(
|
||||
date=get_today_str(),
|
||||
max_researcher_iterations=self.max_research_iter,
|
||||
max_concurrent_research_units=self.max_concurrent_research_units,
|
||||
)
|
||||
).to_dict(),
|
||||
UserMessage(research_brief).to_dict(),
|
||||
]
|
||||
|
||||
# TODO: Clarify the research brief with the user.
|
||||
# The messages that the supervisor clarify the research brief with the user.
|
||||
|
||||
# The messages that the supervisor use tools to conduct research.
|
||||
supervisor_tools_messages = []
|
||||
|
||||
chat_with_tools_task = ChatTask.from_messages(
|
||||
messages=supervisor_prompt_messages + supervisor_tools_messages, tools=self.tools
|
||||
)
|
||||
chat_with_tools_task.worker_tag = Supervisor.WorkerTag.GENERATION
|
||||
|
||||
for _ in range(self.max_research_iter):
|
||||
yield [chat_with_tools_task]
|
||||
if chat_with_tools_task.finish_reason != "tool_calls":
|
||||
break
|
||||
|
||||
if chat_with_tools_task.output_str is not None:
|
||||
supervisor_tools_messages.append(
|
||||
AssistantMessage(chat_with_tools_task.output_str).to_dict()
|
||||
)
|
||||
|
||||
research_tasks = []
|
||||
|
||||
for tool_call in chat_with_tools_task.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
|
||||
supervisor_tools_messages.append(
|
||||
AssistantMessage(
|
||||
f"I have called the tool {tool_name} with arguments: {tool_call.function.arguments}"
|
||||
).to_dict()
|
||||
)
|
||||
|
||||
if tool_name == "think_tool":
|
||||
supervisor_tools_messages.append(
|
||||
UserMessage(f"Reflection recorded: {arguments['think']}").to_dict()
|
||||
)
|
||||
elif tool_name == "conduct_research":
|
||||
research_tasks.append(ResearchTask.from_topic(arguments["research_topic"]))
|
||||
|
||||
elif tool_name == "complete_research":
|
||||
break
|
||||
|
||||
# In a single research iteration, the supervisor may invoke multiple tools.
|
||||
# For example, it might generate several research topics and assign them
|
||||
# concurrently to multiple researchers. We gather these in research_tasks
|
||||
# to take advantage of the researcher_controller's capability to process
|
||||
# them concurrently.
|
||||
if len(research_tasks) != 0:
|
||||
yield from self.researcher_controller.process(research_tasks)
|
||||
for task in research_tasks:
|
||||
supervisor_tools_messages.append(UserMessage(task.research_result).to_dict())
|
||||
|
||||
chat_with_tools_task = ChatTask.from_messages(
|
||||
messages=supervisor_prompt_messages + supervisor_tools_messages, tools=self.tools
|
||||
)
|
||||
chat_with_tools_task.worker_tag = Supervisor.WorkerTag.GENERATION
|
||||
|
||||
final_report_generation_task = ChatTask.from_messages(
|
||||
messages=[
|
||||
SystemMessage(
|
||||
final_report_generation_prompt.format(
|
||||
research_brief=research_brief,
|
||||
messages=supervisor_prompt_messages + supervisor_tools_messages,
|
||||
findings=supervisor_tools_messages,
|
||||
date=get_today_str(),
|
||||
)
|
||||
).to_dict()
|
||||
]
|
||||
)
|
||||
final_report_generation_task.worker_tag = Supervisor.WorkerTag.GENERATION
|
||||
|
||||
yield [final_report_generation_task]
|
||||
|
||||
final_report = final_report_generation_task.output_str
|
||||
|
||||
tasks[0].output_str = final_report
|
||||
return
|
||||
50
tensorrt_llm/scaffolding/contrib/DeepResearch/utils.py
Normal file
50
tensorrt_llm/scaffolding/contrib/DeepResearch/utils.py
Normal file
@ -0,0 +1,50 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def get_today_str() -> str:
|
||||
"""Get current date formatted for display in prompts and outputs.
|
||||
|
||||
Returns:
|
||||
Human-readable date string in format like 'Mon Jan 15, 2024'
|
||||
"""
|
||||
now = datetime.now()
|
||||
return f"{now:%a} {now:%b} {now.day}, {now:%Y}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoleMessage:
|
||||
role: str
|
||||
content: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.role}: {self.content}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.role}: {self.content}\n"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]):
|
||||
return cls(role=data["role"], content=data["content"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserMessage(RoleMessage):
|
||||
def __init__(self, content: str):
|
||||
super().__init__(role="user", content=content)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssistantMessage(RoleMessage):
|
||||
def __init__(self, content: str):
|
||||
super().__init__(role="assistant", content=content)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemMessage(RoleMessage):
|
||||
def __init__(self, content: str):
|
||||
super().__init__(role="system", content=content)
|
||||
@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from tensorrt_llm.scaffolding import GenerationTask
|
||||
|
||||
@ -17,3 +18,12 @@ class ChatTask(GenerationTask):
|
||||
task.messages = messages
|
||||
task.tools = tools
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def from_messages(
|
||||
messages: List[str],
|
||||
tools: Optional[List[Dict[str, Any]]] = None) -> "ChatTask":
|
||||
task = ChatTask()
|
||||
task.messages = messages
|
||||
task.tools = tools
|
||||
return task
|
||||
|
||||
Loading…
Reference in New Issue
Block a user