mirror of
https://github.com/langgenius/dify.git
synced 2026-01-14 06:07:33 +08:00
chore: use from __future__ import annotations (#30254)
Some checks are pending
autofix.ci / autofix (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Some checks are pending
autofix.ci / autofix (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Co-authored-by: Dev <dev@Devs-MacBook-Pro-4.local> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
parent
0294555893
commit
4f0fb6df2b
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
@ -99,7 +101,7 @@ class AccountPasswordPayload(BaseModel):
|
|||||||
repeat_new_password: str
|
repeat_new_password: str
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_passwords_match(self) -> "AccountPasswordPayload":
|
def check_passwords_match(self) -> AccountPasswordPayload:
|
||||||
if self.new_password != self.repeat_new_password:
|
if self.new_password != self.repeat_new_password:
|
||||||
raise RepeatPasswordNotMatchError()
|
raise RepeatPasswordNotMatchError()
|
||||||
return self
|
return self
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -30,7 +32,7 @@ class DatasourcePlugin(ABC):
|
|||||||
"""
|
"""
|
||||||
return DatasourceProviderType.LOCAL_FILE
|
return DatasourceProviderType.LOCAL_FILE
|
||||||
|
|
||||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> DatasourcePlugin:
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
entity=self.entity.model_copy(),
|
entity=self.entity.model_copy(),
|
||||||
runtime=runtime,
|
runtime=runtime,
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -31,7 +33,7 @@ class DatasourceProviderType(enum.StrEnum):
|
|||||||
ONLINE_DRIVE = "online_drive"
|
ONLINE_DRIVE = "online_drive"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "DatasourceProviderType":
|
def value_of(cls, value: str) -> DatasourceProviderType:
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
@ -81,7 +83,7 @@ class DatasourceParameter(PluginParameter):
|
|||||||
typ: DatasourceParameterType,
|
typ: DatasourceParameterType,
|
||||||
required: bool,
|
required: bool,
|
||||||
options: list[str] | None = None,
|
options: list[str] | None = None,
|
||||||
) -> "DatasourceParameter":
|
) -> DatasourceParameter:
|
||||||
"""
|
"""
|
||||||
get a simple datasource parameter
|
get a simple datasource parameter
|
||||||
|
|
||||||
@ -187,14 +189,14 @@ class DatasourceInvokeMeta(BaseModel):
|
|||||||
tool_config: dict | None = None
|
tool_config: dict | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls) -> "DatasourceInvokeMeta":
|
def empty(cls) -> DatasourceInvokeMeta:
|
||||||
"""
|
"""
|
||||||
Get an empty instance of DatasourceInvokeMeta
|
Get an empty instance of DatasourceInvokeMeta
|
||||||
"""
|
"""
|
||||||
return cls(time_cost=0.0, error=None, tool_config={})
|
return cls(time_cost=0.0, error=None, tool_config={})
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
|
def error_instance(cls, error: str) -> DatasourceInvokeMeta:
|
||||||
"""
|
"""
|
||||||
Get an instance of DatasourceInvokeMeta with error
|
Get an instance of DatasourceInvokeMeta with error
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
@ -75,7 +77,7 @@ class MCPProviderEntity(BaseModel):
|
|||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
|
def from_db_model(cls, db_provider: MCPToolProvider) -> MCPProviderEntity:
|
||||||
"""Create entity from database model with decryption"""
|
"""Create entity from database model with decryption"""
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@ -178,7 +180,7 @@ class BasicProviderConfig(BaseModel):
|
|||||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
|
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "ProviderConfig.Type":
|
def value_of(cls, value: str) -> ProviderConfig.Type:
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
|
|||||||
@ -68,13 +68,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||||||
request_id: RequestId,
|
request_id: RequestId,
|
||||||
request_meta: RequestParams.Meta | None,
|
request_meta: RequestParams.Meta | None,
|
||||||
request: ReceiveRequestT,
|
request: ReceiveRequestT,
|
||||||
session: """BaseSession[
|
session: """BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]""",
|
||||||
SendRequestT,
|
|
||||||
SendNotificationT,
|
|
||||||
SendResultT,
|
|
||||||
ReceiveRequestT,
|
|
||||||
ReceiveNotificationT
|
|
||||||
]""",
|
|
||||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||||
):
|
):
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
@ -17,7 +19,7 @@ class PromptMessageRole(StrEnum):
|
|||||||
TOOL = auto()
|
TOOL = auto()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "PromptMessageRole":
|
def value_of(cls, value: str) -> PromptMessageRole:
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -20,7 +22,7 @@ class ModelType(StrEnum):
|
|||||||
TTS = auto()
|
TTS = auto()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, origin_model_type: str) -> "ModelType":
|
def value_of(cls, origin_model_type: str) -> ModelType:
|
||||||
"""
|
"""
|
||||||
Get model type from origin model type.
|
Get model type from origin model type.
|
||||||
|
|
||||||
@ -103,7 +105,7 @@ class DefaultParameterName(StrEnum):
|
|||||||
JSON_SCHEMA = auto()
|
JSON_SCHEMA = auto()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: Any) -> "DefaultParameterName":
|
def value_of(cls, value: Any) -> DefaultParameterName:
|
||||||
"""
|
"""
|
||||||
Get parameter name from value.
|
Get parameter name from value.
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
@ -38,7 +40,7 @@ class ModelProviderFactory:
|
|||||||
plugin_providers = self.get_plugin_model_providers()
|
plugin_providers = self.get_plugin_model_providers()
|
||||||
return [provider.declaration for provider in plugin_providers]
|
return [provider.declaration for provider in plugin_providers]
|
||||||
|
|
||||||
def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
|
def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
|
||||||
"""
|
"""
|
||||||
Get all plugin model providers
|
Get all plugin model providers
|
||||||
:return: list of plugin model providers
|
:return: list of plugin model providers
|
||||||
@ -76,7 +78,7 @@ class ModelProviderFactory:
|
|||||||
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
|
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
|
||||||
return plugin_model_provider_entity.declaration
|
return plugin_model_provider_entity.declaration
|
||||||
|
|
||||||
def get_plugin_model_provider(self, provider: str) -> "PluginModelProviderEntity":
|
def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
|
||||||
"""
|
"""
|
||||||
Get plugin model provider
|
Get plugin model provider
|
||||||
:param provider: provider name
|
:param provider: provider name
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@ -242,7 +244,7 @@ class CredentialType(enum.StrEnum):
|
|||||||
return [item.value for item in cls]
|
return [item.value for item in cls]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def of(cls, credential_type: str) -> "CredentialType":
|
def of(cls, credential_type: str) -> CredentialType:
|
||||||
type_name = credential_type.lower()
|
type_name = credential_type.lower()
|
||||||
if type_name in {"api-key", "api_key"}:
|
if type_name in {"api-key", "api_key"}:
|
||||||
return cls.API_KEY
|
return cls.API_KEY
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -6,7 +8,7 @@ import re
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import clickzetta # type: ignore
|
import clickzetta # type: ignore
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
@ -76,7 +78,7 @@ class ClickzettaConnectionPool:
|
|||||||
Manages connection reuse across ClickzettaVector instances.
|
Manages connection reuse across ClickzettaVector instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_instance: Optional["ClickzettaConnectionPool"] = None
|
_instance: ClickzettaConnectionPool | None = None
|
||||||
_lock = threading.Lock()
|
_lock = threading.Lock()
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -89,7 +91,7 @@ class ClickzettaConnectionPool:
|
|||||||
self._start_cleanup_thread()
|
self._start_cleanup_thread()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls) -> "ClickzettaConnectionPool":
|
def get_instance(cls) -> ClickzettaConnectionPool:
|
||||||
"""Get singleton instance of connection pool."""
|
"""Get singleton instance of connection pool."""
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
@ -104,7 +106,7 @@ class ClickzettaConnectionPool:
|
|||||||
f"{config.workspace}:{config.vcluster}:{config.schema_name}"
|
f"{config.workspace}:{config.vcluster}:{config.schema_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_connection(self, config: ClickzettaConfig) -> "Connection":
|
def _create_connection(self, config: ClickzettaConfig) -> Connection:
|
||||||
"""Create a new ClickZetta connection."""
|
"""Create a new ClickZetta connection."""
|
||||||
max_retries = 3
|
max_retries = 3
|
||||||
retry_delay = 1.0
|
retry_delay = 1.0
|
||||||
@ -134,7 +136,7 @@ class ClickzettaConnectionPool:
|
|||||||
|
|
||||||
raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts")
|
raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts")
|
||||||
|
|
||||||
def _configure_connection(self, connection: "Connection"):
|
def _configure_connection(self, connection: Connection):
|
||||||
"""Configure connection session settings."""
|
"""Configure connection session settings."""
|
||||||
try:
|
try:
|
||||||
with connection.cursor() as cursor:
|
with connection.cursor() as cursor:
|
||||||
@ -181,7 +183,7 @@ class ClickzettaConnectionPool:
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to configure connection, continuing with defaults")
|
logger.exception("Failed to configure connection, continuing with defaults")
|
||||||
|
|
||||||
def _is_connection_valid(self, connection: "Connection") -> bool:
|
def _is_connection_valid(self, connection: Connection) -> bool:
|
||||||
"""Check if connection is still valid."""
|
"""Check if connection is still valid."""
|
||||||
try:
|
try:
|
||||||
with connection.cursor() as cursor:
|
with connection.cursor() as cursor:
|
||||||
@ -190,7 +192,7 @@ class ClickzettaConnectionPool:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_connection(self, config: ClickzettaConfig) -> "Connection":
|
def get_connection(self, config: ClickzettaConfig) -> Connection:
|
||||||
"""Get a connection from the pool or create a new one."""
|
"""Get a connection from the pool or create a new one."""
|
||||||
config_key = self._get_config_key(config)
|
config_key = self._get_config_key(config)
|
||||||
|
|
||||||
@ -221,7 +223,7 @@ class ClickzettaConnectionPool:
|
|||||||
# No valid connection found, create new one
|
# No valid connection found, create new one
|
||||||
return self._create_connection(config)
|
return self._create_connection(config)
|
||||||
|
|
||||||
def return_connection(self, config: ClickzettaConfig, connection: "Connection"):
|
def return_connection(self, config: ClickzettaConfig, connection: Connection):
|
||||||
"""Return a connection to the pool."""
|
"""Return a connection to the pool."""
|
||||||
config_key = self._get_config_key(config)
|
config_key = self._get_config_key(config)
|
||||||
|
|
||||||
@ -315,22 +317,22 @@ class ClickzettaVector(BaseVector):
|
|||||||
self._connection_pool = ClickzettaConnectionPool.get_instance()
|
self._connection_pool = ClickzettaConnectionPool.get_instance()
|
||||||
self._init_write_queue()
|
self._init_write_queue()
|
||||||
|
|
||||||
def _get_connection(self) -> "Connection":
|
def _get_connection(self) -> Connection:
|
||||||
"""Get a connection from the pool."""
|
"""Get a connection from the pool."""
|
||||||
return self._connection_pool.get_connection(self._config)
|
return self._connection_pool.get_connection(self._config)
|
||||||
|
|
||||||
def _return_connection(self, connection: "Connection"):
|
def _return_connection(self, connection: Connection):
|
||||||
"""Return a connection to the pool."""
|
"""Return a connection to the pool."""
|
||||||
self._connection_pool.return_connection(self._config, connection)
|
self._connection_pool.return_connection(self._config, connection)
|
||||||
|
|
||||||
class ConnectionContext:
|
class ConnectionContext:
|
||||||
"""Context manager for borrowing and returning connections."""
|
"""Context manager for borrowing and returning connections."""
|
||||||
|
|
||||||
def __init__(self, vector_instance: "ClickzettaVector"):
|
def __init__(self, vector_instance: ClickzettaVector):
|
||||||
self.vector = vector_instance
|
self.vector = vector_instance
|
||||||
self.connection: Connection | None = None
|
self.connection: Connection | None = None
|
||||||
|
|
||||||
def __enter__(self) -> "Connection":
|
def __enter__(self) -> Connection:
|
||||||
self.connection = self.vector._get_connection()
|
self.connection = self.vector._get_connection()
|
||||||
return self.connection
|
return self.connection
|
||||||
|
|
||||||
@ -338,7 +340,7 @@ class ClickzettaVector(BaseVector):
|
|||||||
if self.connection:
|
if self.connection:
|
||||||
self.vector._return_connection(self.connection)
|
self.vector._return_connection(self.connection)
|
||||||
|
|
||||||
def get_connection_context(self) -> "ClickzettaVector.ConnectionContext":
|
def get_connection_context(self) -> ClickzettaVector.ConnectionContext:
|
||||||
"""Get a connection context manager."""
|
"""Get a connection context manager."""
|
||||||
return self.ConnectionContext(self)
|
return self.ConnectionContext(self)
|
||||||
|
|
||||||
@ -437,7 +439,7 @@ class ClickzettaVector(BaseVector):
|
|||||||
"""Return the vector database type."""
|
"""Return the vector database type."""
|
||||||
return "clickzetta"
|
return "clickzetta"
|
||||||
|
|
||||||
def _ensure_connection(self) -> "Connection":
|
def _ensure_connection(self) -> Connection:
|
||||||
"""Get a connection from the pool."""
|
"""Get a connection from the pool."""
|
||||||
return self._get_connection()
|
return self._get_connection()
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -22,7 +24,7 @@ class DatasetDocumentStore:
|
|||||||
self._document_id = document_id
|
self._document_id = document_id
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore":
|
def from_dict(cls, config_dict: dict[str, Any]) -> DatasetDocumentStore:
|
||||||
return cls(**config_dict)
|
return cls(**config_dict)
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -16,7 +18,7 @@ class TaskWrapper(BaseModel):
|
|||||||
return self.model_dump_json()
|
return self.model_dump_json()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def deserialize(cls, serialized_data: str) -> "TaskWrapper":
|
def deserialize(cls, serialized_data: str) -> TaskWrapper:
|
||||||
return cls.model_validate_json(serialized_data)
|
return cls.model_validate_json(serialized_data)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from collections.abc import Mapping, MutableMapping
|
from collections.abc import Mapping, MutableMapping
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, ClassVar, Optional
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
|
||||||
class SchemaRegistry:
|
class SchemaRegistry:
|
||||||
@ -11,7 +13,7 @@ class SchemaRegistry:
|
|||||||
|
|
||||||
logger: ClassVar[logging.Logger] = logging.getLogger(__name__)
|
logger: ClassVar[logging.Logger] = logging.getLogger(__name__)
|
||||||
|
|
||||||
_default_instance: ClassVar[Optional["SchemaRegistry"]] = None
|
_default_instance: ClassVar[SchemaRegistry | None] = None
|
||||||
_lock: ClassVar[threading.Lock] = threading.Lock()
|
_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||||
|
|
||||||
def __init__(self, base_dir: str):
|
def __init__(self, base_dir: str):
|
||||||
@ -20,7 +22,7 @@ class SchemaRegistry:
|
|||||||
self.metadata: MutableMapping[str, MutableMapping[str, Any]] = {}
|
self.metadata: MutableMapping[str, MutableMapping[str, Any]] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_registry(cls) -> "SchemaRegistry":
|
def default_registry(cls) -> SchemaRegistry:
|
||||||
"""Returns the default schema registry for builtin schemas (thread-safe singleton)"""
|
"""Returns the default schema registry for builtin schemas (thread-safe singleton)"""
|
||||||
if cls._default_instance is None:
|
if cls._default_instance is None:
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -24,7 +26,7 @@ class Tool(ABC):
|
|||||||
self.entity = entity
|
self.entity = entity
|
||||||
self.runtime = runtime
|
self.runtime = runtime
|
||||||
|
|
||||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool":
|
def fork_tool_runtime(self, runtime: ToolRuntime) -> Tool:
|
||||||
"""
|
"""
|
||||||
fork a new tool with metadata
|
fork a new tool with metadata
|
||||||
:return: the new tool
|
:return: the new tool
|
||||||
@ -166,7 +168,7 @@ class Tool(ABC):
|
|||||||
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
|
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_file_message(self, file: "File") -> ToolInvokeMessage:
|
def create_file_message(self, file: File) -> ToolInvokeMessage:
|
||||||
return ToolInvokeMessage(
|
return ToolInvokeMessage(
|
||||||
type=ToolInvokeMessage.MessageType.FILE,
|
type=ToolInvokeMessage.MessageType.FILE,
|
||||||
message=ToolInvokeMessage.FileMessage(),
|
message=ToolInvokeMessage.FileMessage(),
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult
|
from core.model_runtime.entities.llm_entities import LLMResult
|
||||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
@ -24,7 +26,7 @@ class BuiltinTool(Tool):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
|
|
||||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool":
|
def fork_tool_runtime(self, runtime: ToolRuntime) -> BuiltinTool:
|
||||||
"""
|
"""
|
||||||
fork a new tool with metadata
|
fork a new tool with metadata
|
||||||
:return: the new tool
|
:return: the new tool
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
@ -32,7 +34,7 @@ class ApiToolProviderController(ToolProviderController):
|
|||||||
self.tools = []
|
self.tools = []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
|
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> ApiToolProviderController:
|
||||||
credentials_schema = [
|
credentials_schema = [
|
||||||
ProviderConfig(
|
ProviderConfig(
|
||||||
name="auth_type",
|
name="auth_type",
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import contextlib
|
import contextlib
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
@ -55,7 +57,7 @@ class ToolProviderType(StrEnum):
|
|||||||
MCP = auto()
|
MCP = auto()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "ToolProviderType":
|
def value_of(cls, value: str) -> ToolProviderType:
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
@ -79,7 +81,7 @@ class ApiProviderSchemaType(StrEnum):
|
|||||||
OPENAI_ACTIONS = auto()
|
OPENAI_ACTIONS = auto()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "ApiProviderSchemaType":
|
def value_of(cls, value: str) -> ApiProviderSchemaType:
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
@ -102,7 +104,7 @@ class ApiProviderAuthType(StrEnum):
|
|||||||
API_KEY_QUERY = auto()
|
API_KEY_QUERY = auto()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "ApiProviderAuthType":
|
def value_of(cls, value: str) -> ApiProviderAuthType:
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
@ -307,7 +309,7 @@ class ToolParameter(PluginParameter):
|
|||||||
typ: ToolParameterType,
|
typ: ToolParameterType,
|
||||||
required: bool,
|
required: bool,
|
||||||
options: list[str] | None = None,
|
options: list[str] | None = None,
|
||||||
) -> "ToolParameter":
|
) -> ToolParameter:
|
||||||
"""
|
"""
|
||||||
get a simple tool parameter
|
get a simple tool parameter
|
||||||
|
|
||||||
@ -429,14 +431,14 @@ class ToolInvokeMeta(BaseModel):
|
|||||||
tool_config: dict | None = None
|
tool_config: dict | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls) -> "ToolInvokeMeta":
|
def empty(cls) -> ToolInvokeMeta:
|
||||||
"""
|
"""
|
||||||
Get an empty instance of ToolInvokeMeta
|
Get an empty instance of ToolInvokeMeta
|
||||||
"""
|
"""
|
||||||
return cls(time_cost=0.0, error=None, tool_config={})
|
return cls(time_cost=0.0, error=None, tool_config={})
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def error_instance(cls, error: str) -> "ToolInvokeMeta":
|
def error_instance(cls, error: str) -> ToolInvokeMeta:
|
||||||
"""
|
"""
|
||||||
Get an instance of ToolInvokeMeta with error
|
Get an instance of ToolInvokeMeta with error
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -118,7 +120,7 @@ class MCPTool(Tool):
|
|||||||
for item in json_list:
|
for item in json_list:
|
||||||
yield self.create_json_message(item)
|
yield self.create_json_message(item)
|
||||||
|
|
||||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool:
|
||||||
return MCPTool(
|
return MCPTool(
|
||||||
entity=self.entity,
|
entity=self.entity,
|
||||||
runtime=runtime,
|
runtime=runtime,
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -46,7 +48,7 @@ class PluginTool(Tool):
|
|||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool":
|
def fork_tool_runtime(self, runtime: ToolRuntime) -> PluginTool:
|
||||||
return PluginTool(
|
return PluginTool(
|
||||||
entity=self.entity,
|
entity=self.entity,
|
||||||
runtime=runtime,
|
runtime=runtime,
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -47,7 +49,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||||||
self.provider_id = provider_id
|
self.provider_id = provider_id
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
|
def from_db(cls, db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
|
||||||
with session_factory.create_session() as session, session.begin():
|
with session_factory.create_session() as session, session.begin():
|
||||||
app = session.get(App, db_provider.app_id)
|
app = session.get(App, db_provider.app_id)
|
||||||
if not app:
|
if not app:
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
@ -181,7 +183,7 @@ class WorkflowTool(Tool):
|
|||||||
return found
|
return found
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
|
def fork_tool_runtime(self, runtime: ToolRuntime) -> WorkflowTool:
|
||||||
"""
|
"""
|
||||||
fork a new tool with metadata
|
fork a new tool with metadata
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
|
|
||||||
@ -52,7 +54,7 @@ class SegmentType(StrEnum):
|
|||||||
return self in _ARRAY_TYPES
|
return self in _ARRAY_TYPES
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]:
|
def infer_segment_type(cls, value: Any) -> SegmentType | None:
|
||||||
"""
|
"""
|
||||||
Attempt to infer the `SegmentType` based on the Python type of the `value` parameter.
|
Attempt to infer the `SegmentType` based on the Python type of the `value` parameter.
|
||||||
|
|
||||||
@ -173,7 +175,7 @@ class SegmentType(StrEnum):
|
|||||||
raise AssertionError("this statement should be unreachable.")
|
raise AssertionError("this statement should be unreachable.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cast_value(value: Any, type_: "SegmentType"):
|
def cast_value(value: Any, type_: SegmentType):
|
||||||
# Cast Python's `bool` type to `int` when the runtime type requires
|
# Cast Python's `bool` type to `int` when the runtime type requires
|
||||||
# an integer or number.
|
# an integer or number.
|
||||||
#
|
#
|
||||||
@ -193,7 +195,7 @@ class SegmentType(StrEnum):
|
|||||||
return [int(i) for i in value]
|
return [int(i) for i in value]
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def exposed_type(self) -> "SegmentType":
|
def exposed_type(self) -> SegmentType:
|
||||||
"""Returns the type exposed to the frontend.
|
"""Returns the type exposed to the frontend.
|
||||||
|
|
||||||
The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here.
|
The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here.
|
||||||
@ -202,7 +204,7 @@ class SegmentType(StrEnum):
|
|||||||
return SegmentType.NUMBER
|
return SegmentType.NUMBER
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def element_type(self) -> "SegmentType | None":
|
def element_type(self) -> SegmentType | None:
|
||||||
"""Return the element type of the current segment type, or `None` if the element type is undefined.
|
"""Return the element type of the current segment type, or `None` if the element type is undefined.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -217,7 +219,7 @@ class SegmentType(StrEnum):
|
|||||||
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
|
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_zero_value(t: "SegmentType"):
|
def get_zero_value(t: SegmentType):
|
||||||
# Lazy import to avoid circular dependency
|
# Lazy import to avoid circular dependency
|
||||||
from factories import variable_factory
|
from factories import variable_factory
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,8 @@ Models are independent of the storage mechanism and don't contain
|
|||||||
implementation details like tenant_id, app_id, etc.
|
implementation details like tenant_id, app_id, etc.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -59,7 +61,7 @@ class WorkflowExecution(BaseModel):
|
|||||||
graph: Mapping[str, Any],
|
graph: Mapping[str, Any],
|
||||||
inputs: Mapping[str, Any],
|
inputs: Mapping[str, Any],
|
||||||
started_at: datetime,
|
started_at: datetime,
|
||||||
) -> "WorkflowExecution":
|
) -> WorkflowExecution:
|
||||||
return WorkflowExecution(
|
return WorkflowExecution(
|
||||||
id_=id_,
|
id_=id_,
|
||||||
workflow_id=workflow_id,
|
workflow_id=workflow_id,
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
@ -175,7 +177,7 @@ class Graph:
|
|||||||
def _create_node_instances(
|
def _create_node_instances(
|
||||||
cls,
|
cls,
|
||||||
node_configs_map: dict[str, dict[str, object]],
|
node_configs_map: dict[str, dict[str, object]],
|
||||||
node_factory: "NodeFactory",
|
node_factory: NodeFactory,
|
||||||
) -> dict[str, Node]:
|
) -> dict[str, Node]:
|
||||||
"""
|
"""
|
||||||
Create node instances from configurations using the node factory.
|
Create node instances from configurations using the node factory.
|
||||||
@ -197,7 +199,7 @@ class Graph:
|
|||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def new(cls) -> "GraphBuilder":
|
def new(cls) -> GraphBuilder:
|
||||||
"""Create a fluent builder for assembling a graph programmatically."""
|
"""Create a fluent builder for assembling a graph programmatically."""
|
||||||
|
|
||||||
return GraphBuilder(graph_cls=cls)
|
return GraphBuilder(graph_cls=cls)
|
||||||
@ -284,9 +286,9 @@ class Graph:
|
|||||||
cls,
|
cls,
|
||||||
*,
|
*,
|
||||||
graph_config: Mapping[str, object],
|
graph_config: Mapping[str, object],
|
||||||
node_factory: "NodeFactory",
|
node_factory: NodeFactory,
|
||||||
root_node_id: str | None = None,
|
root_node_id: str | None = None,
|
||||||
) -> "Graph":
|
) -> Graph:
|
||||||
"""
|
"""
|
||||||
Initialize graph
|
Initialize graph
|
||||||
|
|
||||||
@ -383,7 +385,7 @@ class GraphBuilder:
|
|||||||
self._edges: list[Edge] = []
|
self._edges: list[Edge] = []
|
||||||
self._edge_counter = 0
|
self._edge_counter = 0
|
||||||
|
|
||||||
def add_root(self, node: Node) -> "GraphBuilder":
|
def add_root(self, node: Node) -> GraphBuilder:
|
||||||
"""Register the root node. Must be called exactly once."""
|
"""Register the root node. Must be called exactly once."""
|
||||||
|
|
||||||
if self._nodes:
|
if self._nodes:
|
||||||
@ -398,7 +400,7 @@ class GraphBuilder:
|
|||||||
*,
|
*,
|
||||||
from_node_id: str | None = None,
|
from_node_id: str | None = None,
|
||||||
source_handle: str = "source",
|
source_handle: str = "source",
|
||||||
) -> "GraphBuilder":
|
) -> GraphBuilder:
|
||||||
"""Append a node and connect it from the specified predecessor."""
|
"""Append a node and connect it from the specified predecessor."""
|
||||||
|
|
||||||
if not self._nodes:
|
if not self._nodes:
|
||||||
@ -419,7 +421,7 @@ class GraphBuilder:
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder":
|
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder:
|
||||||
"""Connect two existing nodes without adding a new node."""
|
"""Connect two existing nodes without adding a new node."""
|
||||||
|
|
||||||
if tail not in self._nodes_by_id:
|
if tail not in self._nodes_by_id:
|
||||||
|
|||||||
@ -5,6 +5,8 @@ This engine uses a modular architecture with separated packages following
|
|||||||
Domain-Driven Design principles for improved maintainability and testability.
|
Domain-Driven Design principles for improved maintainability and testability.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextvars
|
import contextvars
|
||||||
import logging
|
import logging
|
||||||
import queue
|
import queue
|
||||||
@ -232,7 +234,7 @@ class GraphEngine:
|
|||||||
) -> None:
|
) -> None:
|
||||||
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
|
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
|
||||||
|
|
||||||
def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
|
def layer(self, layer: GraphEngineLayer) -> GraphEngine:
|
||||||
"""Add a layer for extending functionality."""
|
"""Add a layer for extending functionality."""
|
||||||
self._layers.append(layer)
|
self._layers.append(layer)
|
||||||
self._bind_layer_context(layer)
|
self._bind_layer_context(layer)
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
Factory for creating ReadyQueue instances from serialized state.
|
Factory for creating ReadyQueue instances from serialized state.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from .in_memory import InMemoryReadyQueue
|
from .in_memory import InMemoryReadyQueue
|
||||||
@ -11,7 +13,7 @@ if TYPE_CHECKING:
|
|||||||
from .protocol import ReadyQueue
|
from .protocol import ReadyQueue
|
||||||
|
|
||||||
|
|
||||||
def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue":
|
def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue:
|
||||||
"""
|
"""
|
||||||
Create a ReadyQueue instance from a serialized state.
|
Create a ReadyQueue instance from a serialized state.
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,8 @@ This module contains the private ResponseSession class used internally
|
|||||||
by ResponseStreamCoordinator to manage streaming sessions.
|
by ResponseStreamCoordinator to manage streaming sessions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||||
@ -27,7 +29,7 @@ class ResponseSession:
|
|||||||
index: int = 0 # Current position in the template segments
|
index: int = 0 # Current position in the template segments
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_node(cls, node: Node) -> "ResponseSession":
|
def from_node(cls, node: Node) -> ResponseSession:
|
||||||
"""
|
"""
|
||||||
Create a ResponseSession from an AnswerNode or EndNode.
|
Create a ResponseSession from an AnswerNode or EndNode.
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
@ -167,7 +169,7 @@ class AgentNode(Node[AgentNodeData]):
|
|||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
node_data: AgentNodeData,
|
node_data: AgentNodeData,
|
||||||
for_log: bool = False,
|
for_log: bool = False,
|
||||||
strategy: "PluginAgentStrategy",
|
strategy: PluginAgentStrategy,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||||
@ -328,7 +330,7 @@ class AgentNode(Node[AgentNodeData]):
|
|||||||
def _generate_credentials(
|
def _generate_credentials(
|
||||||
self,
|
self,
|
||||||
parameters: dict[str, Any],
|
parameters: dict[str, Any],
|
||||||
) -> "InvokeCredentials":
|
) -> InvokeCredentials:
|
||||||
"""
|
"""
|
||||||
Generate credentials based on the given agent parameters.
|
Generate credentials based on the given agent parameters.
|
||||||
"""
|
"""
|
||||||
@ -442,9 +444,7 @@ class AgentNode(Node[AgentNodeData]):
|
|||||||
model_schema.features.remove(feature)
|
model_schema.features.remove(feature)
|
||||||
return model_schema
|
return model_schema
|
||||||
|
|
||||||
def _filter_mcp_type_tool(
|
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]]
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""
|
"""
|
||||||
Filter MCP type tool
|
Filter MCP type tool
|
||||||
:param strategy: plugin agent strategy
|
:param strategy: plugin agent strategy
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from builtins import type as type_
|
from builtins import type as type_
|
||||||
@ -111,7 +113,7 @@ class DefaultValue(BaseModel):
|
|||||||
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
|
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_value_type(self) -> "DefaultValue":
|
def validate_value_type(self) -> DefaultValue:
|
||||||
# Type validation configuration
|
# Type validation configuration
|
||||||
type_validators = {
|
type_validators = {
|
||||||
DefaultValueType.STRING: {
|
DefaultValueType.STRING: {
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
@ -59,7 +61,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class Node(Generic[NodeDataT]):
|
class Node(Generic[NodeDataT]):
|
||||||
node_type: ClassVar["NodeType"]
|
node_type: ClassVar[NodeType]
|
||||||
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
||||||
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
|
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
|
||||||
|
|
||||||
@ -198,14 +200,14 @@ class Node(Generic[NodeDataT]):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Global registry populated via __init_subclass__
|
# Global registry populated via __init_subclass__
|
||||||
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}
|
_registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
id: str,
|
id: str,
|
||||||
config: Mapping[str, Any],
|
config: Mapping[str, Any],
|
||||||
graph_init_params: "GraphInitParams",
|
graph_init_params: GraphInitParams,
|
||||||
graph_runtime_state: "GraphRuntimeState",
|
graph_runtime_state: GraphRuntimeState,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._graph_init_params = graph_init_params
|
self._graph_init_params = graph_init_params
|
||||||
self.id = id
|
self.id = id
|
||||||
@ -241,7 +243,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
return
|
return
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def graph_init_params(self) -> "GraphInitParams":
|
def graph_init_params(self) -> GraphInitParams:
|
||||||
return self._graph_init_params
|
return self._graph_init_params
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -457,7 +459,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
|
def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]:
|
||||||
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
|
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
|
||||||
|
|
||||||
Import all modules under core.workflow.nodes so subclasses register themselves on import.
|
Import all modules under core.workflow.nodes so subclasses register themselves on import.
|
||||||
|
|||||||
@ -4,6 +4,8 @@ This module provides a unified template structure for both Answer and End nodes,
|
|||||||
similar to SegmentGroup but focused on template representation without values.
|
similar to SegmentGroup but focused on template representation without values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -58,7 +60,7 @@ class Template:
|
|||||||
segments: list[TemplateSegmentUnion]
|
segments: list[TemplateSegmentUnion]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_answer_template(cls, template_str: str) -> "Template":
|
def from_answer_template(cls, template_str: str) -> Template:
|
||||||
"""Create a Template from an Answer node template string.
|
"""Create a Template from an Answer node template string.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -107,7 +109,7 @@ class Template:
|
|||||||
return cls(segments=segments)
|
return cls(segments=segments)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template":
|
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> Template:
|
||||||
"""Create a Template from an End node outputs configuration.
|
"""Create a Template from an End node outputs configuration.
|
||||||
|
|
||||||
End nodes are treated as templates of concatenated variables with newlines.
|
End nodes are treated as templates of concatenated variables with newlines.
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
@ -113,7 +115,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
|
|
||||||
# Instance attributes specific to LLMNode.
|
# Instance attributes specific to LLMNode.
|
||||||
# Output variable for file
|
# Output variable for file
|
||||||
_file_outputs: list["File"]
|
_file_outputs: list[File]
|
||||||
|
|
||||||
_llm_file_saver: LLMFileSaver
|
_llm_file_saver: LLMFileSaver
|
||||||
|
|
||||||
@ -121,8 +123,8 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
self,
|
self,
|
||||||
id: str,
|
id: str,
|
||||||
config: Mapping[str, Any],
|
config: Mapping[str, Any],
|
||||||
graph_init_params: "GraphInitParams",
|
graph_init_params: GraphInitParams,
|
||||||
graph_runtime_state: "GraphRuntimeState",
|
graph_runtime_state: GraphRuntimeState,
|
||||||
*,
|
*,
|
||||||
llm_file_saver: LLMFileSaver | None = None,
|
llm_file_saver: LLMFileSaver | None = None,
|
||||||
):
|
):
|
||||||
@ -361,7 +363,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
structured_output_enabled: bool,
|
structured_output_enabled: bool,
|
||||||
structured_output: Mapping[str, Any] | None = None,
|
structured_output: Mapping[str, Any] | None = None,
|
||||||
file_saver: LLMFileSaver,
|
file_saver: LLMFileSaver,
|
||||||
file_outputs: list["File"],
|
file_outputs: list[File],
|
||||||
node_id: str,
|
node_id: str,
|
||||||
node_type: NodeType,
|
node_type: NodeType,
|
||||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||||
@ -415,7 +417,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
*,
|
*,
|
||||||
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
|
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
|
||||||
file_saver: LLMFileSaver,
|
file_saver: LLMFileSaver,
|
||||||
file_outputs: list["File"],
|
file_outputs: list[File],
|
||||||
node_id: str,
|
node_id: str,
|
||||||
node_type: NodeType,
|
node_type: NodeType,
|
||||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||||
@ -525,7 +527,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _image_file_to_markdown(file: "File", /):
|
def _image_file_to_markdown(file: File, /):
|
||||||
text_chunk = f"})"
|
text_chunk = f"})"
|
||||||
return text_chunk
|
return text_chunk
|
||||||
|
|
||||||
@ -774,7 +776,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
def fetch_prompt_messages(
|
def fetch_prompt_messages(
|
||||||
*,
|
*,
|
||||||
sys_query: str | None = None,
|
sys_query: str | None = None,
|
||||||
sys_files: Sequence["File"],
|
sys_files: Sequence[File],
|
||||||
context: str | None = None,
|
context: str | None = None,
|
||||||
memory: TokenBufferMemory | None = None,
|
memory: TokenBufferMemory | None = None,
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
@ -785,7 +787,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
jinja2_variables: Sequence[VariableSelector],
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
context_files: list["File"] | None = None,
|
context_files: list[File] | None = None,
|
||||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||||
prompt_messages: list[PromptMessage] = []
|
prompt_messages: list[PromptMessage] = []
|
||||||
|
|
||||||
@ -1137,7 +1139,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
*,
|
*,
|
||||||
invoke_result: LLMResult | LLMResultWithStructuredOutput,
|
invoke_result: LLMResult | LLMResultWithStructuredOutput,
|
||||||
saver: LLMFileSaver,
|
saver: LLMFileSaver,
|
||||||
file_outputs: list["File"],
|
file_outputs: list[File],
|
||||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||||
request_latency: float | None = None,
|
request_latency: float | None = None,
|
||||||
) -> ModelInvokeCompletedEvent:
|
) -> ModelInvokeCompletedEvent:
|
||||||
@ -1179,7 +1181,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
*,
|
*,
|
||||||
content: ImagePromptMessageContent,
|
content: ImagePromptMessageContent,
|
||||||
file_saver: LLMFileSaver,
|
file_saver: LLMFileSaver,
|
||||||
) -> "File":
|
) -> File:
|
||||||
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
|
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
|
||||||
|
|
||||||
There are two kinds of multimodal outputs:
|
There are two kinds of multimodal outputs:
|
||||||
@ -1229,7 +1231,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
*,
|
*,
|
||||||
contents: str | list[PromptMessageContentUnionTypes] | None,
|
contents: str | list[PromptMessageContentUnionTypes] | None,
|
||||||
file_saver: LLMFileSaver,
|
file_saver: LLMFileSaver,
|
||||||
file_outputs: list["File"],
|
file_outputs: list[File],
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
"""Convert intermediate prompt messages into strings and yield them to the caller.
|
"""Convert intermediate prompt messages into strings and yield them to the caller.
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
@ -23,7 +25,7 @@ class DraftVariableSaverFactory(Protocol):
|
|||||||
node_type: NodeType,
|
node_type: NodeType,
|
||||||
node_execution_id: str,
|
node_execution_id: str,
|
||||||
enclosing_node_id: str | None = None,
|
enclosing_node_id: str | None = None,
|
||||||
) -> "DraftVariableSaver":
|
) -> DraftVariableSaver:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
@ -267,6 +269,6 @@ class VariablePool(BaseModel):
|
|||||||
self.add(selector, value)
|
self.add(selector, value)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls) -> "VariablePool":
|
def empty(cls) -> VariablePool:
|
||||||
"""Create an empty variable pool."""
|
"""Create an empty variable pool."""
|
||||||
return cls(system_variables=SystemVariable.empty())
|
return cls(system_variables=SystemVariable.empty())
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -70,7 +72,7 @@ class SystemVariable(BaseModel):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls) -> "SystemVariable":
|
def empty(cls) -> SystemVariable:
|
||||||
return cls()
|
return cls()
|
||||||
|
|
||||||
def to_dict(self) -> dict[SystemVariableKey, Any]:
|
def to_dict(self) -> dict[SystemVariableKey, Any]:
|
||||||
@ -114,7 +116,7 @@ class SystemVariable(BaseModel):
|
|||||||
d[SystemVariableKey.TIMESTAMP] = self.timestamp
|
d[SystemVariableKey.TIMESTAMP] = self.timestamp
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def as_view(self) -> "SystemVariableReadOnlyView":
|
def as_view(self) -> SystemVariableReadOnlyView:
|
||||||
return SystemVariableReadOnlyView(self)
|
return SystemVariableReadOnlyView(self)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
@ -33,7 +35,7 @@ class AliyunLogStore:
|
|||||||
Ensures only one instance exists to prevent multiple PG connection pools.
|
Ensures only one instance exists to prevent multiple PG connection pools.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_instance: "AliyunLogStore | None" = None
|
_instance: AliyunLogStore | None = None
|
||||||
_initialized: bool = False
|
_initialized: bool = False
|
||||||
|
|
||||||
# Track delayed PG connection for newly created projects
|
# Track delayed PG connection for newly created projects
|
||||||
@ -66,7 +68,7 @@ class AliyunLogStore:
|
|||||||
"\t",
|
"\t",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __new__(cls) -> "AliyunLogStore":
|
def __new__(cls) -> AliyunLogStore:
|
||||||
"""Implement singleton pattern."""
|
"""Implement singleton pattern."""
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
|
|||||||
@ -5,6 +5,8 @@ automatic cleanup, backup and restore.
|
|||||||
Supports complete lifecycle management for knowledge base files.
|
Supports complete lifecycle management for knowledge base files.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
@ -48,7 +50,7 @@ class FileMetadata:
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict) -> "FileMetadata":
|
def from_dict(cls, data: dict) -> FileMetadata:
|
||||||
"""Create instance from dictionary"""
|
"""Create instance from dictionary"""
|
||||||
data = data.copy()
|
data = data.copy()
|
||||||
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
Broadcast channel for Pub/Sub messaging.
|
Broadcast channel for Pub/Sub messaging.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import types
|
import types
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
@ -129,6 +131,6 @@ class BroadcastChannel(Protocol):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def topic(self, topic: str) -> "Topic":
|
def topic(self, topic: str) -> Topic:
|
||||||
"""topic returns a `Topic` instance for the given topic name."""
|
"""topic returns a `Topic` instance for the given topic name."""
|
||||||
...
|
...
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
|
|
||||||
@ -20,7 +22,7 @@ class BroadcastChannel:
|
|||||||
):
|
):
|
||||||
self._client = redis_client
|
self._client = redis_client
|
||||||
|
|
||||||
def topic(self, topic: str) -> "Topic":
|
def topic(self, topic: str) -> Topic:
|
||||||
return Topic(self._client, topic)
|
return Topic(self._client, topic)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
|
|
||||||
@ -18,7 +20,7 @@ class ShardedRedisBroadcastChannel:
|
|||||||
):
|
):
|
||||||
self._client = redis_client
|
self._client = redis_client
|
||||||
|
|
||||||
def topic(self, topic: str) -> "ShardedTopic":
|
def topic(self, topic: str) -> ShardedTopic:
|
||||||
return ShardedTopic(self._client, topic)
|
return ShardedTopic(self._client, topic)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,8 @@ in Dify. It follows Domain-Driven Design principles with proper type hints and
|
|||||||
eliminates the need for repetitive language switching logic.
|
eliminates the need for repetitive language switching logic.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
@ -53,7 +55,7 @@ class EmailLanguage(StrEnum):
|
|||||||
ZH_HANS = "zh-Hans"
|
ZH_HANS = "zh-Hans"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_language_code(cls, language_code: str) -> "EmailLanguage":
|
def from_language_code(cls, language_code: str) -> EmailLanguage:
|
||||||
"""Convert a language code to EmailLanguage with fallback to English."""
|
"""Convert a language code to EmailLanguage with fallback to English."""
|
||||||
if language_code == "zh-Hans":
|
if language_code == "zh-Hans":
|
||||||
return cls.ZH_HANS
|
return cls.ZH_HANS
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
@ -5,7 +7,7 @@ from collections.abc import Mapping
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
@ -54,7 +56,7 @@ class AppMode(StrEnum):
|
|||||||
RAG_PIPELINE = "rag-pipeline"
|
RAG_PIPELINE = "rag-pipeline"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "AppMode":
|
def value_of(cls, value: str) -> AppMode:
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
@ -121,19 +123,19 @@ class App(Base):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def site(self) -> Optional["Site"]:
|
def site(self) -> Site | None:
|
||||||
site = db.session.query(Site).where(Site.app_id == self.id).first()
|
site = db.session.query(Site).where(Site.app_id == self.id).first()
|
||||||
return site
|
return site
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def app_model_config(self) -> Optional["AppModelConfig"]:
|
def app_model_config(self) -> AppModelConfig | None:
|
||||||
if self.app_model_config_id:
|
if self.app_model_config_id:
|
||||||
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
|
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def workflow(self) -> Optional["Workflow"]:
|
def workflow(self) -> Workflow | None:
|
||||||
if self.workflow_id:
|
if self.workflow_id:
|
||||||
from .workflow import Workflow
|
from .workflow import Workflow
|
||||||
|
|
||||||
@ -288,7 +290,7 @@ class App(Base):
|
|||||||
return deleted_tools
|
return deleted_tools
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tags(self) -> list["Tag"]:
|
def tags(self) -> list[Tag]:
|
||||||
tags = (
|
tags = (
|
||||||
db.session.query(Tag)
|
db.session.query(Tag)
|
||||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||||
@ -1194,7 +1196,7 @@ class Message(Base):
|
|||||||
return json.loads(self.message_metadata) if self.message_metadata else {}
|
return json.loads(self.message_metadata) if self.message_metadata else {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def agent_thoughts(self) -> list["MessageAgentThought"]:
|
def agent_thoughts(self) -> list[MessageAgentThought]:
|
||||||
return (
|
return (
|
||||||
db.session.query(MessageAgentThought)
|
db.session.query(MessageAgentThought)
|
||||||
.where(MessageAgentThought.message_id == self.id)
|
.where(MessageAgentThought.message_id == self.id)
|
||||||
@ -1307,7 +1309,7 @@ class Message(Base):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> "Message":
|
def from_dict(cls, data: dict[str, Any]) -> Message:
|
||||||
return cls(
|
return cls(
|
||||||
id=data["id"],
|
id=data["id"],
|
||||||
app_id=data["app_id"],
|
app_id=data["app_id"],
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
@ -19,7 +21,7 @@ class ProviderType(StrEnum):
|
|||||||
SYSTEM = auto()
|
SYSTEM = auto()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def value_of(value: str) -> "ProviderType":
|
def value_of(value: str) -> ProviderType:
|
||||||
for member in ProviderType:
|
for member in ProviderType:
|
||||||
if member.value == value:
|
if member.value == value:
|
||||||
return member
|
return member
|
||||||
@ -37,7 +39,7 @@ class ProviderQuotaType(StrEnum):
|
|||||||
"""hosted trial quota"""
|
"""hosted trial quota"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def value_of(value: str) -> "ProviderQuotaType":
|
def value_of(value: str) -> ProviderQuotaType:
|
||||||
for member in ProviderQuotaType:
|
for member in ProviderQuotaType:
|
||||||
if member.value == value:
|
if member.value == value:
|
||||||
return member
|
return member
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
@ -167,11 +169,11 @@ class ApiToolProvider(TypeBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def schema_type(self) -> "ApiProviderSchemaType":
|
def schema_type(self) -> ApiProviderSchemaType:
|
||||||
return ApiProviderSchemaType.value_of(self.schema_type_str)
|
return ApiProviderSchemaType.value_of(self.schema_type_str)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tools(self) -> list["ApiToolBundle"]:
|
def tools(self) -> list[ApiToolBundle]:
|
||||||
return [ApiToolBundle.model_validate(tool) for tool in json.loads(self.tools_str)]
|
return [ApiToolBundle.model_validate(tool) for tool in json.loads(self.tools_str)]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -267,7 +269,7 @@ class WorkflowToolProvider(TypeBase):
|
|||||||
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameter_configurations(self) -> list["WorkflowToolParameterConfiguration"]:
|
def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
|
||||||
return [
|
return [
|
||||||
WorkflowToolParameterConfiguration.model_validate(config)
|
WorkflowToolParameterConfiguration.model_validate(config)
|
||||||
for config in json.loads(self.parameter_configuration)
|
for config in json.loads(self.parameter_configuration)
|
||||||
@ -359,7 +361,7 @@ class MCPToolProvider(TypeBase):
|
|||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def to_entity(self) -> "MCPProviderEntity":
|
def to_entity(self) -> MCPProviderEntity:
|
||||||
"""Convert to domain entity"""
|
"""Convert to domain entity"""
|
||||||
from core.entities.mcp_provider import MCPProviderEntity
|
from core.entities.mcp_provider import MCPProviderEntity
|
||||||
|
|
||||||
@ -533,5 +535,5 @@ class DeprecatedPublishedAppTool(TypeBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description_i18n(self) -> "I18nObject":
|
def description_i18n(self) -> I18nObject:
|
||||||
return I18nObject.model_validate(json.loads(self.description))
|
return I18nObject.model_validate(json.loads(self.description))
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
from typing import TYPE_CHECKING, Any, Union, cast
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
@ -67,7 +69,7 @@ class WorkflowType(StrEnum):
|
|||||||
RAG_PIPELINE = "rag-pipeline"
|
RAG_PIPELINE = "rag-pipeline"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "WorkflowType":
|
def value_of(cls, value: str) -> WorkflowType:
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
@ -80,7 +82,7 @@ class WorkflowType(StrEnum):
|
|||||||
raise ValueError(f"invalid workflow type value {value}")
|
raise ValueError(f"invalid workflow type value {value}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
|
def from_app_mode(cls, app_mode: Union[str, AppMode]) -> WorkflowType:
|
||||||
"""
|
"""
|
||||||
Get workflow type from app mode.
|
Get workflow type from app mode.
|
||||||
|
|
||||||
@ -181,7 +183,7 @@ class Workflow(Base): # bug
|
|||||||
rag_pipeline_variables: list[dict],
|
rag_pipeline_variables: list[dict],
|
||||||
marked_name: str = "",
|
marked_name: str = "",
|
||||||
marked_comment: str = "",
|
marked_comment: str = "",
|
||||||
) -> "Workflow":
|
) -> Workflow:
|
||||||
workflow = Workflow()
|
workflow = Workflow()
|
||||||
workflow.id = str(uuid4())
|
workflow.id = str(uuid4())
|
||||||
workflow.tenant_id = tenant_id
|
workflow.tenant_id = tenant_id
|
||||||
@ -619,7 +621,7 @@ class WorkflowRun(Base):
|
|||||||
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||||
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
||||||
|
|
||||||
pause: Mapped[Optional["WorkflowPause"]] = orm.relationship(
|
pause: Mapped[WorkflowPause | None] = orm.relationship(
|
||||||
"WorkflowPause",
|
"WorkflowPause",
|
||||||
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
|
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
|
||||||
uselist=False,
|
uselist=False,
|
||||||
@ -689,7 +691,7 @@ class WorkflowRun(Base):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
|
def from_dict(cls, data: dict[str, Any]) -> WorkflowRun:
|
||||||
return cls(
|
return cls(
|
||||||
id=data.get("id"),
|
id=data.get("id"),
|
||||||
tenant_id=data.get("tenant_id"),
|
tenant_id=data.get("tenant_id"),
|
||||||
@ -841,7 +843,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
|||||||
created_by: Mapped[str] = mapped_column(StringUUID)
|
created_by: Mapped[str] = mapped_column(StringUUID)
|
||||||
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||||
|
|
||||||
offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship(
|
offload_data: Mapped[list[WorkflowNodeExecutionOffload]] = orm.relationship(
|
||||||
"WorkflowNodeExecutionOffload",
|
"WorkflowNodeExecutionOffload",
|
||||||
primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)",
|
primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)",
|
||||||
uselist=True,
|
uselist=True,
|
||||||
@ -851,13 +853,13 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def preload_offload_data(
|
def preload_offload_data(
|
||||||
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
|
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
|
||||||
):
|
):
|
||||||
return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data))
|
return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def preload_offload_data_and_files(
|
def preload_offload_data_and_files(
|
||||||
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
|
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
|
||||||
):
|
):
|
||||||
return query.options(
|
return query.options(
|
||||||
orm.selectinload(WorkflowNodeExecutionModel.offload_data).options(
|
orm.selectinload(WorkflowNodeExecutionModel.offload_data).options(
|
||||||
@ -932,7 +934,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
|||||||
)
|
)
|
||||||
return extras
|
return extras
|
||||||
|
|
||||||
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
|
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> WorkflowNodeExecutionOffload | None:
|
||||||
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
|
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1046,7 +1048,7 @@ class WorkflowNodeExecutionOffload(Base):
|
|||||||
back_populates="offload_data",
|
back_populates="offload_data",
|
||||||
)
|
)
|
||||||
|
|
||||||
file: Mapped[Optional["UploadFile"]] = orm.relationship(
|
file: Mapped[UploadFile | None] = orm.relationship(
|
||||||
foreign_keys=[file_id],
|
foreign_keys=[file_id],
|
||||||
lazy="raise",
|
lazy="raise",
|
||||||
uselist=False,
|
uselist=False,
|
||||||
@ -1064,7 +1066,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
|
|||||||
INSTALLED_APP = "installed-app"
|
INSTALLED_APP = "installed-app"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
|
def value_of(cls, value: str) -> WorkflowAppLogCreatedFrom:
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
@ -1181,7 +1183,7 @@ class ConversationVariable(TypeBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> "ConversationVariable":
|
def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> ConversationVariable:
|
||||||
obj = cls(
|
obj = cls(
|
||||||
id=variable.id,
|
id=variable.id,
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
@ -1334,7 +1336,7 @@ class WorkflowDraftVariable(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Relationship to WorkflowDraftVariableFile
|
# Relationship to WorkflowDraftVariableFile
|
||||||
variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship(
|
variable_file: Mapped[WorkflowDraftVariableFile | None] = orm.relationship(
|
||||||
foreign_keys=[file_id],
|
foreign_keys=[file_id],
|
||||||
lazy="raise",
|
lazy="raise",
|
||||||
uselist=False,
|
uselist=False,
|
||||||
@ -1504,7 +1506,7 @@ class WorkflowDraftVariable(Base):
|
|||||||
node_execution_id: str | None,
|
node_execution_id: str | None,
|
||||||
description: str = "",
|
description: str = "",
|
||||||
file_id: str | None = None,
|
file_id: str | None = None,
|
||||||
) -> "WorkflowDraftVariable":
|
) -> WorkflowDraftVariable:
|
||||||
variable = WorkflowDraftVariable()
|
variable = WorkflowDraftVariable()
|
||||||
variable.id = str(uuid4())
|
variable.id = str(uuid4())
|
||||||
variable.created_at = naive_utc_now()
|
variable.created_at = naive_utc_now()
|
||||||
@ -1527,7 +1529,7 @@ class WorkflowDraftVariable(Base):
|
|||||||
name: str,
|
name: str,
|
||||||
value: Segment,
|
value: Segment,
|
||||||
description: str = "",
|
description: str = "",
|
||||||
) -> "WorkflowDraftVariable":
|
) -> WorkflowDraftVariable:
|
||||||
variable = cls._new(
|
variable = cls._new(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
node_id=CONVERSATION_VARIABLE_NODE_ID,
|
node_id=CONVERSATION_VARIABLE_NODE_ID,
|
||||||
@ -1548,7 +1550,7 @@ class WorkflowDraftVariable(Base):
|
|||||||
value: Segment,
|
value: Segment,
|
||||||
node_execution_id: str,
|
node_execution_id: str,
|
||||||
editable: bool = False,
|
editable: bool = False,
|
||||||
) -> "WorkflowDraftVariable":
|
) -> WorkflowDraftVariable:
|
||||||
variable = cls._new(
|
variable = cls._new(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
node_id=SYSTEM_VARIABLE_NODE_ID,
|
node_id=SYSTEM_VARIABLE_NODE_ID,
|
||||||
@ -1571,7 +1573,7 @@ class WorkflowDraftVariable(Base):
|
|||||||
visible: bool = True,
|
visible: bool = True,
|
||||||
editable: bool = True,
|
editable: bool = True,
|
||||||
file_id: str | None = None,
|
file_id: str | None = None,
|
||||||
) -> "WorkflowDraftVariable":
|
) -> WorkflowDraftVariable:
|
||||||
variable = cls._new(
|
variable = cls._new(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
@ -1667,7 +1669,7 @@ class WorkflowDraftVariableFile(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Relationship to UploadFile
|
# Relationship to UploadFile
|
||||||
upload_file: Mapped["UploadFile"] = orm.relationship(
|
upload_file: Mapped[UploadFile] = orm.relationship(
|
||||||
foreign_keys=[upload_file_id],
|
foreign_keys=[upload_file_id],
|
||||||
lazy="raise",
|
lazy="raise",
|
||||||
uselist=False,
|
uselist=False,
|
||||||
@ -1734,7 +1736,7 @@ class WorkflowPause(DefaultFieldsMixin, Base):
|
|||||||
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
|
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
|
||||||
|
|
||||||
# Relationship to WorkflowRun
|
# Relationship to WorkflowRun
|
||||||
workflow_run: Mapped["WorkflowRun"] = orm.relationship(
|
workflow_run: Mapped[WorkflowRun] = orm.relationship(
|
||||||
foreign_keys=[workflow_run_id],
|
foreign_keys=[workflow_run_id],
|
||||||
# require explicit preloading.
|
# require explicit preloading.
|
||||||
lazy="raise",
|
lazy="raise",
|
||||||
@ -1790,7 +1792,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
|
def from_entity(cls, pause_reason: PauseReason) -> WorkflowPauseReason:
|
||||||
if isinstance(pause_reason, HumanInputRequired):
|
if isinstance(pause_reason, HumanInputRequired):
|
||||||
return cls(
|
return cls(
|
||||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id
|
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
@ -106,7 +108,7 @@ class VariableTruncator(BaseTruncator):
|
|||||||
self._max_size_bytes = max_size_bytes
|
self._max_size_bytes = max_size_bytes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default(cls) -> "VariableTruncator":
|
def default(cls) -> VariableTruncator:
|
||||||
return VariableTruncator(
|
return VariableTruncator(
|
||||||
max_size_bytes=dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE,
|
max_size_bytes=dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE,
|
||||||
array_element_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH,
|
array_element_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH,
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -78,7 +80,7 @@ class WebsiteCrawlApiRequest:
|
|||||||
return CrawlRequest(url=self.url, provider=self.provider, options=options)
|
return CrawlRequest(url=self.url, provider=self.provider, options=options)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_args(cls, args: dict) -> "WebsiteCrawlApiRequest":
|
def from_args(cls, args: dict) -> WebsiteCrawlApiRequest:
|
||||||
"""Create from Flask-RESTful parsed arguments."""
|
"""Create from Flask-RESTful parsed arguments."""
|
||||||
provider = args.get("provider")
|
provider = args.get("provider")
|
||||||
url = args.get("url")
|
url = args.get("url")
|
||||||
@ -102,7 +104,7 @@ class WebsiteCrawlStatusApiRequest:
|
|||||||
job_id: str
|
job_id: str
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest":
|
def from_args(cls, args: dict, job_id: str) -> WebsiteCrawlStatusApiRequest:
|
||||||
"""Create from Flask-RESTful parsed arguments."""
|
"""Create from Flask-RESTful parsed arguments."""
|
||||||
provider = args.get("provider")
|
provider = args.get("provider")
|
||||||
if not provider:
|
if not provider:
|
||||||
|
|||||||
@ -5,6 +5,8 @@ This module provides a flexible configuration system for customizing
|
|||||||
the behavior of mock nodes during testing.
|
the behavior of mock nodes during testing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -95,67 +97,67 @@ class MockConfigBuilder:
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._config = MockConfig()
|
self._config = MockConfig()
|
||||||
|
|
||||||
def with_auto_mock(self, enabled: bool = True) -> "MockConfigBuilder":
|
def with_auto_mock(self, enabled: bool = True) -> MockConfigBuilder:
|
||||||
"""Enable or disable auto-mocking."""
|
"""Enable or disable auto-mocking."""
|
||||||
self._config.enable_auto_mock = enabled
|
self._config.enable_auto_mock = enabled
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_delays(self, enabled: bool = True) -> "MockConfigBuilder":
|
def with_delays(self, enabled: bool = True) -> MockConfigBuilder:
|
||||||
"""Enable or disable simulated execution delays."""
|
"""Enable or disable simulated execution delays."""
|
||||||
self._config.simulate_delays = enabled
|
self._config.simulate_delays = enabled
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_llm_response(self, response: str) -> "MockConfigBuilder":
|
def with_llm_response(self, response: str) -> MockConfigBuilder:
|
||||||
"""Set default LLM response."""
|
"""Set default LLM response."""
|
||||||
self._config.default_llm_response = response
|
self._config.default_llm_response = response
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_agent_response(self, response: str) -> "MockConfigBuilder":
|
def with_agent_response(self, response: str) -> MockConfigBuilder:
|
||||||
"""Set default agent response."""
|
"""Set default agent response."""
|
||||||
self._config.default_agent_response = response
|
self._config.default_agent_response = response
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_tool_response(self, response: dict[str, Any]) -> "MockConfigBuilder":
|
def with_tool_response(self, response: dict[str, Any]) -> MockConfigBuilder:
|
||||||
"""Set default tool response."""
|
"""Set default tool response."""
|
||||||
self._config.default_tool_response = response
|
self._config.default_tool_response = response
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_retrieval_response(self, response: str) -> "MockConfigBuilder":
|
def with_retrieval_response(self, response: str) -> MockConfigBuilder:
|
||||||
"""Set default retrieval response."""
|
"""Set default retrieval response."""
|
||||||
self._config.default_retrieval_response = response
|
self._config.default_retrieval_response = response
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_http_response(self, response: dict[str, Any]) -> "MockConfigBuilder":
|
def with_http_response(self, response: dict[str, Any]) -> MockConfigBuilder:
|
||||||
"""Set default HTTP response."""
|
"""Set default HTTP response."""
|
||||||
self._config.default_http_response = response
|
self._config.default_http_response = response
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_template_transform_response(self, response: str) -> "MockConfigBuilder":
|
def with_template_transform_response(self, response: str) -> MockConfigBuilder:
|
||||||
"""Set default template transform response."""
|
"""Set default template transform response."""
|
||||||
self._config.default_template_transform_response = response
|
self._config.default_template_transform_response = response
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_code_response(self, response: dict[str, Any]) -> "MockConfigBuilder":
|
def with_code_response(self, response: dict[str, Any]) -> MockConfigBuilder:
|
||||||
"""Set default code execution response."""
|
"""Set default code execution response."""
|
||||||
self._config.default_code_response = response
|
self._config.default_code_response = response
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_node_output(self, node_id: str, outputs: dict[str, Any]) -> "MockConfigBuilder":
|
def with_node_output(self, node_id: str, outputs: dict[str, Any]) -> MockConfigBuilder:
|
||||||
"""Set outputs for a specific node."""
|
"""Set outputs for a specific node."""
|
||||||
self._config.set_node_outputs(node_id, outputs)
|
self._config.set_node_outputs(node_id, outputs)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_node_error(self, node_id: str, error: str) -> "MockConfigBuilder":
|
def with_node_error(self, node_id: str, error: str) -> MockConfigBuilder:
|
||||||
"""Set error for a specific node."""
|
"""Set error for a specific node."""
|
||||||
self._config.set_node_error(node_id, error)
|
self._config.set_node_error(node_id, error)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_node_config(self, config: NodeMockConfig) -> "MockConfigBuilder":
|
def with_node_config(self, config: NodeMockConfig) -> MockConfigBuilder:
|
||||||
"""Add a node-specific configuration."""
|
"""Add a node-specific configuration."""
|
||||||
self._config.set_node_config(config.node_id, config)
|
self._config.set_node_config(config.node_id, config)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_default_config(self, node_type: NodeType, config: dict[str, Any]) -> "MockConfigBuilder":
|
def with_default_config(self, node_type: NodeType, config: dict[str, Any]) -> MockConfigBuilder:
|
||||||
"""Set default configuration for a node type."""
|
"""Set default configuration for a node type."""
|
||||||
self._config.set_default_config(node_type, config)
|
self._config.set_default_config(node_type, config)
|
||||||
return self
|
return self
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
@ -21,7 +23,7 @@ if TYPE_CHECKING: # pragma: no cover - imported for type checking only
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def tool_node(monkeypatch) -> "ToolNode":
|
def tool_node(monkeypatch) -> ToolNode:
|
||||||
module_name = "core.ops.ops_trace_manager"
|
module_name = "core.ops.ops_trace_manager"
|
||||||
if module_name not in sys.modules:
|
if module_name not in sys.modules:
|
||||||
ops_stub = types.ModuleType(module_name)
|
ops_stub = types.ModuleType(module_name)
|
||||||
@ -85,7 +87,7 @@ def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]:
|
|||||||
return events, stop.value
|
return events, stop.value
|
||||||
|
|
||||||
|
|
||||||
def _run_transform(tool_node: "ToolNode", message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]:
|
def _run_transform(tool_node: ToolNode, message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]:
|
||||||
def _identity_transform(messages, *_args, **_kwargs):
|
def _identity_transform(messages, *_args, **_kwargs):
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@ -103,7 +105,7 @@ def _run_transform(tool_node: "ToolNode", message: ToolInvokeMessage) -> tuple[l
|
|||||||
return _collect_events(generator)
|
return _collect_events(generator)
|
||||||
|
|
||||||
|
|
||||||
def test_link_messages_with_file_populate_files_output(tool_node: "ToolNode"):
|
def test_link_messages_with_file_populate_files_output(tool_node: ToolNode):
|
||||||
file_obj = File(
|
file_obj = File(
|
||||||
tenant_id="tenant-id",
|
tenant_id="tenant-id",
|
||||||
type=FileType.DOCUMENT,
|
type=FileType.DOCUMENT,
|
||||||
@ -139,7 +141,7 @@ def test_link_messages_with_file_populate_files_output(tool_node: "ToolNode"):
|
|||||||
assert files_segment.value == [file_obj]
|
assert files_segment.value == [file_obj]
|
||||||
|
|
||||||
|
|
||||||
def test_plain_link_messages_remain_links(tool_node: "ToolNode"):
|
def test_plain_link_messages_remain_links(tool_node: ToolNode):
|
||||||
message = ToolInvokeMessage(
|
message = ToolInvokeMessage(
|
||||||
type=ToolInvokeMessage.MessageType.LINK,
|
type=ToolInvokeMessage.MessageType.LINK,
|
||||||
message=ToolInvokeMessage.TextMessage(text="https://dify.ai"),
|
message=ToolInvokeMessage.TextMessage(text="https://dify.ai"),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user