mirror of
https://github.com/langgenius/dify.git
synced 2026-01-23 04:02:21 +08:00
Signed-off-by: -LAN- <laipz8200@outlook.com> Signed-off-by: kenwoodjw <blackxin55+@gmail.com> Signed-off-by: Yongtao Huang <yongtaoh2022@gmail.com> Signed-off-by: yihong0618 <zouzou0208@gmail.com> Signed-off-by: zhanluxianshen <zhanluxianshen@163.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: GuanMu <ballmanjq@gmail.com> Co-authored-by: Davide Delbianco <davide.delbianco@outlook.com> Co-authored-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Co-authored-by: kenwoodjw <blackxin55+@gmail.com> Co-authored-by: Yongtao Huang <yongtaoh2022@gmail.com> Co-authored-by: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com> Co-authored-by: Qiang Lee <18018968632@163.com> Co-authored-by: 李强04 <liqiang04@gaotu.cn> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: Matri Qi <matrixdom@126.com> Co-authored-by: huayaoyue6 <huayaoyue@163.com> Co-authored-by: Bowen Liang <liangbowen@gf.com.cn> Co-authored-by: znn <jubinkumarsoni@gmail.com> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: yihong <zouzou0208@gmail.com> Co-authored-by: Muke Wang <shaodwaaron@gmail.com> Co-authored-by: wangmuke <wangmuke@kingsware.cn> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: quicksand <quicksandzn@gmail.com> Co-authored-by: 非法操作 <hjlarry@163.com> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: Eric Guo <eric.guocz@gmail.com> Co-authored-by: Zhedong Cen <cenzhedong2@126.com> Co-authored-by: jiangbo721 <jiangbo721@163.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: hjlarry <25834719+hjlarry@users.noreply.github.com> Co-authored-by: lxsummer <35754229+lxjustdoit@users.noreply.github.com> Co-authored-by: 湛露先生 <zhanluxianshen@163.com> Co-authored-by: Guangdong Liu <liugddx@gmail.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Yessenia-d <yessenia.contact@gmail.com> Co-authored-by: huangzhuo1949 <167434202+huangzhuo1949@users.noreply.github.com> Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com> Co-authored-by: 17hz <0x149527@gmail.com> Co-authored-by: Amy <1530140574@qq.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Nite Knite <nkCoding@gmail.com> Co-authored-by: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Co-authored-by: Petrus Han <petrus.hanks@gmail.com> Co-authored-by: iamjoel <2120155+iamjoel@users.noreply.github.com> Co-authored-by: Kalo Chin <frog.beepers.0n@icloud.com> Co-authored-by: Ujjwal Maurya <ujjwalsbx@gmail.com> Co-authored-by: Maries <xh001x@hotmail.com>
162 lines
6.0 KiB
Python
162 lines
6.0 KiB
Python
import logging
|
|
from collections.abc import Callable
|
|
from contextlib import AbstractContextManager, ExitStack
|
|
from types import TracebackType
|
|
from typing import Any, Optional, cast
|
|
from urllib.parse import urlparse
|
|
|
|
from core.mcp.client.sse_client import sse_client
|
|
from core.mcp.client.streamable_client import streamablehttp_client
|
|
from core.mcp.error import MCPAuthError, MCPConnectionError
|
|
from core.mcp.session.client_session import ClientSession
|
|
from core.mcp.types import Tool
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MCPClient:
|
|
def __init__(
|
|
self,
|
|
server_url: str,
|
|
provider_id: str,
|
|
tenant_id: str,
|
|
authed: bool = True,
|
|
authorization_code: Optional[str] = None,
|
|
for_list: bool = False,
|
|
headers: Optional[dict[str, str]] = None,
|
|
timeout: Optional[float] = None,
|
|
sse_read_timeout: Optional[float] = None,
|
|
):
|
|
# Initialize info
|
|
self.provider_id = provider_id
|
|
self.tenant_id = tenant_id
|
|
self.client_type = "streamable"
|
|
self.server_url = server_url
|
|
self.headers = headers or {}
|
|
self.timeout = timeout
|
|
self.sse_read_timeout = sse_read_timeout
|
|
|
|
# Authentication info
|
|
self.authed = authed
|
|
self.authorization_code = authorization_code
|
|
if authed:
|
|
from core.mcp.auth.auth_provider import OAuthClientProvider
|
|
|
|
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
|
|
self.token = self.provider.tokens()
|
|
|
|
# Initialize session and client objects
|
|
self._session: Optional[ClientSession] = None
|
|
self._streams_context: Optional[AbstractContextManager[Any]] = None
|
|
self._session_context: Optional[ClientSession] = None
|
|
self._exit_stack = ExitStack()
|
|
|
|
# Whether the client has been initialized
|
|
self._initialized = False
|
|
|
|
def __enter__(self):
|
|
self._initialize()
|
|
self._initialized = True
|
|
return self
|
|
|
|
def __exit__(
|
|
self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[TracebackType]
|
|
):
|
|
self.cleanup()
|
|
|
|
def _initialize(
|
|
self,
|
|
):
|
|
"""Initialize the client with fallback to SSE if streamable connection fails"""
|
|
connection_methods: dict[str, Callable[..., AbstractContextManager[Any]]] = {
|
|
"mcp": streamablehttp_client,
|
|
"sse": sse_client,
|
|
}
|
|
|
|
parsed_url = urlparse(self.server_url)
|
|
path = parsed_url.path or ""
|
|
method_name = path.rstrip("/").split("/")[-1] if path else ""
|
|
if method_name in connection_methods:
|
|
client_factory = connection_methods[method_name]
|
|
self.connect_server(client_factory, method_name)
|
|
else:
|
|
try:
|
|
logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
|
|
self.connect_server(sse_client, "sse")
|
|
except MCPConnectionError:
|
|
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
|
|
self.connect_server(streamablehttp_client, "mcp")
|
|
|
|
def connect_server(
|
|
self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
|
|
):
|
|
from core.mcp.auth.auth_flow import auth
|
|
|
|
try:
|
|
headers = (
|
|
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
|
|
if self.authed and self.token
|
|
else self.headers
|
|
)
|
|
self._streams_context = client_factory(
|
|
url=self.server_url,
|
|
headers=headers,
|
|
timeout=self.timeout,
|
|
sse_read_timeout=self.sse_read_timeout,
|
|
)
|
|
if not self._streams_context:
|
|
raise MCPConnectionError("Failed to create connection context")
|
|
|
|
# Use exit_stack to manage context managers properly
|
|
if method_name == "mcp":
|
|
read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
|
|
streams = (read_stream, write_stream)
|
|
else: # sse_client
|
|
streams = self._exit_stack.enter_context(self._streams_context)
|
|
|
|
self._session_context = ClientSession(*streams)
|
|
self._session = self._exit_stack.enter_context(self._session_context)
|
|
session = cast(ClientSession, self._session)
|
|
session.initialize()
|
|
return
|
|
|
|
except MCPAuthError:
|
|
if not self.authed:
|
|
raise
|
|
try:
|
|
auth(self.provider, self.server_url, self.authorization_code)
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to authenticate: {e}")
|
|
self.token = self.provider.tokens()
|
|
if first_try:
|
|
return self.connect_server(client_factory, method_name, first_try=False)
|
|
|
|
def list_tools(self) -> list[Tool]:
|
|
"""Connect to an MCP server running with SSE transport"""
|
|
# List available tools to verify connection
|
|
if not self._initialized or not self._session:
|
|
raise ValueError("Session not initialized.")
|
|
response = self._session.list_tools()
|
|
tools = response.tools
|
|
return tools
|
|
|
|
def invoke_tool(self, tool_name: str, tool_args: dict):
|
|
"""Call a tool"""
|
|
if not self._initialized or not self._session:
|
|
raise ValueError("Session not initialized.")
|
|
return self._session.call_tool(tool_name, tool_args)
|
|
|
|
def cleanup(self):
|
|
"""Clean up resources"""
|
|
try:
|
|
# ExitStack will handle proper cleanup of all managed context managers
|
|
self._exit_stack.close()
|
|
except Exception as e:
|
|
logger.exception("Error during cleanup")
|
|
raise ValueError(f"Error during cleanup: {e}")
|
|
finally:
|
|
self._session = None
|
|
self._session_context = None
|
|
self._streams_context = None
|
|
self._initialized = False
|