Merge branch 'main' into feat/grouping-branching

# Conflicts:
#	web/package.json
This commit is contained in:
zhsama 2026-01-06 22:00:01 +08:00
commit 760a739e91
156 changed files with 5890 additions and 1553 deletions

View File

@ -20,4 +20,4 @@
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [x] I've updated the documentation accordingly.
- [x] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods

View File

@ -60,9 +60,10 @@ check:
@echo "✅ Code check complete"
lint:
@echo "🔧 Running ruff format, check with fixes, and import linter..."
@echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..."
@uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
@uv run --directory api --dev lint-imports
@uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
@echo "✅ Linting complete"
type-check:
@ -122,7 +123,7 @@ help:
@echo "Backend Code Quality:"
@echo " make format - Format code with ruff"
@echo " make check - Check code with ruff"
@echo " make lint - Format and fix code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make type-check - Run type checking with basedpyright"
@echo " make test - Run backend unit tests"
@echo ""

View File

@ -575,6 +575,10 @@ LOGSTORE_DUAL_WRITE_ENABLED=false
# Enable dual-read fallback to SQL database when LogStore returns no results (default: true)
# Useful for migration scenarios where historical data exists only in SQL database
LOGSTORE_DUAL_READ_ENABLED=true
# Control flag for whether to write the `graph` field to LogStore.
# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
# otherwise write an empty {} instead. Defaults to writing the `graph` field.
LOGSTORE_ENABLE_PUT_GRAPH_FIELD=true
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1

View File

@ -3,9 +3,11 @@ root_packages =
core
configs
controllers
extensions
models
tasks
services
include_external_packages = True
[importlinter:contract:workflow]
name = Workflow
@ -33,6 +35,28 @@ ignore_imports =
core.workflow.nodes.loop.loop_node -> core.workflow.graph
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
[importlinter:contract:workflow-infrastructure-dependencies]
name = Workflow Infrastructure Dependencies
type = forbidden
source_modules =
core.workflow
forbidden_modules =
extensions.ext_database
extensions.ext_redis
allow_indirect_imports = True
ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.graph_engine.manager -> extensions.ext_redis
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
[importlinter:contract:rsc]
name = RSC
type = layers

View File

@ -50,16 +50,33 @@ WORKDIR /app/api
# Create non-root user
ARG dify_uid=1001
ARG NODE_MAJOR=22
ARG NODE_PACKAGE_VERSION=22.21.0-1nodesource1
ARG NODESOURCE_KEY_FPR=6F71F525282841EEDAF851B42F59B5F99B1BE0B4
RUN groupadd -r -g ${dify_uid} dify && \
useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \
chown -R dify:dify /app
RUN \
apt-get update \
&& apt-get install -y --no-install-recommends \
ca-certificates \
curl \
gnupg \
&& mkdir -p /etc/apt/keyrings \
&& curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key -o /tmp/nodesource.gpg \
&& gpg --show-keys --with-colons /tmp/nodesource.gpg \
| awk -F: '/^fpr:/ {print $10}' \
| grep -Fx "${NODESOURCE_KEY_FPR}" \
&& gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg /tmp/nodesource.gpg \
&& rm -f /tmp/nodesource.gpg \
&& echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_${NODE_MAJOR}.x nodistro main" \
> /etc/apt/sources.list.d/nodesource.list \
&& apt-get update \
# Install dependencies
&& apt-get install -y --no-install-recommends \
# basic environment
curl nodejs \
nodejs=${NODE_PACKAGE_VERSION} \
# for gmpy2 \
libgmp-dev libmpfr-dev libmpc-dev \
# For Security
@ -79,7 +96,8 @@ COPY --from=packages --chown=dify:dify ${VIRTUAL_ENV} ${VIRTUAL_ENV}
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data
RUN mkdir -p /usr/local/share/nltk_data && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \
RUN mkdir -p /usr/local/share/nltk_data \
&& NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; from unstructured.nlp.tokenize import download_nltk_packages; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords'); download_nltk_packages()" \
&& chmod -R 755 /usr/local/share/nltk_data
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache

View File

@ -235,7 +235,7 @@ def migrate_annotation_vector_database():
if annotations:
for annotation in annotations:
document = Document(
page_content=annotation.question,
page_content=annotation.question_text,
metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id},
)
documents.append(document)
@ -1184,6 +1184,217 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow"))
@click.command("file-usage", help="Query file usages and show where files are referenced.")
@click.option("--file-id", type=str, default=None, help="Filter by file UUID.")
@click.option("--key", type=str, default=None, help="Filter by storage key.")
@click.option("--src", type=str, default=None, help="Filter by table.column pattern (e.g., 'documents.%' or '%.icon').")
@click.option("--limit", type=int, default=100, help="Limit number of results (default: 100).")
@click.option("--offset", type=int, default=0, help="Offset for pagination (default: 0).")
@click.option("--json", "output_json", is_flag=True, help="Output results in JSON format.")
def file_usage(
file_id: str | None,
key: str | None,
src: str | None,
limit: int,
offset: int,
output_json: bool,
):
"""
Query file usages and show where files are referenced in the database.
This command reuses the same reference checking logic as clear-orphaned-file-records
and displays detailed information about where each file is referenced.
"""
# define tables and columns to process
files_tables = [
{"table": "upload_files", "id_column": "id", "key_column": "key"},
{"table": "tool_files", "id_column": "id", "key_column": "file_key"},
]
ids_tables = [
{"type": "uuid", "table": "message_files", "column": "upload_file_id", "pk_column": "id"},
{"type": "text", "table": "documents", "column": "data_source_info", "pk_column": "id"},
{"type": "text", "table": "document_segments", "column": "content", "pk_column": "id"},
{"type": "text", "table": "messages", "column": "answer", "pk_column": "id"},
{"type": "text", "table": "workflow_node_executions", "column": "inputs", "pk_column": "id"},
{"type": "text", "table": "workflow_node_executions", "column": "process_data", "pk_column": "id"},
{"type": "text", "table": "workflow_node_executions", "column": "outputs", "pk_column": "id"},
{"type": "text", "table": "conversations", "column": "introduction", "pk_column": "id"},
{"type": "text", "table": "conversations", "column": "system_instruction", "pk_column": "id"},
{"type": "text", "table": "accounts", "column": "avatar", "pk_column": "id"},
{"type": "text", "table": "apps", "column": "icon", "pk_column": "id"},
{"type": "text", "table": "sites", "column": "icon", "pk_column": "id"},
{"type": "json", "table": "messages", "column": "inputs", "pk_column": "id"},
{"type": "json", "table": "messages", "column": "message", "pk_column": "id"},
]
# Stream file usages with pagination to avoid holding all results in memory
paginated_usages = []
total_count = 0
# First, build a mapping of file_id -> storage_key from the base tables
file_key_map = {}
for files_table in files_tables:
query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}"
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
file_key_map[str(row[0])] = f"{files_table['table']}:{row[1]}"
# If filtering by key or file_id, verify it exists
if file_id and file_id not in file_key_map:
if output_json:
click.echo(json.dumps({"error": f"File ID {file_id} not found in base tables"}))
else:
click.echo(click.style(f"File ID {file_id} not found in base tables.", fg="red"))
return
if key:
valid_prefixes = {f"upload_files:{key}", f"tool_files:{key}"}
matching_file_ids = [fid for fid, fkey in file_key_map.items() if fkey in valid_prefixes]
if not matching_file_ids:
if output_json:
click.echo(json.dumps({"error": f"Key {key} not found in base tables"}))
else:
click.echo(click.style(f"Key {key} not found in base tables.", fg="red"))
return
guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
# For each reference table/column, find matching file IDs and record the references
for ids_table in ids_tables:
src_filter = f"{ids_table['table']}.{ids_table['column']}"
# Skip if src filter doesn't match (use fnmatch for wildcard patterns)
if src:
if "%" in src or "_" in src:
import fnmatch
# Convert SQL LIKE wildcards to fnmatch wildcards (% -> *, _ -> ?)
pattern = src.replace("%", "*").replace("_", "?")
if not fnmatch.fnmatch(src_filter, pattern):
continue
else:
if src_filter != src:
continue
if ids_table["type"] == "uuid":
# Direct UUID match
query = (
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
ref_file_id = str(row[1])
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
# Apply filters
if file_id and ref_file_id != file_id:
continue
if key and not storage_key.endswith(key):
continue
# Only collect items within the requested page range
if offset <= total_count < offset + limit:
paginated_usages.append(
{
"src": f"{ids_table['table']}.{ids_table['column']}",
"record_id": record_id,
"file_id": ref_file_id,
"key": storage_key,
}
)
total_count += 1
elif ids_table["type"] in ("text", "json"):
# Extract UUIDs from text/json content
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
query = (
f"SELECT {ids_table['pk_column']}, {column_cast} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
content = str(row[1])
# Find all UUIDs in the content
import re
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
matches = uuid_pattern.findall(content)
for ref_file_id in matches:
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
# Apply filters
if file_id and ref_file_id != file_id:
continue
if key and not storage_key.endswith(key):
continue
# Only collect items within the requested page range
if offset <= total_count < offset + limit:
paginated_usages.append(
{
"src": f"{ids_table['table']}.{ids_table['column']}",
"record_id": record_id,
"file_id": ref_file_id,
"key": storage_key,
}
)
total_count += 1
# Output results
if output_json:
result = {
"total": total_count,
"offset": offset,
"limit": limit,
"usages": paginated_usages,
}
click.echo(json.dumps(result, indent=2))
else:
click.echo(
click.style(f"Found {total_count} file usages (showing {len(paginated_usages)} results)", fg="white")
)
click.echo("")
if not paginated_usages:
click.echo(click.style("No file usages found matching the specified criteria.", fg="yellow"))
return
# Print table header
click.echo(
click.style(
f"{'Src (Table.Column)':<50} {'Record ID':<40} {'File ID':<40} {'Storage Key':<60}",
fg="cyan",
)
)
click.echo(click.style("-" * 190, fg="white"))
# Print each usage
for usage in paginated_usages:
click.echo(f"{usage['src']:<50} {usage['record_id']:<40} {usage['file_id']:<40} {usage['key']:<60}")
# Show pagination info
if offset + limit < total_count:
click.echo("")
click.echo(
click.style(
f"Showing {offset + 1}-{offset + len(paginated_usages)} of {total_count} results", fg="white"
)
)
click.echo(click.style(f"Use --offset {offset + limit} to see next page", fg="white"))
@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")

View File

@ -16,7 +16,6 @@ class MilvusConfig(BaseSettings):
description="Authentication token for Milvus, if token-based authentication is enabled",
default=None,
)
MILVUS_USER: str | None = Field(
description="Username for authenticating with Milvus, if username/password authentication is enabled",
default=None,

View File

@ -1,14 +1,16 @@
import re
import uuid
from typing import Literal
from datetime import datetime
from typing import Any, Literal, TypeAlias
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
@ -19,27 +21,19 @@ from controllers.console.wraps import (
is_admin_or_owner_required,
setup_required,
)
from core.file import helpers as file_helpers
from core.ops.ops_trace_manager import OpsTraceManager
from core.workflow.enums import NodeType
from extensions.ext_database import db
from fields.app_fields import (
deleted_tool_fields,
model_config_fields,
model_config_partial_fields,
site_fields,
tag_fields,
)
from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
from libs.helper import AppIconUrlField, TimestampField
from libs.login import current_account_with_tenant, login_required
from models import App, Workflow
from models.model import IconType
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppListQuery(BaseModel):
@ -192,124 +186,292 @@ class AppTracePayload(BaseModel):
return value
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
JSONValue: TypeAlias = Any
reg(AppListQuery)
reg(CreateAppPayload)
reg(UpdateAppPayload)
reg(CopyAppPayload)
reg(AppExportQuery)
reg(AppNamePayload)
reg(AppIconPayload)
reg(AppSiteStatusPayload)
reg(AppApiStatusPayload)
reg(AppTracePayload)
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base models first
tag_model = console_ns.model("Tag", tag_fields)
workflow_partial_model = console_ns.model("WorkflowPartial", _workflow_partial_fields_dict)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
model_config_model = console_ns.model("ModelConfig", model_config_fields)
model_config_partial_model = console_ns.model("ModelConfigPartial", model_config_partial_fields)
def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None:
if icon is None or icon_type is None:
return None
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
if icon_type_value.lower() != IconType.IMAGE.value:
return None
return file_helpers.get_signed_file_url(icon)
deleted_tool_model = console_ns.model("DeletedTool", deleted_tool_fields)
site_model = console_ns.model("Site", site_fields)
class Tag(ResponseModel):
id: str
name: str
type: str
app_partial_model = console_ns.model(
"AppPartial",
{
"id": fields.String,
"name": fields.String,
"max_active_requests": fields.Raw(),
"description": fields.String(attribute="desc_or_prompt"),
"mode": fields.String(attribute="mode_compatible_with_agent"),
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"model_config": fields.Nested(model_config_partial_model, attribute="app_model_config", allow_null=True),
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
"use_icon_as_answer_icon": fields.Boolean,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
"tags": fields.List(fields.Nested(tag_model)),
"access_mode": fields.String,
"create_user_name": fields.String,
"author_name": fields.String,
"has_draft_trigger": fields.Boolean,
},
)
app_detail_model = console_ns.model(
"AppDetail",
{
"id": fields.String,
"name": fields.String,
"description": fields.String,
"mode": fields.String(attribute="mode_compatible_with_agent"),
"icon": fields.String,
"icon_background": fields.String,
"enable_site": fields.Boolean,
"enable_api": fields.Boolean,
"model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
"tracing": fields.Raw,
"use_icon_as_answer_icon": fields.Boolean,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
"access_mode": fields.String,
"tags": fields.List(fields.Nested(tag_model)),
},
)
class WorkflowPartial(ResponseModel):
id: str
created_by: str | None = None
created_at: int | None = None
updated_by: str | None = None
updated_at: int | None = None
app_detail_with_site_model = console_ns.model(
"AppDetailWithSite",
{
"id": fields.String,
"name": fields.String,
"description": fields.String,
"mode": fields.String(attribute="mode_compatible_with_agent"),
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"enable_site": fields.Boolean,
"enable_api": fields.Boolean,
"model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
"api_base_url": fields.String,
"use_icon_as_answer_icon": fields.Boolean,
"max_active_requests": fields.Integer,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
"deleted_tools": fields.List(fields.Nested(deleted_tool_model)),
"access_mode": fields.String,
"tags": fields.List(fields.Nested(tag_model)),
"site": fields.Nested(site_model),
},
)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
app_pagination_model = console_ns.model(
"AppPagination",
{
"page": fields.Integer,
"limit": fields.Integer(attribute="per_page"),
"total": fields.Integer,
"has_more": fields.Boolean(attribute="has_next"),
"data": fields.List(fields.Nested(app_partial_model), attribute="items"),
},
class ModelConfigPartial(ResponseModel):
model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model"))
pre_prompt: str | None = None
created_by: str | None = None
created_at: int | None = None
updated_by: str | None = None
updated_at: int | None = None
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class ModelConfig(ResponseModel):
opening_statement: str | None = None
suggested_questions: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("suggested_questions_list", "suggested_questions")
)
suggested_questions_after_answer: JSONValue | None = Field(
default=None,
validation_alias=AliasChoices("suggested_questions_after_answer_dict", "suggested_questions_after_answer"),
)
speech_to_text: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("speech_to_text_dict", "speech_to_text")
)
text_to_speech: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("text_to_speech_dict", "text_to_speech")
)
retriever_resource: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("retriever_resource_dict", "retriever_resource")
)
annotation_reply: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("annotation_reply_dict", "annotation_reply")
)
more_like_this: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("more_like_this_dict", "more_like_this")
)
sensitive_word_avoidance: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("sensitive_word_avoidance_dict", "sensitive_word_avoidance")
)
external_data_tools: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("external_data_tools_list", "external_data_tools")
)
model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model"))
user_input_form: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("user_input_form_list", "user_input_form")
)
dataset_query_variable: str | None = None
pre_prompt: str | None = None
agent_mode: JSONValue | None = Field(default=None, validation_alias=AliasChoices("agent_mode_dict", "agent_mode"))
prompt_type: str | None = None
chat_prompt_config: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("chat_prompt_config_dict", "chat_prompt_config")
)
completion_prompt_config: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("completion_prompt_config_dict", "completion_prompt_config")
)
dataset_configs: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("dataset_configs_dict", "dataset_configs")
)
file_upload: JSONValue | None = Field(
default=None, validation_alias=AliasChoices("file_upload_dict", "file_upload")
)
created_by: str | None = None
created_at: int | None = None
updated_by: str | None = None
updated_at: int | None = None
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class Site(ResponseModel):
access_token: str | None = Field(default=None, validation_alias="code")
code: str | None = None
title: str | None = None
icon_type: str | IconType | None = None
icon: str | None = None
icon_background: str | None = None
description: str | None = None
default_language: str | None = None
chat_color_theme: str | None = None
chat_color_theme_inverted: bool | None = None
customize_domain: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
customize_token_strategy: str | None = None
prompt_public: bool | None = None
app_base_url: str | None = None
show_workflow_steps: bool | None = None
use_icon_as_answer_icon: bool | None = None
created_by: str | None = None
created_at: int | None = None
updated_by: str | None = None
updated_at: int | None = None
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
return _build_icon_url(self.icon_type, self.icon)
@field_validator("icon_type", mode="before")
@classmethod
def _normalize_icon_type(cls, value: str | IconType | None) -> str | None:
if isinstance(value, IconType):
return value.value
return value
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class DeletedTool(ResponseModel):
type: str
tool_name: str
provider_id: str
class AppPartial(ResponseModel):
id: str
name: str
max_active_requests: int | None = None
description: str | None = Field(default=None, validation_alias=AliasChoices("desc_or_prompt", "description"))
mode: str = Field(validation_alias="mode_compatible_with_agent")
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
model_config_: ModelConfigPartial | None = Field(
default=None,
validation_alias=AliasChoices("app_model_config", "model_config"),
alias="model_config",
)
workflow: WorkflowPartial | None = None
use_icon_as_answer_icon: bool | None = None
created_by: str | None = None
created_at: int | None = None
updated_by: str | None = None
updated_at: int | None = None
tags: list[Tag] = Field(default_factory=list)
access_mode: str | None = None
create_user_name: str | None = None
author_name: str | None = None
has_draft_trigger: bool | None = None
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
return _build_icon_url(self.icon_type, self.icon)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class AppDetail(ResponseModel):
id: str
name: str
description: str | None = None
mode: str = Field(validation_alias="mode_compatible_with_agent")
icon: str | None = None
icon_background: str | None = None
enable_site: bool
enable_api: bool
model_config_: ModelConfig | None = Field(
default=None,
validation_alias=AliasChoices("app_model_config", "model_config"),
alias="model_config",
)
workflow: WorkflowPartial | None = None
tracing: JSONValue | None = None
use_icon_as_answer_icon: bool | None = None
created_by: str | None = None
created_at: int | None = None
updated_by: str | None = None
updated_at: int | None = None
access_mode: str | None = None
tags: list[Tag] = Field(default_factory=list)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class AppDetailWithSite(AppDetail):
icon_type: str | None = None
api_base_url: str | None = None
max_active_requests: int | None = None
deleted_tools: list[DeletedTool] = Field(default_factory=list)
site: Site | None = None
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
return _build_icon_url(self.icon_type, self.icon)
class AppPagination(ResponseModel):
page: int
limit: int = Field(validation_alias=AliasChoices("per_page", "limit"))
total: int
has_more: bool = Field(validation_alias=AliasChoices("has_next", "has_more"))
data: list[AppPartial] = Field(validation_alias=AliasChoices("items", "data"))
class AppExportResponse(ResponseModel):
data: str
register_schema_models(
console_ns,
AppListQuery,
CreateAppPayload,
UpdateAppPayload,
CopyAppPayload,
AppExportQuery,
AppNamePayload,
AppIconPayload,
AppSiteStatusPayload,
AppApiStatusPayload,
AppTracePayload,
Tag,
WorkflowPartial,
ModelConfigPartial,
ModelConfig,
Site,
DeletedTool,
AppPartial,
AppDetail,
AppDetailWithSite,
AppPagination,
AppExportResponse,
)
@ -318,7 +480,7 @@ class AppListApi(Resource):
@console_ns.doc("list_apps")
@console_ns.doc(description="Get list of applications with pagination and filtering")
@console_ns.expect(console_ns.models[AppListQuery.__name__])
@console_ns.response(200, "Success", app_pagination_model)
@console_ns.response(200, "Success", console_ns.models[AppPagination.__name__])
@setup_required
@login_required
@account_initialization_required
@ -334,7 +496,8 @@ class AppListApi(Resource):
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
return empty.model_dump(mode="json"), 200
if FeatureService.get_system_features().webapp_auth.enabled:
app_ids = [str(app.id) for app in app_pagination.items]
@ -378,18 +541,18 @@ class AppListApi(Resource):
for app in app_pagination.items:
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
return marshal(app_pagination, app_pagination_model), 200
pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True)
return pagination_model.model_dump(mode="json"), 200
@console_ns.doc("create_app")
@console_ns.doc(description="Create a new application")
@console_ns.expect(console_ns.models[CreateAppPayload.__name__])
@console_ns.response(201, "App created successfully", app_detail_model)
@console_ns.response(201, "App created successfully", console_ns.models[AppDetail.__name__])
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_detail_model)
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
@ -399,8 +562,8 @@ class AppListApi(Resource):
app_service = AppService()
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
return app, 201
app_detail = AppDetail.model_validate(app, from_attributes=True)
return app_detail.model_dump(mode="json"), 201
@console_ns.route("/apps/<uuid:app_id>")
@ -408,13 +571,12 @@ class AppApi(Resource):
@console_ns.doc("get_app_detail")
@console_ns.doc(description="Get application details")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Success", app_detail_with_site_model)
@console_ns.response(200, "Success", console_ns.models[AppDetailWithSite.__name__])
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@get_app_model
@marshal_with(app_detail_with_site_model)
@get_app_model(mode=None)
def get(self, app_model):
"""Get app detail"""
app_service = AppService()
@ -425,21 +587,21 @@ class AppApi(Resource):
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
app_model.access_mode = app_setting.access_mode
return app_model
response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
return response_model.model_dump(mode="json")
@console_ns.doc("update_app")
@console_ns.doc(description="Update application details")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[UpdateAppPayload.__name__])
@console_ns.response(200, "App updated successfully", app_detail_with_site_model)
@console_ns.response(200, "App updated successfully", console_ns.models[AppDetailWithSite.__name__])
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
@get_app_model
@get_app_model(mode=None)
@edit_permission_required
@marshal_with(app_detail_with_site_model)
def put(self, app_model):
"""Update app"""
args = UpdateAppPayload.model_validate(console_ns.payload)
@ -456,8 +618,8 @@ class AppApi(Resource):
"max_active_requests": args.max_active_requests or 0,
}
app_model = app_service.update_app(app_model, args_dict)
return app_model
response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
return response_model.model_dump(mode="json")
@console_ns.doc("delete_app")
@console_ns.doc(description="Delete application")
@ -483,14 +645,13 @@ class AppCopyApi(Resource):
@console_ns.doc(description="Create a copy of an existing application")
@console_ns.doc(params={"app_id": "Application ID to copy"})
@console_ns.expect(console_ns.models[CopyAppPayload.__name__])
@console_ns.response(201, "App copied successfully", app_detail_with_site_model)
@console_ns.response(201, "App copied successfully", console_ns.models[AppDetailWithSite.__name__])
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@get_app_model
@get_app_model(mode=None)
@edit_permission_required
@marshal_with(app_detail_with_site_model)
def post(self, app_model):
"""Copy app"""
# The role of the current user in the ta table must be admin, owner, or editor
@ -516,7 +677,8 @@ class AppCopyApi(Resource):
stmt = select(App).where(App.id == result.app_id)
app = session.scalar(stmt)
return app, 201
response_model = AppDetailWithSite.model_validate(app, from_attributes=True)
return response_model.model_dump(mode="json"), 201
@console_ns.route("/apps/<uuid:app_id>/export")
@ -525,11 +687,7 @@ class AppExportApi(Resource):
@console_ns.doc(description="Export application configuration as DSL")
@console_ns.doc(params={"app_id": "Application ID to export"})
@console_ns.expect(console_ns.models[AppExportQuery.__name__])
@console_ns.response(
200,
"App exported successfully",
console_ns.model("AppExportResponse", {"data": fields.String(description="DSL export data")}),
)
@console_ns.response(200, "App exported successfully", console_ns.models[AppExportResponse.__name__])
@console_ns.response(403, "Insufficient permissions")
@get_app_model
@setup_required
@ -540,13 +698,14 @@ class AppExportApi(Resource):
"""Export app"""
args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return {
"data": AppDslService.export_dsl(
payload = AppExportResponse(
data=AppDslService.export_dsl(
app_model=app_model,
include_secret=args.include_secret,
workflow_id=args.workflow_id,
)
}
)
return payload.model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/name")
@ -555,20 +714,19 @@ class AppNameApi(Resource):
@console_ns.doc(description="Check if app name is available")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppNamePayload.__name__])
@console_ns.response(200, "Name availability checked")
@console_ns.response(200, "Name availability checked", console_ns.models[AppDetail.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_detail_model)
@get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
args = AppNamePayload.model_validate(console_ns.payload)
app_service = AppService()
app_model = app_service.update_app_name(app_model, args.name)
return app_model
response_model = AppDetail.model_validate(app_model, from_attributes=True)
return response_model.model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/icon")
@ -582,16 +740,15 @@ class AppIconApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_detail_model)
@get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
args = AppIconPayload.model_validate(console_ns.payload or {})
app_service = AppService()
app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
return app_model
response_model = AppDetail.model_validate(app_model, from_attributes=True)
return response_model.model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/site-enable")
@ -600,21 +757,20 @@ class AppSiteStatus(Resource):
@console_ns.doc(description="Enable or disable app site")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__])
@console_ns.response(200, "Site status updated successfully", app_detail_model)
@console_ns.response(200, "Site status updated successfully", console_ns.models[AppDetail.__name__])
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_detail_model)
@get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
args = AppSiteStatusPayload.model_validate(console_ns.payload)
app_service = AppService()
app_model = app_service.update_app_site_status(app_model, args.enable_site)
return app_model
response_model = AppDetail.model_validate(app_model, from_attributes=True)
return response_model.model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/api-enable")
@ -623,21 +779,20 @@ class AppApiStatus(Resource):
@console_ns.doc(description="Enable or disable app API")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppApiStatusPayload.__name__])
@console_ns.response(200, "API status updated successfully", app_detail_model)
@console_ns.response(200, "API status updated successfully", console_ns.models[AppDetail.__name__])
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
@get_app_model
@marshal_with(app_detail_model)
@get_app_model(mode=None)
def post(self, app_model):
args = AppApiStatusPayload.model_validate(console_ns.payload)
app_service = AppService()
app_model = app_service.update_app_api_status(app_model, args.enable_api)
return app_model
response_model = AppDetail.model_validate(app_model, from_attributes=True)
return response_model.model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/trace")

View File

@ -348,10 +348,13 @@ class CompletionConversationApi(Resource):
)
if args.keyword:
from libs.helper import escape_like_pattern
escaped_keyword = escape_like_pattern(args.keyword)
query = query.join(Message, Message.conversation_id == Conversation.id).where(
or_(
Message.query.ilike(f"%{args.keyword}%"),
Message.answer.ilike(f"%{args.keyword}%"),
Message.query.ilike(f"%{escaped_keyword}%", escape="\\"),
Message.answer.ilike(f"%{escaped_keyword}%", escape="\\"),
)
)
@ -460,7 +463,10 @@ class ChatConversationApi(Resource):
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
if args.keyword:
keyword_filter = f"%{args.keyword}%"
from libs.helper import escape_like_pattern
escaped_keyword = escape_like_pattern(args.keyword)
keyword_filter = f"%{escaped_keyword}%"
query = (
query.join(
Message,
@ -469,11 +475,11 @@ class ChatConversationApi(Resource):
.join(subquery, subquery.c.conversation_id == Conversation.id)
.where(
or_(
Message.query.ilike(keyword_filter),
Message.answer.ilike(keyword_filter),
Conversation.name.ilike(keyword_filter),
Conversation.introduction.ilike(keyword_filter),
subquery.c.from_end_user_session_id.ilike(keyword_filter),
Message.query.ilike(keyword_filter, escape="\\"),
Message.answer.ilike(keyword_filter, escape="\\"),
Conversation.name.ilike(keyword_filter, escape="\\"),
Conversation.introduction.ilike(keyword_filter, escape="\\"),
subquery.c.from_end_user_session_id.ilike(keyword_filter, escape="\\"),
),
)
.group_by(Conversation.id)

View File

@ -1,3 +1,5 @@
from typing import Any
import flask_login
from flask import make_response, request
from flask_restx import Resource
@ -96,14 +98,13 @@ class LoginApi(Resource):
if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError()
# TODO: why invitation is re-assigned with different type?
invitation = args.invite_token # type: ignore
if invitation:
invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
invitation_data: dict[str, Any] | None = None
if args.invite_token:
invitation_data = RegisterService.get_invitation_if_token_valid(None, args.email, args.invite_token)
try:
if invitation:
data = invitation.get("data", {}) # type: ignore
if invitation_data:
data = invitation_data.get("data", {})
invitee_email = data.get("email") if data else None
if invitee_email != args.email:
raise InvalidEmailError()

View File

@ -751,12 +751,12 @@ class DocumentApi(DocumentResource):
elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
"position": document.position,
"data_source_type": document.data_source_type,
"data_source_info": data_source_info,
"data_source_info": document.data_source_info_dict,
"data_source_detail_dict": document.data_source_detail_dict,
"dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
@ -784,12 +784,12 @@ class DocumentApi(DocumentResource):
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
"position": document.position,
"data_source_type": document.data_source_type,
"data_source_info": data_source_info,
"data_source_info": document.data_source_info_dict,
"data_source_detail_dict": document.data_source_detail_dict,
"dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,

View File

@ -30,6 +30,7 @@ from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields
from libs.helper import escape_like_pattern
from libs.login import current_account_with_tenant, login_required
from models.dataset import ChildChunk, DocumentSegment
from models.model import UploadFile
@ -145,6 +146,8 @@ class DatasetDocumentSegmentListApi(Resource):
query = query.where(DocumentSegment.hit_count >= hit_count_gte)
if keyword:
# Escape special characters in keyword to prevent SQL injection via LIKE wildcards
escaped_keyword = escape_like_pattern(keyword)
# Search in both content and keywords fields
# Use database-specific methods for JSON array search
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
@ -156,15 +159,15 @@ class DatasetDocumentSegmentListApi(Resource):
.scalar_subquery()
),
",",
).ilike(f"%{keyword}%")
).ilike(f"%{escaped_keyword}%", escape="\\")
else:
# MySQL: Cast JSON to string for pattern matching
# MySQL stores Chinese text directly in JSON without Unicode escaping
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{keyword}%")
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{escaped_keyword}%", escape="\\")
query = query.where(
or_(
DocumentSegment.content.ilike(f"%{keyword}%"),
DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"),
keywords_condition,
)
)

View File

@ -1,7 +1,7 @@
import logging
from typing import Any
from flask_restx import marshal, reqparse
from flask_restx import marshal
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -56,15 +56,10 @@ class DatasetsHitTestingBase:
HitTestingService.hit_testing_args_check(args)
@staticmethod
def parse_args():
parser = (
reqparse.RequestParser()
.add_argument("query", type=str, required=False, location="json")
.add_argument("attachment_ids", type=list, required=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
return parser.parse_args()
def parse_args(payload: dict[str, Any]) -> dict[str, Any]:
"""Validate and return hit-testing arguments from an incoming payload."""
hit_testing_payload = HitTestingPayload.model_validate(payload or {})
return hit_testing_payload.model_dump(exclude_none=True)
@staticmethod
def perform_hit_testing(dataset, args):

View File

@ -4,12 +4,11 @@ from typing import Any
from flask import make_response, redirect, request
from flask_restx import Resource, reqparse
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from controllers.web.error import NotFoundError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
@ -44,6 +43,12 @@ class TriggerSubscriptionUpdateRequest(BaseModel):
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription")
properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription")
@model_validator(mode="after")
def check_at_least_one_field(self):
if all(v is None for v in (self.name, self.credentials, self.parameters, self.properties)):
raise ValueError("At least one of name, credentials, parameters, or properties must be provided")
return self
class TriggerSubscriptionVerifyRequest(BaseModel):
"""Request payload for verifying subscription credentials."""
@ -333,7 +338,7 @@ class TriggerSubscriptionUpdateApi(Resource):
user = current_user
assert user.current_tenant_id is not None
args = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
request = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
subscription = TriggerProviderService.get_subscription_by_id(
tenant_id=user.current_tenant_id,
@ -345,50 +350,32 @@ class TriggerSubscriptionUpdateApi(Resource):
provider_id = TriggerProviderID(subscription.provider_id)
try:
# rename only
if (
args.name is not None
and args.credentials is None
and args.parameters is None
and args.properties is None
):
# For rename only, just update the name
rename = request.name is not None and not any((request.credentials, request.parameters, request.properties))
# When credential type is UNAUTHORIZED, it indicates the subscription was manually created
# For Manually created subscription, they dont have credentials, parameters
# They only have name and properties(which is input by user)
manually_created = subscription.credential_type == CredentialType.UNAUTHORIZED
if rename or manually_created:
TriggerProviderService.update_trigger_subscription(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
name=args.name,
name=request.name,
properties=request.properties,
)
return 200
# rebuild for create automatically by the provider
match subscription.credential_type:
case CredentialType.UNAUTHORIZED:
TriggerProviderService.update_trigger_subscription(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
name=args.name,
properties=args.properties,
)
return 200
case CredentialType.API_KEY | CredentialType.OAUTH2:
if args.credentials:
new_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
for key, value in args.credentials.items()
}
else:
new_credentials = subscription.credentials
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=user.current_tenant_id,
name=args.name,
provider_id=provider_id,
subscription_id=subscription_id,
credentials=new_credentials,
parameters=args.parameters or subscription.parameters,
)
return 200
case _:
raise BadRequest("Invalid credential type")
# For the rest cases(API_KEY, OAUTH2)
# we need to call third party provider(e.g. GitHub) to rebuild the subscription
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=user.current_tenant_id,
name=request.name,
provider_id=provider_id,
subscription_id=subscription_id,
credentials=request.credentials or subscription.credentials,
parameters=request.parameters or subscription.parameters,
)
return 200
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:

View File

@ -24,7 +24,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
dataset_id_str = str(dataset_id)
dataset = self.get_and_validate_dataset(dataset_id_str)
args = self.parse_args()
args = self.parse_args(service_api_ns.payload)
self.hit_testing_args_check(args)
return self.perform_hit_testing(dataset, args)

View File

@ -1,9 +1,11 @@
from flask_restx import reqparse
from flask_restx.inputs import int_range
from pydantic import TypeAdapter
from typing import Literal
from flask import request
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotChatAppError
from controllers.web.wraps import WebApiResource
@ -21,6 +23,35 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers
from services.web_conversation_service import WebConversationService
class ConversationListQuery(BaseModel):
last_id: str | None = None
limit: int = Field(default=20, ge=1, le=100)
pinned: bool | None = None
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = "-updated_at"
@field_validator("last_id")
@classmethod
def validate_last_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)
@web_ns.route("/conversations")
class ConversationListApi(WebApiResource):
@web_ns.doc("Get Conversation List")
@ -64,25 +95,8 @@ class ConversationListApi(WebApiResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
)
)
args = parser.parse_args()
pinned = None
if "pinned" in args and args["pinned"] is not None:
pinned = args["pinned"] == "true"
raw_args = request.args.to_dict()
query = ConversationListQuery.model_validate(raw_args)
try:
with Session(db.engine) as session:
@ -90,11 +104,11 @@ class ConversationListApi(WebApiResource):
session=session,
app_model=app_model,
user=end_user,
last_id=args["last_id"],
limit=args["limit"],
last_id=query.last_id,
limit=query.limit,
invoke_from=InvokeFrom.WEB_APP,
pinned=pinned,
sort_by=args["sort_by"],
pinned=query.pinned,
sort_by=query.sort_by,
)
adapter = TypeAdapter(SimpleConversation)
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
@ -168,16 +182,11 @@ class ConversationRenameApi(WebApiResource):
conversation_id = str(c_id)
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, location="json")
.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
)
args = parser.parse_args()
payload = ConversationRenamePayload.model_validate(web_ns.payload or {})
try:
conversation = ConversationService.rename(
app_model, conversation_id, end_user, args["name"], args["auto_generate"]
app_model, conversation_id, end_user, payload.name, payload.auto_generate
)
return (
TypeAdapter(SimpleConversation)

View File

@ -1,18 +1,30 @@
from flask_restx import reqparse
from flask_restx.inputs import int_range
from pydantic import TypeAdapter
from flask import request
from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotCompletionAppError
from controllers.web.wraps import WebApiResource
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.helper import uuid_value
from libs.helper import UUIDStrOrEmpty
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)
@web_ns.route("/saved-messages")
class SavedMessageListApi(WebApiResource):
@web_ns.doc("Get Saved Messages")
@ -42,14 +54,10 @@ class SavedMessageListApi(WebApiResource):
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
raw_args = request.args.to_dict()
query = SavedMessageListQuery.model_validate(raw_args)
pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit)
adapter = TypeAdapter(SavedMessageItem)
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
return SavedMessageInfiniteScrollPagination(
@ -79,11 +87,10 @@ class SavedMessageListApi(WebApiResource):
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json")
args = parser.parse_args()
payload = SavedMessageCreatePayload.model_validate(web_ns.payload or {})
try:
SavedMessageService.save(app_model, end_user, args["message_id"])
SavedMessageService.save(app_model, end_user, payload.message_id)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")

View File

@ -20,6 +20,8 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
)
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
from core.db.session_factory import session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
@ -40,6 +42,7 @@ from models import Workflow
from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable
from services.conversation_variable_updater import ConversationVariableUpdater
logger = logging.getLogger(__name__)
@ -200,6 +203,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
)
workflow_entry.graph_engine.layer(persistence_layer)
conversation_variable_layer = ConversationVariablePersistenceLayer(
ConversationVariableUpdater(session_factory.get_session_maker())
)
workflow_entry.graph_engine.layer(conversation_variable_layer)
for layer in self._graph_engine_layers:
workflow_entry.graph_engine.layer(layer)

View File

@ -75,7 +75,7 @@ class AnnotationReplyFeature:
AppAnnotationService.add_annotation_history(
annotation.id,
app_record.id,
annotation.question,
annotation.question_text,
annotation.content,
query,
user_id,

View File

@ -0,0 +1,60 @@
import logging
from core.variables import Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
logger = logging.getLogger(__name__)
class ConversationVariablePersistenceLayer(GraphEngineLayer):
def __init__(self, conversation_variable_updater: ConversationVariableUpdater) -> None:
super().__init__()
self._conversation_variable_updater = conversation_variable_updater
def on_graph_start(self) -> None:
pass
def on_event(self, event: GraphEngineEvent) -> None:
if not isinstance(event, NodeRunSucceededEvent):
return
if event.node_type != NodeType.VARIABLE_ASSIGNER:
return
if self.graph_runtime_state is None:
return
updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or []
if not updated_variables:
return
conversation_id = self.graph_runtime_state.system_variable.conversation_id
if conversation_id is None:
return
updated_any = False
for item in updated_variables:
selector = item.selector
if len(selector) < 2:
logger.warning("Conversation variable selector invalid. selector=%s", selector)
continue
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
continue
variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable):
logger.warning(
"Conversation variable not found in variable pool. selector=%s",
selector,
)
continue
self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable)
updated_any = True
if updated_any:
self._conversation_variable_updater.flush()
def on_graph_end(self, error: Exception | None) -> None:
pass

View File

@ -66,6 +66,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
"""
if isinstance(session_factory, Engine):
session_factory = sessionmaker(session_factory)
super().__init__()
self._session_maker = session_factory
self._state_owner_user_id = state_owner_user_id
self._generate_entity = generate_entity
@ -98,8 +99,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
if not isinstance(event, GraphRunPausedEvent):
return
assert self.graph_runtime_state is not None
entity_wrapper: _GenerateEntityUnion
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)

View File

@ -33,6 +33,7 @@ class TriggerPostLayer(GraphEngineLayer):
trigger_log_id: str,
session_maker: sessionmaker[Session],
):
super().__init__()
self.trigger_log_id = trigger_log_id
self.start_time = start_time
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
@ -57,10 +58,6 @@ class TriggerPostLayer(GraphEngineLayer):
elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds()
# Extract relevant data from result
if not self.graph_runtime_state:
logger.exception("Graph runtime state is not set")
return
outputs = self.graph_runtime_state.outputs
# BASICLY, workflow_execution_id is the same as workflow_run_id

View File

@ -1,7 +1,7 @@
from sqlalchemy import Engine
from sqlalchemy.orm import Session, sessionmaker
_session_maker: sessionmaker | None = None
_session_maker: sessionmaker[Session] | None = None
def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
@ -10,7 +10,7 @@ def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
_session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
def get_session_maker() -> sessionmaker:
def get_session_maker() -> sessionmaker[Session]:
if _session_maker is None:
raise RuntimeError("Session factory not configured. Call configure_session_factory() first.")
return _session_maker
@ -27,7 +27,7 @@ class SessionFactory:
configure_session_factory(engine, expire_on_commit)
@staticmethod
def get_session_maker() -> sessionmaker:
def get_session_maker() -> sessionmaker[Session]:
return get_session_maker()
@staticmethod

View File

@ -8,8 +8,9 @@ import urllib.parse
from configs import dify_config
def get_signed_file_url(upload_file_id: str, as_attachment=False) -> str:
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview"
def get_signed_file_url(upload_file_id: str, as_attachment=False, for_external: bool = True) -> str:
base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
url = f"{base_url}/files/{upload_file_id}/file-preview"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()

View File

@ -112,17 +112,17 @@ class File(BaseModel):
return text
def generate_url(self) -> str | None:
def generate_url(self, for_external: bool = True) -> str | None:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.remote_url
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
if self.related_id is None:
raise ValueError("Missing file related_id")
return helpers.get_signed_file_url(upload_file_id=self.related_id)
return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external)
elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]:
assert self.related_id is not None
assert self.extension is not None
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension)
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension, for_external=for_external)
return None
def to_plugin_parameter(self) -> dict[str, Any]:
@ -133,7 +133,7 @@ class File(BaseModel):
"extension": self.extension,
"size": self.size,
"type": self.type,
"url": self.generate_url(),
"url": self.generate_url(for_external=False),
}
@model_validator(mode="after")

View File

@ -76,7 +76,7 @@ class TemplateTransformer(ABC):
Post-process the result to convert scientific notation strings back to numbers
"""
def convert_scientific_notation(value):
def convert_scientific_notation(value: Any) -> Any:
if isinstance(value, str):
# Check if the string looks like scientific notation
if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
@ -90,7 +90,7 @@ class TemplateTransformer(ABC):
return [convert_scientific_notation(v) for v in value]
return value
return convert_scientific_notation(result) # type: ignore[no-any-return]
return convert_scientific_notation(result)
@classmethod
@abstractmethod

View File

@ -984,9 +984,11 @@ class ClickzettaVector(BaseVector):
# No need for dataset_id filter since each dataset has its own table
# Use simple quote escaping for LIKE clause
escaped_query = query.replace("'", "''")
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'")
# Escape special characters for LIKE clause to prevent SQL injection
from libs.helper import escape_like_pattern
escaped_query = escape_like_pattern(query).replace("'", "''")
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%' ESCAPE '\\\\'")
where_clause = " AND ".join(filter_clauses)
search_sql = f"""

View File

@ -287,11 +287,15 @@ class IrisVector(BaseVector):
cursor.execute(sql, (query,))
else:
# Fallback to LIKE search (inefficient for large datasets)
query_pattern = f"%{query}%"
# Escape special characters for LIKE clause to prevent SQL injection
from libs.helper import escape_like_pattern
escaped_query = escape_like_pattern(query)
query_pattern = f"%{escaped_query}%"
sql = f"""
SELECT TOP {top_k} id, text, meta
FROM {self.schema}.{self.table_name}
WHERE text LIKE ?
WHERE text LIKE ? ESCAPE '\\'
"""
cursor.execute(sql, (query_pattern,))

View File

@ -66,6 +66,8 @@ class WeaviateVector(BaseVector):
in a Weaviate collection.
"""
_DOCUMENT_ID_PROPERTY = "document_id"
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
"""
Initializes the Weaviate vector store.
@ -353,15 +355,12 @@ class WeaviateVector(BaseVector):
return []
col = self._client.collections.use(self._collection_name)
props = list({*self._attributes, "document_id", Field.TEXT_KEY.value})
props = list({*self._attributes, self._DOCUMENT_ID_PROPERTY, Field.TEXT_KEY.value})
where = None
doc_ids = kwargs.get("document_ids_filter") or []
if doc_ids:
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
where = ors[0]
for f in ors[1:]:
where = where | f
where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids)
top_k = int(kwargs.get("top_k", 4))
score_threshold = float(kwargs.get("score_threshold") or 0.0)
@ -408,10 +407,7 @@ class WeaviateVector(BaseVector):
where = None
doc_ids = kwargs.get("document_ids_filter") or []
if doc_ids:
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
where = ors[0]
for f in ors[1:]:
where = where | f
where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids)
top_k = int(kwargs.get("top_k", 4))

View File

@ -7,10 +7,11 @@ import re
import tempfile
import uuid
from urllib.parse import urlparse
from xml.etree import ElementTree
import httpx
from docx import Document as DocxDocument
from docx.oxml.ns import qn
from docx.text.run import Run
from configs import dify_config
from core.helper import ssrf_proxy
@ -229,44 +230,20 @@ class WordExtractor(BaseExtractor):
image_map = self._extract_images_from_docx(doc)
hyperlinks_url = None
url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+")
for para in doc.paragraphs:
for run in para.runs:
if run.text and hyperlinks_url:
result = f" [{run.text}]({hyperlinks_url}) "
run.text = result
hyperlinks_url = None
if "HYPERLINK" in run.element.xml:
try:
xml = ElementTree.XML(run.element.xml)
x_child = [c for c in xml.iter() if c is not None]
for x in x_child:
if x is None:
continue
if x.tag.endswith("instrText"):
if x.text is None:
continue
for i in url_pattern.findall(x.text):
hyperlinks_url = str(i)
except Exception:
logger.exception("Failed to parse HYPERLINK xml")
def parse_paragraph(paragraph):
paragraph_content = []
def append_image_link(image_id, has_drawing):
def append_image_link(image_id, has_drawing, target_buffer):
"""Helper to append image link from image_map based on relationship type."""
rel = doc.part.rels[image_id]
if rel.is_external:
if image_id in image_map and not has_drawing:
paragraph_content.append(image_map[image_id])
target_buffer.append(image_map[image_id])
else:
image_part = rel.target_part
if image_part in image_map and not has_drawing:
paragraph_content.append(image_map[image_part])
target_buffer.append(image_map[image_part])
for run in paragraph.runs:
def process_run(run, target_buffer):
# Helper to extract text and embedded images from a run element and append them to target_buffer
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
# Process drawing type images
drawing_elements = run.element.findall(
@ -287,13 +264,13 @@ class WordExtractor(BaseExtractor):
# External image: use embed_id as key
if embed_id in image_map:
has_drawing = True
paragraph_content.append(image_map[embed_id])
target_buffer.append(image_map[embed_id])
else:
# Internal image: use target_part as key
image_part = doc.part.related_parts.get(embed_id)
if image_part in image_map:
has_drawing = True
paragraph_content.append(image_map[image_part])
target_buffer.append(image_map[image_part])
# Process pict type images
shape_elements = run.element.findall(
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
@ -308,7 +285,7 @@ class WordExtractor(BaseExtractor):
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
)
if image_id and image_id in doc.part.rels:
append_image_link(image_id, has_drawing)
append_image_link(image_id, has_drawing, target_buffer)
# Find imagedata element in VML
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
if image_data is not None:
@ -316,9 +293,93 @@ class WordExtractor(BaseExtractor):
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
)
if image_id and image_id in doc.part.rels:
append_image_link(image_id, has_drawing)
append_image_link(image_id, has_drawing, target_buffer)
if run.text.strip():
paragraph_content.append(run.text.strip())
target_buffer.append(run.text.strip())
def process_hyperlink(hyperlink_elem, target_buffer):
# Helper to extract text from a hyperlink element and append it to target_buffer
r_id = hyperlink_elem.get(qn("r:id"))
# Extract text from runs inside the hyperlink
link_text_parts = []
for run_elem in hyperlink_elem.findall(qn("w:r")):
run = Run(run_elem, paragraph)
# Hyperlink text may be split across multiple runs (e.g., with different formatting),
# so collect all run texts first
if run.text:
link_text_parts.append(run.text)
link_text = "".join(link_text_parts).strip()
# Resolve URL
if r_id:
try:
rel = doc.part.rels.get(r_id)
if rel and rel.is_external:
link_text = f"[{link_text or rel.target_ref}]({rel.target_ref})"
except Exception:
logger.exception("Failed to resolve URL for hyperlink with r:id: %s", r_id)
if link_text:
target_buffer.append(link_text)
paragraph_content = []
# State for legacy HYPERLINK fields
hyperlink_field_url = None
hyperlink_field_text_parts: list = []
is_collecting_field_text = False
# Iterate through paragraph elements in document order
for child in paragraph._element:
tag = child.tag
if tag == qn("w:r"):
# Regular run
run = Run(child, paragraph)
# Check for fldChar (begin/end/separate) and instrText for legacy hyperlinks
fld_chars = child.findall(qn("w:fldChar"))
instr_texts = child.findall(qn("w:instrText"))
# Handle Fields
if fld_chars or instr_texts:
# Process instrText to find HYPERLINK "url"
for instr in instr_texts:
if instr.text and "HYPERLINK" in instr.text:
# Quick regex to extract URL
match = re.search(r'HYPERLINK\s+"([^"]+)"', instr.text, re.IGNORECASE)
if match:
hyperlink_field_url = match.group(1)
# Process fldChar
for fld_char in fld_chars:
fld_char_type = fld_char.get(qn("w:fldCharType"))
if fld_char_type == "begin":
# Start of a field: reset legacy link state
hyperlink_field_url = None
hyperlink_field_text_parts = []
is_collecting_field_text = False
elif fld_char_type == "separate":
# Separator: if we found a URL, start collecting visible text
if hyperlink_field_url:
is_collecting_field_text = True
elif fld_char_type == "end":
# End of field
if is_collecting_field_text and hyperlink_field_url:
# Create markdown link and append to main content
display_text = "".join(hyperlink_field_text_parts).strip()
if display_text:
link_md = f"[{display_text}]({hyperlink_field_url})"
paragraph_content.append(link_md)
# Reset state
hyperlink_field_url = None
hyperlink_field_text_parts = []
is_collecting_field_text = False
# Decide where to append content
target_buffer = hyperlink_field_text_parts if is_collecting_field_text else paragraph_content
process_run(run, target_buffer)
elif tag == qn("w:hyperlink"):
process_hyperlink(child, paragraph_content)
return "".join(paragraph_content) if paragraph_content else ""
paragraphs = doc.paragraphs.copy()

View File

@ -1198,18 +1198,24 @@ class DatasetRetrieval:
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
from libs.helper import escape_like_pattern
match condition:
case "contains":
filters.append(json_field.like(f"%{value}%"))
escaped_value = escape_like_pattern(str(value))
filters.append(json_field.like(f"%{escaped_value}%", escape="\\"))
case "not contains":
filters.append(json_field.notlike(f"%{value}%"))
escaped_value = escape_like_pattern(str(value))
filters.append(json_field.notlike(f"%{escaped_value}%", escape="\\"))
case "start with":
filters.append(json_field.like(f"{value}%"))
escaped_value = escape_like_pattern(str(value))
filters.append(json_field.like(f"{escaped_value}%", escape="\\"))
case "end with":
filters.append(json_field.like(f"%{value}"))
escaped_value = escape_like_pattern(str(value))
filters.append(json_field.like(f"%{escaped_value}", escape="\\"))
case "is" | "=":
if isinstance(value, str):
@ -1474,38 +1480,38 @@ class DatasetRetrieval:
if cancel_event and cancel_event.is_set():
break
# Skip second reranking when there is only one dataset
if reranking_enable and dataset_count > 1:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
if query:
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY,
)
if attachment_id:
all_documents_item = data_post_processor.invoke(
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.IMAGE_QUERY,
query=attachment_id,
)
else:
if index_type == IndexTechniqueType.ECONOMY:
if not query:
all_documents_item = []
else:
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
elif index_type == IndexTechniqueType.HIGH_QUALITY:
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
# Skip second reranking when there is only one dataset
if reranking_enable and dataset_count > 1:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
if query:
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY,
)
if attachment_id:
all_documents_item = data_post_processor.invoke(
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.IMAGE_QUERY,
query=attachment_id,
)
else:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)
if index_type == IndexTechniqueType.ECONOMY:
if not query:
all_documents_item = []
else:
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
elif index_type == IndexTechniqueType.HIGH_QUALITY:
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
else:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)
except Exception as e:
if cancel_event:
cancel_event.set()

View File

@ -7,12 +7,12 @@ import time
from configs import dify_config
def sign_tool_file(tool_file_id: str, extension: str) -> str:
def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True) -> str:
"""
sign file to get a temporary url for plugin access
"""
# Use internal URL for plugin/tool file access in Docker environments
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
# Use internal URL for plugin/tool file access in Docker environments, unless for_external is True
base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
timestamp = str(int(time.time()))

View File

@ -64,6 +64,9 @@ engine.layer(DebugLoggingLayer(level="INFO"))
engine.layer(ExecutionLimitsLayer(max_nodes=100))
```
`engine.layer()` binds the read-only runtime state before execution, so layer hooks
can assume `graph_runtime_state` is available.
### Event-Driven Architecture
All node executions emit events for monitoring and integration:

View File

@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue.
import json
from typing import TYPE_CHECKING, Any, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
@ -113,6 +113,8 @@ class RedisChannel:
return AbortCommand.model_validate(data)
if command_type == CommandType.PAUSE:
return PauseCommand.model_validate(data)
if command_type == CommandType.UPDATE_VARIABLES:
return UpdateVariablesCommand.model_validate(data)
# For other command types, use base class
return GraphEngineCommand.model_validate(data)

View File

@ -5,11 +5,12 @@ This package handles external commands sent to the engine
during execution.
"""
from .command_handlers import AbortCommandHandler, PauseCommandHandler
from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler
from .command_processor import CommandProcessor
__all__ = [
"AbortCommandHandler",
"CommandProcessor",
"PauseCommandHandler",
"UpdateVariablesCommandHandler",
]

View File

@ -4,9 +4,10 @@ from typing import final
from typing_extensions import override
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.runtime import VariablePool
from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
from .command_processor import CommandHandler
logger = logging.getLogger(__name__)
@ -31,3 +32,25 @@ class PauseCommandHandler(CommandHandler):
reason = command.reason
pause_reason = SchedulingPause(message=reason)
execution.pause(pause_reason)
@final
class UpdateVariablesCommandHandler(CommandHandler):
def __init__(self, variable_pool: VariablePool) -> None:
self._variable_pool = variable_pool
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
assert isinstance(command, UpdateVariablesCommand)
for update in command.updates:
try:
variable = update.value
self._variable_pool.add(variable.selector, variable)
logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id)
except ValueError as exc:
logger.warning(
"Skipping invalid variable selector %s for workflow %s: %s",
getattr(update.value, "selector", None),
execution.workflow_id,
exc,
)

View File

@ -5,17 +5,21 @@ This module defines command types that can be sent to a running GraphEngine
instance to control its execution flow.
"""
from enum import StrEnum
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, Field
from core.variables.variables import VariableUnion
class CommandType(StrEnum):
"""Types of commands that can be sent to GraphEngine."""
ABORT = "abort"
PAUSE = "pause"
ABORT = auto()
PAUSE = auto()
UPDATE_VARIABLES = auto()
class GraphEngineCommand(BaseModel):
@ -37,3 +41,16 @@ class PauseCommand(GraphEngineCommand):
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
reason: str = Field(default="unknown reason", description="reason for pause")
class VariableUpdate(BaseModel):
"""Represents a single variable update instruction."""
value: VariableUnion = Field(description="New variable value")
class UpdateVariablesCommand(GraphEngineCommand):
"""Command to update a group of variables in the variable pool."""
command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command")
updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates")

View File

@ -8,6 +8,7 @@ Domain-Driven Design principles for improved maintainability and testability.
import contextvars
import logging
import queue
import threading
from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
@ -30,8 +31,13 @@ from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWr
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
from core.workflow.runtime.graph_runtime_state import GraphProtocol
from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler
from .entities.commands import AbortCommand, PauseCommand
from .command_processing import (
AbortCommandHandler,
CommandProcessor,
PauseCommandHandler,
UpdateVariablesCommandHandler,
)
from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand
from .error_handler import ErrorHandler
from .event_management import EventHandler, EventManager
from .graph_state_manager import GraphStateManager
@ -70,10 +76,13 @@ class GraphEngine:
scale_down_idle_time: float | None = None,
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
# stop event
self._stop_event = threading.Event()
# Bind runtime state to current workflow context
self._graph = graph
self._graph_runtime_state = graph_runtime_state
self._graph_runtime_state.stop_event = self._stop_event
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
self._command_channel = command_channel
@ -140,6 +149,9 @@ class GraphEngine:
pause_handler = PauseCommandHandler()
self._command_processor.register_handler(PauseCommand, pause_handler)
update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool)
self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler)
# === Extensibility ===
# Layers allow plugins to extend engine functionality
self._layers: list[GraphEngineLayer] = []
@ -169,6 +181,7 @@ class GraphEngine:
max_workers=self._max_workers,
scale_up_threshold=self._scale_up_threshold,
scale_down_idle_time=self._scale_down_idle_time,
stop_event=self._stop_event,
)
# === Orchestration ===
@ -199,6 +212,7 @@ class GraphEngine:
event_handler=self._event_handler_registry,
execution_coordinator=self._execution_coordinator,
event_emitter=self._event_manager,
stop_event=self._stop_event,
)
# === Validation ===
@ -212,9 +226,16 @@ class GraphEngine:
if id(node.graph_runtime_state) != expected_state_id:
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
def _bind_layer_context(
self,
layer: GraphEngineLayer,
) -> None:
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
"""Add a layer for extending functionality."""
self._layers.append(layer)
self._bind_layer_context(layer)
return self
def run(self) -> Generator[GraphEngineEvent, None, None]:
@ -301,14 +322,7 @@ class GraphEngine:
def _initialize_layers(self) -> None:
"""Initialize layers with context."""
self._event_manager.set_layers(self._layers)
# Create a read-only wrapper for the runtime state
read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state)
for layer in self._layers:
try:
layer.initialize(read_only_state, self._command_channel)
except Exception as e:
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
try:
layer.on_graph_start()
except Exception as e:
@ -316,6 +330,7 @@ class GraphEngine:
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
self._stop_event.clear()
paused_nodes: list[str] = []
if resume:
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
@ -343,13 +358,12 @@ class GraphEngine:
def _stop_execution(self) -> None:
"""Stop execution subsystems."""
self._stop_event.set()
self._dispatcher.stop()
self._worker_pool.stop()
# Don't mark complete here as the dispatcher already does it
# Notify layers
logger = logging.getLogger(__name__)
for layer in self._layers:
try:
layer.on_graph_end(self._graph_execution.error)

View File

@ -8,7 +8,7 @@ Pluggable middleware for engine extensions.
Abstract base class for layers.
- `initialize()` - Receive runtime context
- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks)
- `on_graph_start()` - Execution start hook
- `on_event()` - Process all events
- `on_graph_end()` - Execution end hook
@ -34,6 +34,9 @@ engine.layer(debug_layer)
engine.run()
```
`engine.layer()` binds the read-only runtime state before execution, so
`graph_runtime_state` is always available inside layer hooks.
## Custom Layers
```python

View File

@ -13,6 +13,14 @@ from core.workflow.nodes.base.node import Node
from core.workflow.runtime import ReadOnlyGraphRuntimeState
class GraphEngineLayerNotInitializedError(Exception):
"""Raised when a layer's runtime state is accessed before initialization."""
def __init__(self, layer_name: str | None = None) -> None:
name = layer_name or "GraphEngineLayer"
super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.")
class GraphEngineLayer(ABC):
"""
Abstract base class for GraphEngine layers.
@ -28,22 +36,27 @@ class GraphEngineLayer(ABC):
def __init__(self) -> None:
"""Initialize the layer. Subclasses can override with custom parameters."""
self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
self.command_channel: CommandChannel | None = None
@property
def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState:
if self._graph_runtime_state is None:
raise GraphEngineLayerNotInitializedError(type(self).__name__)
return self._graph_runtime_state
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
"""
Initialize the layer with engine dependencies.
Called by GraphEngine before execution starts to inject the read-only runtime state
and command channel. This allows layers to observe engine context and send
commands, but prevents direct state modification.
Called by GraphEngine to inject the read-only runtime state and command channel.
This is invoked when the layer is registered with a `GraphEngine` instance.
Implementations should be idempotent.
Args:
graph_runtime_state: Read-only view of the runtime state
command_channel: Channel for sending commands to the engine
"""
self.graph_runtime_state = graph_runtime_state
self._graph_runtime_state = graph_runtime_state
self.command_channel = command_channel
@abstractmethod

View File

@ -109,10 +109,8 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.info("=" * 80)
self.logger.info("🚀 GRAPH EXECUTION STARTED")
self.logger.info("=" * 80)
if self.graph_runtime_state:
# Log initial state
self.logger.info("Initial State:")
# Log initial state
self.logger.info("Initial State:")
@override
def on_event(self, event: GraphEngineEvent) -> None:
@ -243,8 +241,7 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.info(" Node retries: %s", self.retry_count)
# Log final state if available
if self.graph_runtime_state and self.include_outputs:
if self.graph_runtime_state.outputs:
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
if self.include_outputs and self.graph_runtime_state.outputs:
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
self.logger.info("=" * 80)

View File

@ -337,8 +337,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
if update_finished:
execution.finished_at = naive_utc_now()
runtime_state = self.graph_runtime_state
if runtime_state is None:
return
execution.total_tokens = runtime_state.total_tokens
execution.total_steps = runtime_state.node_run_steps
execution.outputs = execution.outputs or runtime_state.outputs
@ -404,6 +402,4 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
def _system_variables(self) -> Mapping[str, Any]:
runtime_state = self.graph_runtime_state
if runtime_state is None:
return {}
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)

View File

@ -3,14 +3,20 @@ GraphEngine Manager for sending control commands via Redis channel.
This module provides a simplified interface for controlling workflow executions
using the new Redis command channel, without requiring user permission checks.
Supports stop, pause, and resume operations.
"""
import logging
from collections.abc import Sequence
from typing import final
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from core.workflow.graph_engine.entities.commands import (
AbortCommand,
GraphEngineCommand,
PauseCommand,
UpdateVariablesCommand,
VariableUpdate,
)
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
@ -23,7 +29,6 @@ class GraphEngineManager:
This class provides a simple interface for controlling workflow executions
by sending commands through Redis channels, without user validation.
Supports stop and pause operations.
"""
@staticmethod
@ -45,6 +50,16 @@ class GraphEngineManager:
pause_command = PauseCommand(reason=reason or "User requested pause")
GraphEngineManager._send_command(task_id, pause_command)
@staticmethod
def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None:
"""Send a command to update variables in a running workflow."""
if not updates:
return
update_command = UpdateVariablesCommand(updates=updates)
GraphEngineManager._send_command(task_id, update_command)
@staticmethod
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
"""Send a command to the workflow-specific Redis channel."""

View File

@ -44,6 +44,7 @@ class Dispatcher:
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandler",
execution_coordinator: ExecutionCoordinator,
stop_event: threading.Event,
event_emitter: EventManager | None = None,
) -> None:
"""
@ -61,7 +62,7 @@ class Dispatcher:
self._event_emitter = event_emitter
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._stop_event = stop_event
self._start_time: float | None = None
def start(self) -> None:
@ -69,16 +70,14 @@ class Dispatcher:
if self._thread and self._thread.is_alive():
return
self._stop_event.clear()
self._start_time = time.time()
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
self._thread.start()
def stop(self) -> None:
"""Stop the dispatcher thread."""
self._stop_event.set()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=10.0)
self._thread.join(timeout=2.0)
def _dispatcher_loop(self) -> None:
"""Main dispatcher loop."""

View File

@ -42,6 +42,7 @@ class Worker(threading.Thread):
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: Sequence[GraphEngineLayer],
stop_event: threading.Event,
worker_id: int = 0,
flask_app: Flask | None = None,
context_vars: contextvars.Context | None = None,
@ -65,13 +66,16 @@ class Worker(threading.Thread):
self._worker_id = worker_id
self._flask_app = flask_app
self._context_vars = context_vars
self._stop_event = threading.Event()
self._last_task_time = time.time()
self._stop_event = stop_event
self._layers = layers if layers is not None else []
def stop(self) -> None:
"""Signal the worker to stop processing."""
self._stop_event.set()
"""Worker is controlled via shared stop_event from GraphEngine.
This method is a no-op retained for backward compatibility.
"""
pass
@property
def is_idle(self) -> bool:

View File

@ -41,6 +41,7 @@ class WorkerPool:
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: list[GraphEngineLayer],
stop_event: threading.Event,
flask_app: "Flask | None" = None,
context_vars: "Context | None" = None,
min_workers: int | None = None,
@ -81,6 +82,7 @@ class WorkerPool:
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False
self._stop_event = stop_event
# No longer tracking worker states with callbacks to avoid lock contention
@ -135,7 +137,7 @@ class WorkerPool:
# Wait for workers to finish
for worker in self._workers:
if worker.is_alive():
worker.join(timeout=10.0)
worker.join(timeout=2.0)
self._workers.clear()
@ -152,6 +154,7 @@ class WorkerPool:
worker_id=worker_id,
flask_app=self._flask_app,
context_vars=self._context_vars,
stop_event=self._stop_event,
)
worker.start()

View File

@ -264,6 +264,10 @@ class Node(Generic[NodeDataT]):
"""
raise NotImplementedError
def _should_stop(self) -> bool:
"""Check if execution should be stopped."""
return self.graph_runtime_state.stop_event.is_set()
def run(self) -> Generator[GraphNodeEventBase, None, None]:
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
@ -332,6 +336,21 @@ class Node(Generic[NodeDataT]):
yield event
else:
yield event
if self._should_stop():
error_message = "Execution cancelled"
yield NodeRunFailedEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=error_message,
),
error=error_message,
)
return
except Exception as e:
logger.exception("Node %s failed to run", self._node_id)
result = NodeRunResult(

View File

@ -11,6 +11,11 @@ from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
)
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from libs.typing import is_str, is_str_dict
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
@ -37,6 +42,7 @@ class DifyNodeFactory(NodeFactory):
code_executor: type[CodeExecutor] | None = None,
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
template_renderer: Jinja2TemplateRenderer | None = None,
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
@ -54,6 +60,7 @@ class DifyNodeFactory(NodeFactory):
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
@override
def create_node(self, node_config: dict[str, object]) -> Node:
@ -106,6 +113,14 @@ class DifyNodeFactory(NodeFactory):
code_providers=self._code_providers,
code_limits=self._code_limits,
)
if node_type == NodeType.TEMPLATE_TRANSFORM:
return TemplateTransformNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
template_renderer=self._template_renderer,
)
return node_class(
id=node_id,

View File

@ -0,0 +1,40 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Protocol
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
class TemplateRenderError(ValueError):
"""Raised when rendering a Jinja2 template fails."""
class Jinja2TemplateRenderer(Protocol):
"""Render Jinja2 templates for template transform nodes."""
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
"""Render a Jinja2 template with provided variables."""
raise NotImplementedError
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
"""Adapter that renders Jinja2 templates via CodeExecutor."""
_code_executor: type[CodeExecutor]
def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None:
self._code_executor = code_executor or CodeExecutor
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
try:
result = self._code_executor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=template, inputs=variables
)
except CodeExecutionError as exc:
raise TemplateRenderError(str(exc)) from exc
rendered = result.get("result")
if not isinstance(rendered, str):
raise TemplateRenderError("Template render result must be a string.")
return rendered

View File

@ -1,18 +1,44 @@
from collections.abc import Mapping, Sequence
from typing import Any
from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
TemplateRenderError,
)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
class TemplateTransformNode(Node[TemplateTransformNodeData]):
node_type = NodeType.TEMPLATE_TRANSFORM
_template_renderer: Jinja2TemplateRenderer
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
template_renderer: Jinja2TemplateRenderer | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@ -39,13 +65,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
variables[variable_name] = value.to_object() if value else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
)
except CodeExecutionError as e:
rendered = self._template_renderer.render_template(self.node_data.template, variables)
except TemplateRenderError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
return NodeRunResult(
inputs=variables,
status=WorkflowNodeExecutionStatus.FAILED,
@ -53,7 +77,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered}
)
@classmethod

View File

@ -1,28 +0,0 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.variables.variables import Variable
from extensions.ext_database import db
from models import ConversationVariable
from .exc import VariableOperatorNodeError
class ConversationVariableUpdaterImpl:
def update(self, conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()
def flush(self):
pass
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
return ConversationVariableUpdaterImpl()

View File

@ -1,9 +1,8 @@
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, TypeAlias
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
@ -11,19 +10,14 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from ..common.impl import conversation_variable_updater_factory
from .node_data import VariableAssignerData, WriteMode
if TYPE_CHECKING:
from core.workflow.runtime import GraphRuntimeState
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
class VariableAssignerNode(Node[VariableAssignerData]):
node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
def __init__(
self,
@ -31,7 +25,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
):
super().__init__(
id=id,
@ -39,7 +32,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._conv_var_updater_factory = conv_var_updater_factory
@classmethod
def version(cls) -> str:
@ -96,16 +88,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
# Over write the variable.
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
# TODO: Move database operation to the pipeline.
# Update conversation variable.
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
raise VariableOperatorNodeError("conversation_id not found")
conv_var_updater = self._conv_var_updater_factory()
conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable)
conv_var_updater.flush()
updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={

View File

@ -1,24 +1,20 @@
import json
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any, cast
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
from . import helpers
from .entities import VariableAssignerNodeData, VariableOperationItem
from .enums import InputType, Operation
from .exc import (
ConversationIDNotFoundError,
InputTypeNotSupportedError,
InvalidDataError,
InvalidInputValueError,
@ -26,6 +22,10 @@ from .exc import (
VariableNotFoundError,
)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
selector_node_id = item.variable_selector[0]
@ -53,6 +53,20 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
class VariableAssignerNode(Node[VariableAssignerNodeData]):
node_type = NodeType.VARIABLE_ASSIGNER
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
"""
Check if this Variable Assigner node blocks the output of specific variables.
@ -70,9 +84,6 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
return False
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
return conversation_variable_updater_factory()
@classmethod
def version(cls) -> str:
return "2"
@ -179,26 +190,12 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
# remove the duplicated items first.
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))
conv_var_updater = self._conv_var_updater_factory()
# Update variables
for selector in updated_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable):
raise VariableNotFoundError(variable_selector=selector)
process_data[variable.name] = variable.value
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
if self.invoke_from != InvokeFrom.DEBUGGER:
raise ConversationIDNotFoundError
else:
conversation_id = conversation_id.value
conv_var_updater.update(
conversation_id=cast(str, conversation_id),
variable=variable,
)
conv_var_updater.flush()
updated_variables = [
common_helpers.variable_to_processed_data(selector, seg)
for selector in updated_variable_selectors

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import importlib
import json
import threading
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
@ -168,6 +169,7 @@ class GraphRuntimeState:
self._pending_response_coordinator_dump: str | None = None
self._pending_graph_execution_workflow_id: str | None = None
self._paused_nodes: set[str] = set()
self.stop_event: threading.Event = threading.Event()
if graph is not None:
self.attach_graph(graph)

View File

@ -1,4 +1,4 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Any, Protocol
from core.model_runtime.entities.llm_entities import LLMUsage
@ -9,7 +9,7 @@ from core.workflow.system_variable import SystemVariableReadOnlyView
class ReadOnlyVariablePool(Protocol):
"""Read-only interface for VariablePool."""
def get(self, node_id: str, variable_key: str) -> Segment | None:
def get(self, selector: Sequence[str], /) -> Segment | None:
"""Get a variable value (read-only)."""
...

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from copy import deepcopy
from typing import Any
@ -18,9 +18,9 @@ class ReadOnlyVariablePoolWrapper:
def __init__(self, variable_pool: VariablePool) -> None:
self._variable_pool = variable_pool
def get(self, node_id: str, variable_key: str) -> Segment | None:
def get(self, selector: Sequence[str], /) -> Segment | None:
"""Return a copy of a variable value if present."""
value = self._variable_pool.get([node_id, variable_key])
value = self._variable_pool.get(selector)
return deepcopy(value) if value is not None else None
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:

View File

@ -3,8 +3,9 @@
set -e
# Set UTF-8 encoding to address potential encoding issues in containerized environments
export LANG=${LANG:-en_US.UTF-8}
export LC_ALL=${LC_ALL:-en_US.UTF-8}
# Use C.UTF-8 which is universally available in all containers
export LANG=${LANG:-C.UTF-8}
export LC_ALL=${LC_ALL:-C.UTF-8}
export PYTHONIOENCODING=${PYTHONIOENCODING:-utf-8}
if [[ "${MIGRATION_ENABLED}" == "true" ]]; then

View File

@ -6,6 +6,7 @@ BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEAD
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE)
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
@ -42,10 +43,28 @@ def init_app(app: DifyApp):
_apply_cors_once(
web_bp,
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=list(AUTHENTICATED_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
resources={
# Embedded bot endpoints (unauthenticated, cross-origin safe)
r"^/chat-messages$": {
"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
"supports_credentials": False,
"allow_headers": list(EMBED_HEADERS),
"methods": ["GET", "POST", "OPTIONS"],
},
r"^/chat-messages/.*": {
"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
"supports_credentials": False,
"allow_headers": list(EMBED_HEADERS),
"methods": ["GET", "POST", "OPTIONS"],
},
# Default web application endpoints (authenticated)
r"/*": {
"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
"supports_credentials": True,
"allow_headers": list(AUTHENTICATED_HEADERS),
"methods": ["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
},
},
expose_headers=list(EXPOSED_HEADERS),
)
app.register_blueprint(web_bp)

View File

@ -11,6 +11,7 @@ def init_app(app: DifyApp):
create_tenant,
extract_plugins,
extract_unique_plugins,
file_usage,
fix_app_site_missing,
install_plugins,
install_rag_pipeline_plugins,
@ -47,6 +48,7 @@ def init_app(app: DifyApp):
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
remove_orphaned_files_on_storage,
file_usage,
setup_system_tool_oauth_client,
setup_system_trigger_oauth_client,
cleanup_orphaned_draft_variables,

View File

@ -53,3 +53,10 @@ def _setup_gevent_compatibility():
def init_app(app: DifyApp):
db.init_app(app)
_setup_gevent_compatibility()
# Eagerly build the engine so pool_size/max_overflow/etc. come from config
try:
with app.app_context():
_ = db.engine # triggers engine creation with the configured options
except Exception:
logger.exception("Failed to initialize SQLAlchemy engine during app startup")

View File

@ -22,6 +22,18 @@ from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
def to_serializable(obj):
"""
Convert non-JSON-serializable objects into JSON-compatible formats.
- Uses `to_dict()` if it's a callable method.
- Falls back to string representation.
"""
if hasattr(obj, "to_dict") and callable(obj.to_dict):
return obj.to_dict()
return str(obj)
class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__(
self,
@ -69,6 +81,11 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
# Set to True to enable dual-write for safe migration, False to use LogStore only
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
# Control flag for whether to write the `graph` field to LogStore.
# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
# otherwise write an empty {} instead. Defaults to writing the `graph` field.
self._enable_put_graph_field = os.environ.get("LOGSTORE_ENABLE_PUT_GRAPH_FIELD", "true").lower() == "true"
def _to_logstore_model(self, domain_model: WorkflowExecution) -> list[tuple[str, str]]:
"""
Convert a domain model to a logstore model (List[Tuple[str, str]]).
@ -108,9 +125,24 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
),
("type", domain_model.workflow_type.value),
("version", domain_model.workflow_version),
("graph", json.dumps(domain_model.graph, ensure_ascii=False) if domain_model.graph else "{}"),
("inputs", json.dumps(domain_model.inputs, ensure_ascii=False) if domain_model.inputs else "{}"),
("outputs", json.dumps(domain_model.outputs, ensure_ascii=False) if domain_model.outputs else "{}"),
(
"graph",
json.dumps(domain_model.graph, ensure_ascii=False, default=to_serializable)
if domain_model.graph and self._enable_put_graph_field
else "{}",
),
(
"inputs",
json.dumps(domain_model.inputs, ensure_ascii=False, default=to_serializable)
if domain_model.inputs
else "{}",
),
(
"outputs",
json.dumps(domain_model.outputs, ensure_ascii=False, default=to_serializable)
if domain_model.outputs
else "{}",
),
("status", domain_model.status.value),
("error_message", domain_model.error_message or ""),
("total_tokens", str(domain_model.total_tokens)),

View File

@ -32,6 +32,38 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def escape_like_pattern(pattern: str) -> str:
"""
Escape special characters in a string for safe use in SQL LIKE patterns.
This function escapes the special characters used in SQL LIKE patterns:
- Backslash (\\) -> \\
- Percent (%) -> \\%
- Underscore (_) -> \\_
The escaped pattern can then be safely used in SQL LIKE queries with the
ESCAPE '\\' clause to prevent SQL injection via LIKE wildcards.
Args:
pattern: The string pattern to escape
Returns:
Escaped string safe for use in SQL LIKE queries
Examples:
>>> escape_like_pattern("50% discount")
'50\\% discount'
>>> escape_like_pattern("test_data")
'test\\_data'
>>> escape_like_pattern("path\\to\\file")
'path\\\\to\\\\file'
"""
if not pattern:
return pattern
# Escape backslash first, then percent and underscore
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
"""
Extract tenant_id from Account or EndUser object.

View File

@ -70,6 +70,7 @@ class AppMode(StrEnum):
class IconType(StrEnum):
IMAGE = auto()
EMOJI = auto()
LINK = auto()
class App(Base):
@ -81,7 +82,7 @@ class App(Base):
name: Mapped[str] = mapped_column(String(255))
description: Mapped[str] = mapped_column(LongText, default=sa.text("''"))
mode: Mapped[str] = mapped_column(String(255))
icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji
icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji, link
icon = mapped_column(String(255))
icon_background: Mapped[str | None] = mapped_column(String(255))
app_model_config_id = mapped_column(StringUUID, nullable=True)
@ -1419,15 +1420,20 @@ class MessageAnnotation(Base):
app_id: Mapped[str] = mapped_column(StringUUID)
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"))
message_id: Mapped[str | None] = mapped_column(StringUUID)
question = mapped_column(LongText, nullable=True)
content = mapped_column(LongText, nullable=False)
question: Mapped[str | None] = mapped_column(LongText, nullable=True)
content: Mapped[str] = mapped_column(LongText, nullable=False)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
account_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def question_text(self) -> str:
"""Return a non-null question string, falling back to the answer content."""
return self.question or self.content
@property
def account(self):
account = db.session.query(Account).where(Account.id == self.account_id).first()

View File

@ -1506,6 +1506,7 @@ class WorkflowDraftVariable(Base):
file_id: str | None = None,
) -> "WorkflowDraftVariable":
variable = WorkflowDraftVariable()
variable.id = str(uuid4())
variable.created_at = naive_utc_now()
variable.updated_at = naive_utc_now()
variable.description = description

View File

@ -77,7 +77,7 @@ class AppAnnotationService:
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
annotation.question,
question,
current_tenant_id,
app_id,
annotation_setting.collection_binding_id,
@ -137,13 +137,16 @@ class AppAnnotationService:
if not app:
raise NotFound("App not found")
if keyword:
from libs.helper import escape_like_pattern
escaped_keyword = escape_like_pattern(keyword)
stmt = (
select(MessageAnnotation)
.where(MessageAnnotation.app_id == app_id)
.where(
or_(
MessageAnnotation.question.ilike(f"%{keyword}%"),
MessageAnnotation.content.ilike(f"%{keyword}%"),
MessageAnnotation.question.ilike(f"%{escaped_keyword}%", escape="\\"),
MessageAnnotation.content.ilike(f"%{escaped_keyword}%", escape="\\"),
)
)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
@ -253,7 +256,7 @@ class AppAnnotationService:
if app_annotation_setting:
update_annotation_to_index_task.delay(
annotation.id,
annotation.question,
annotation.question_text,
current_tenant_id,
app_id,
app_annotation_setting.collection_binding_id,

View File

@ -32,7 +32,7 @@ from extensions.ext_redis import redis_client
from factories import variable_factory
from libs.datetime_utils import naive_utc_now
from models import Account, App, AppMode
from models.model import AppModelConfig
from models.model import AppModelConfig, IconType
from models.workflow import Workflow
from services.plugin.dependencies_analysis import DependenciesAnalysisService
from services.workflow_draft_variable_service import WorkflowDraftVariableService
@ -428,10 +428,10 @@ class AppDslService:
# Set icon type
icon_type_value = icon_type or app_data.get("icon_type")
if icon_type_value in ["emoji", "link", "image"]:
if icon_type_value in [IconType.EMOJI.value, IconType.IMAGE.value, IconType.LINK.value]:
icon_type = icon_type_value
else:
icon_type = "emoji"
icon_type = IconType.EMOJI.value
icon = icon or str(app_data.get("icon", ""))
if app:

View File

@ -55,8 +55,11 @@ class AppService:
if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
if args.get("name"):
from libs.helper import escape_like_pattern
name = args["name"][:30]
filters.append(App.name.ilike(f"%{name}%"))
escaped_name = escape_like_pattern(name)
filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\"))
# Check if tag_ids is not empty to avoid WHERE false condition
if args.get("tag_ids") and len(args["tag_ids"]) > 0:
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"])

View File

@ -11,13 +11,13 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.db.session_factory import session_factory
from core.llm_generator.llm_generator import LLMGenerator
from core.variables.types import SegmentType
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
from extensions.ext_database import db
from factories import variable_factory
from libs.datetime_utils import naive_utc_now
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account, ConversationVariable
from models.model import App, Conversation, EndUser, Message
from services.conversation_variable_updater import ConversationVariableUpdater
from services.errors.conversation import (
ConversationNotExistsError,
ConversationVariableNotExistsError,
@ -218,7 +218,9 @@ class ConversationService:
# Apply variable_name filter if provided
if variable_name:
# Filter using JSON extraction to match variable names case-insensitively
escaped_variable_name = variable_name.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
from libs.helper import escape_like_pattern
escaped_variable_name = escape_like_pattern(variable_name)
# Filter using JSON extraction to match variable names case-insensitively
if dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]:
stmt = stmt.where(
@ -335,7 +337,7 @@ class ConversationService:
updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict)
# Use the conversation variable updater to persist the changes
updater = conversation_variable_updater_factory()
updater = ConversationVariableUpdater(session_factory.get_session_maker())
updater.update(conversation_id, updated_variable)
updater.flush()

View File

@ -0,0 +1,28 @@
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.variables.variables import Variable
from models import ConversationVariable
class ConversationVariableNotFoundError(Exception):
pass
class ConversationVariableUpdater:
def __init__(self, session_maker: sessionmaker[Session]) -> None:
self._session_maker: sessionmaker[Session] = session_maker
def update(self, conversation_id: str, variable: Variable) -> None:
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with self._session_maker() as session:
row = session.scalar(stmt)
if not row:
raise ConversationVariableNotFoundError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()
def flush(self) -> None:
pass

View File

@ -144,7 +144,8 @@ class DatasetService:
query = query.where(Dataset.permission == DatasetPermissionEnum.ALL_TEAM)
if search:
query = query.where(Dataset.name.ilike(f"%{search}%"))
escaped_search = helper.escape_like_pattern(search)
query = query.where(Dataset.name.ilike(f"%{escaped_search}%", escape="\\"))
# Check if tag_ids is not empty to avoid WHERE false condition
if tag_ids and len(tag_ids) > 0:
@ -3423,7 +3424,8 @@ class SegmentService:
.order_by(ChildChunk.position.asc())
)
if keyword:
query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
escaped_keyword = helper.escape_like_pattern(keyword)
query = query.where(ChildChunk.content.ilike(f"%{escaped_keyword}%", escape="\\"))
return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@classmethod
@ -3456,7 +3458,8 @@ class SegmentService:
query = query.where(DocumentSegment.status.in_(status_list))
if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
escaped_keyword = helper.escape_like_pattern(keyword)
query = query.where(DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"))
query = query.order_by(DocumentSegment.position.asc(), DocumentSegment.id.asc())
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)

View File

@ -35,7 +35,10 @@ class ExternalDatasetService:
.order_by(ExternalKnowledgeApis.created_at.desc())
)
if search:
query = query.where(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
from libs.helper import escape_like_pattern
escaped_search = escape_like_pattern(search)
query = query.where(ExternalKnowledgeApis.name.ilike(f"%{escaped_search}%", escape="\\"))
external_knowledge_apis = db.paginate(
select=query, page=page, per_page=per_page, max_per_page=100, error_out=False

View File

@ -19,7 +19,10 @@ class TagService:
.where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
)
if keyword:
query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%")))
from libs.helper import escape_like_pattern
escaped_keyword = escape_like_pattern(keyword)
query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\")))
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
results: list = query.order_by(Tag.created_at.desc()).all()
return results

View File

@ -853,7 +853,7 @@ class TriggerProviderService:
"""
Create a subscription builder for rebuilding an existing subscription.
This method creates a builder pre-filled with data from the rebuild request,
This method rebuild the subscription by call DELETE and CREATE API of the third party provider(e.g. GitHub)
keeping the same subscription_id and endpoint_id so the webhook URL remains unchanged.
:param tenant_id: Tenant ID
@ -868,111 +868,50 @@ class TriggerProviderService:
if not provider_controller:
raise ValueError(f"Provider {provider_id} not found")
# Use distributed lock to prevent race conditions on the same subscription
lock_key = f"trigger_subscription_rebuild_lock:{tenant_id}_{subscription_id}"
with redis_client.lock(lock_key, timeout=20):
with Session(db.engine, expire_on_commit=False) as session:
try:
# Get subscription within the transaction
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
)
if not subscription:
raise ValueError(f"Subscription {subscription_id} not found")
subscription = TriggerProviderService.get_subscription_by_id(
tenant_id=tenant_id,
subscription_id=subscription_id,
)
if not subscription:
raise ValueError(f"Subscription {subscription_id} not found")
credential_type = CredentialType.of(subscription.credential_type)
if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
raise ValueError("Credential type not supported for rebuild")
credential_type = CredentialType.of(subscription.credential_type)
if credential_type not in {CredentialType.OAUTH2, CredentialType.API_KEY}:
raise ValueError(f"Credential type {credential_type} not supported for auto creation")
# Decrypt existing credentials for merging
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
decrypted_credentials = dict(credential_encrypter.decrypt(subscription.credentials))
# Delete the previous subscription
user_id = subscription.user_id
unsubscribe_result = TriggerManager.unsubscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
subscription=subscription.to_entity(),
credentials=subscription.credentials,
credential_type=credential_type,
)
if not unsubscribe_result.success:
raise ValueError(f"Failed to delete previous subscription: {unsubscribe_result.message}")
# Merge credentials: if caller passed HIDDEN_VALUE, retain existing decrypted value
merged_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else decrypted_credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
user_id = subscription.user_id
# TODO: Trying to invoke update api of the plugin trigger provider
# FALLBACK: If the update api is not implemented,
# delete the previous subscription and create a new one
# Unsubscribe the previous subscription (external call, but we'll handle errors)
try:
TriggerManager.unsubscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
subscription=subscription.to_entity(),
credentials=decrypted_credentials,
credential_type=credential_type,
)
except Exception as e:
logger.exception("Error unsubscribing trigger during rebuild", exc_info=e)
# Continue anyway - the subscription might already be deleted externally
# Create a new subscription with the same subscription_id and endpoint_id (external call)
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
parameters=parameters,
credentials=merged_credentials,
credential_type=credential_type,
)
# Update the subscription in the same transaction
# Inline update logic to reuse the same session
if name is not None and name != subscription.name:
existing = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
.first()
)
if existing and existing.id != subscription.id:
raise ValueError(f"Subscription name '{name}' already exists for this provider")
subscription.name = name
# Update parameters
subscription.parameters = dict(parameters)
# Update credentials with merged (and encrypted) values
subscription.credentials = dict(credential_encrypter.encrypt(merged_credentials))
# Update properties
if new_subscription.properties:
properties_encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=provider_controller.get_properties_schema(),
cache=NoOpProviderCredentialCache(),
)
subscription.properties = dict(properties_encrypter.encrypt(dict(new_subscription.properties)))
# Update expiration timestamp
if new_subscription.expires_at is not None:
subscription.expires_at = new_subscription.expires_at
# Commit the transaction
session.commit()
# Clear subscription cache
delete_cache_for_subscription(
tenant_id=tenant_id,
provider_id=subscription.provider_id,
subscription_id=subscription.id,
)
except Exception as e:
# Rollback on any error
session.rollback()
logger.exception("Failed to rebuild trigger subscription", exc_info=e)
raise
# Create a new subscription with the same subscription_id and endpoint_id
new_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
parameters=parameters,
credentials=new_credentials,
credential_type=credential_type,
)
TriggerProviderService.update_trigger_subscription(
tenant_id=tenant_id,
subscription_id=subscription.id,
name=name,
parameters=parameters,
credentials=new_credentials,
properties=new_subscription.properties,
expires_at=new_subscription.expires_at,
)

View File

@ -86,12 +86,19 @@ class WorkflowAppService:
# Join to workflow run for filtering when needed.
if keyword:
keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
from libs.helper import escape_like_pattern
# Escape special characters in keyword to prevent SQL injection via LIKE wildcards
escaped_keyword = escape_like_pattern(keyword[:30])
keyword_like_val = f"%{escaped_keyword}%"
keyword_conditions = [
WorkflowRun.inputs.ilike(keyword_like_val),
WorkflowRun.outputs.ilike(keyword_like_val),
WorkflowRun.inputs.ilike(keyword_like_val, escape="\\"),
WorkflowRun.outputs.ilike(keyword_like_val, escape="\\"),
# filter keyword by end user session id if created by end user role
and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)),
and_(
WorkflowRun.created_by_role == "end_user",
EndUser.session_id.ilike(keyword_like_val, escape="\\"),
),
]
# filter keyword by workflow run id

View File

@ -679,6 +679,7 @@ def _batch_upsert_draft_variable(
def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]:
d: dict[str, Any] = {
"id": model.id,
"app_id": model.app_id,
"last_edited_at": None,
"node_id": model.node_id,

View File

@ -98,7 +98,7 @@ def enable_annotation_reply_task(
if annotations:
for annotation in annotations:
document = Document(
page_content=annotation.question,
page_content=annotation.question_text,
metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
)
documents.append(document)

View File

@ -35,6 +35,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError
from core.workflow.graph_events.graph import GraphRunPausedEvent
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
@ -569,10 +570,10 @@ class TestPauseStatePersistenceLayerTestContainers:
"""Test that layer requires proper initialization before handling events."""
# Arrange
layer = self._create_pause_state_persistence_layer()
# Don't initialize - graph_runtime_state should not be set
# Don't initialize - graph_runtime_state should be uninitialized
event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
# Act & Assert - Should raise AttributeError
with pytest.raises(AttributeError):
# Act & Assert - Should raise GraphEngineLayerNotInitializedError
with pytest.raises(GraphEngineLayerNotInitializedError):
layer.on_event(event)

View File

@ -444,6 +444,78 @@ class TestAnnotationService:
assert total == 1
assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content
def test_get_annotation_list_by_app_id_with_special_characters_in_keyword(
self, db_session_with_containers, mock_external_service_dependencies
):
r"""
Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention.
This test verifies:
- Special characters (%, _, \) in keyword are properly escaped
- Search treats special characters as literal characters, not wildcards
- SQL injection via LIKE wildcards is prevented
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create annotations with special characters in content
annotation_with_percent = {
"question": "Question with 50% discount",
"answer": "Answer about 50% discount offer",
}
AppAnnotationService.insert_app_annotation_directly(annotation_with_percent, app.id)
annotation_with_underscore = {
"question": "Question with test_data",
"answer": "Answer about test_data value",
}
AppAnnotationService.insert_app_annotation_directly(annotation_with_underscore, app.id)
annotation_with_backslash = {
"question": "Question with path\\to\\file",
"answer": "Answer about path\\to\\file location",
}
AppAnnotationService.insert_app_annotation_directly(annotation_with_backslash, app.id)
# Create annotation that should NOT match (contains % but as part of different text)
annotation_no_match = {
"question": "Question with 100% different",
"answer": "Answer about 100% different content",
}
AppAnnotationService.insert_app_annotation_directly(annotation_no_match, app.id)
# Test 1: Search with % character - should find exact match only
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
app.id, page=1, limit=10, keyword="50%"
)
assert total == 1
assert len(annotation_list) == 1
assert "50%" in annotation_list[0].question or "50%" in annotation_list[0].content
# Test 2: Search with _ character - should find exact match only
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
app.id, page=1, limit=10, keyword="test_data"
)
assert total == 1
assert len(annotation_list) == 1
assert "test_data" in annotation_list[0].question or "test_data" in annotation_list[0].content
# Test 3: Search with \ character - should find exact match only
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
app.id, page=1, limit=10, keyword="path\\to\\file"
)
assert total == 1
assert len(annotation_list) == 1
assert "path\\to\\file" in annotation_list[0].question or "path\\to\\file" in annotation_list[0].content
# Test 4: Search with % should NOT match 100% (verifies escaping works)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
app.id, page=1, limit=10, keyword="50%"
)
# Should only find the 50% annotation, not the 100% one
assert total == 1
assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list)
def test_get_annotation_list_by_app_id_app_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):

View File

@ -7,7 +7,9 @@ from constants.model_template import default_app_templates
from models import Account
from models.model import App, Site
from services.account_service import AccountService, TenantService
from services.app_service import AppService
# Delay import of AppService to avoid circular dependency
# from services.app_service import AppService
class TestAppService:
@ -71,6 +73,9 @@ class TestAppService:
}
# Create app
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -109,6 +114,9 @@ class TestAppService:
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
# Test different app modes
@ -159,6 +167,9 @@ class TestAppService:
"icon_background": "#45B7D1",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
created_app = app_service.create_app(tenant.id, app_args, account)
@ -194,6 +205,9 @@ class TestAppService:
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
# Create multiple apps
@ -245,6 +259,9 @@ class TestAppService:
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
# Create apps with different modes
@ -315,6 +332,9 @@ class TestAppService:
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
# Create an app
@ -392,6 +412,9 @@ class TestAppService:
"icon_background": "#45B7D1",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -458,6 +481,9 @@ class TestAppService:
"icon_background": "#45B7D1",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -508,6 +534,9 @@ class TestAppService:
"icon_background": "#45B7D1",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -562,6 +591,9 @@ class TestAppService:
"icon_background": "#74B9FF",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -617,6 +649,9 @@ class TestAppService:
"icon_background": "#A29BFE",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -672,6 +707,9 @@ class TestAppService:
"icon_background": "#FD79A8",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -720,6 +758,9 @@ class TestAppService:
"icon_background": "#E17055",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -768,6 +809,9 @@ class TestAppService:
"icon_background": "#00B894",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -826,6 +870,9 @@ class TestAppService:
"icon_background": "#6C5CE7",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -862,6 +909,9 @@ class TestAppService:
"icon_background": "#FDCB6E",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -899,6 +949,9 @@ class TestAppService:
"icon_background": "#E84393",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -947,8 +1000,132 @@ class TestAppService:
"icon_background": "#D63031",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
# Attempt to create app with invalid mode
with pytest.raises(ValueError, match="invalid mode value"):
app_service.create_app(tenant.id, app_args, account)
def test_get_apps_with_special_characters_in_name(
self, db_session_with_containers, mock_external_service_dependencies
):
r"""
Test app retrieval with special characters in name search to verify SQL injection prevention.
This test verifies:
- Special characters (%, _, \) in name search are properly escaped
- Search treats special characters as literal characters, not wildcards
- SQL injection via LIKE wildcards is prevented
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
# Create apps with special characters in names
app_with_percent = app_service.create_app(
tenant.id,
{
"name": "App with 50% discount",
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
},
account,
)
app_with_underscore = app_service.create_app(
tenant.id,
{
"name": "test_data_app",
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
},
account,
)
app_with_backslash = app_service.create_app(
tenant.id,
{
"name": "path\\to\\app",
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
},
account,
)
# Create app that should NOT match
app_no_match = app_service.create_app(
tenant.id,
{
"name": "100% different",
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
},
account,
)
# Test 1: Search with % character
args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
assert paginated_apps is not None
assert paginated_apps.total == 1
assert len(paginated_apps.items) == 1
assert paginated_apps.items[0].name == "App with 50% discount"
# Test 2: Search with _ character
args = {"name": "test_data", "mode": "chat", "page": 1, "limit": 10}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
assert paginated_apps is not None
assert paginated_apps.total == 1
assert len(paginated_apps.items) == 1
assert paginated_apps.items[0].name == "test_data_app"
# Test 3: Search with \ character
args = {"name": "path\\to\\app", "mode": "chat", "page": 1, "limit": 10}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
assert paginated_apps is not None
assert paginated_apps.total == 1
assert len(paginated_apps.items) == 1
assert paginated_apps.items[0].name == "path\\to\\app"
# Test 4: Search with % should NOT match 100% (verifies escaping works)
args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
assert paginated_apps is not None
assert paginated_apps.total == 1
assert all("50%" in app.name for app in paginated_apps.items)

View File

@ -1,3 +1,4 @@
import uuid
from unittest.mock import create_autospec, patch
import pytest
@ -312,6 +313,85 @@ class TestTagService:
result_no_match = TagService.get_tags("app", tenant.id, keyword="nonexistent")
assert len(result_no_match) == 0
def test_get_tags_with_special_characters_in_keyword(
self, db_session_with_containers, mock_external_service_dependencies
):
r"""
Test tag retrieval with special characters in keyword to verify SQL injection prevention.
This test verifies:
- Special characters (%, _, \) in keyword are properly escaped
- Search treats special characters as literal characters, not wildcards
- SQL injection via LIKE wildcards is prevented
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
from extensions.ext_database import db
# Create tags with special characters in names
tag_with_percent = Tag(
name="50% discount",
type="app",
tenant_id=tenant.id,
created_by=account.id,
)
tag_with_percent.id = str(uuid.uuid4())
db.session.add(tag_with_percent)
tag_with_underscore = Tag(
name="test_data_tag",
type="app",
tenant_id=tenant.id,
created_by=account.id,
)
tag_with_underscore.id = str(uuid.uuid4())
db.session.add(tag_with_underscore)
tag_with_backslash = Tag(
name="path\\to\\tag",
type="app",
tenant_id=tenant.id,
created_by=account.id,
)
tag_with_backslash.id = str(uuid.uuid4())
db.session.add(tag_with_backslash)
# Create tag that should NOT match
tag_no_match = Tag(
name="100% different",
type="app",
tenant_id=tenant.id,
created_by=account.id,
)
tag_no_match.id = str(uuid.uuid4())
db.session.add(tag_no_match)
db.session.commit()
# Act & Assert: Test 1 - Search with % character
result = TagService.get_tags("app", tenant.id, keyword="50%")
assert len(result) == 1
assert result[0].name == "50% discount"
# Test 2 - Search with _ character
result = TagService.get_tags("app", tenant.id, keyword="test_data")
assert len(result) == 1
assert result[0].name == "test_data_tag"
# Test 3 - Search with \ character
result = TagService.get_tags("app", tenant.id, keyword="path\\to\\tag")
assert len(result) == 1
assert result[0].name == "path\\to\\tag"
# Test 4 - Search with % should NOT match 100% (verifies escaping works)
result = TagService.get_tags("app", tenant.id, keyword="50%")
assert len(result) == 1
assert all("50%" in item.name for item in result)
def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test tag retrieval when no tags exist.

View File

@ -474,64 +474,6 @@ class TestTriggerProviderService:
assert subscription.name == original_name
assert subscription.parameters == original_parameters
def test_rebuild_trigger_subscription_unsubscribe_error_continues(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that unsubscribe errors are handled gracefully and operation continues.
This test verifies:
- Unsubscribe errors are caught and logged but don't stop the rebuild
- Rebuild continues even if unsubscribe fails
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
original_credentials = {"api_key": "original-key"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# Make unsubscribe_trigger raise an error (should be caught and continue)
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.side_effect = ValueError(
"Unsubscribe failed"
)
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
# Execute rebuild - should succeed despite unsubscribe error
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials={"api_key": "new-key"},
parameters={},
)
# Verify subscribe was still called (operation continued)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once()
# Verify subscription was updated
db.session.refresh(subscription)
assert subscription.parameters == {}
def test_rebuild_trigger_subscription_subscription_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
@ -558,70 +500,6 @@ class TestTriggerProviderService:
parameters={},
)
def test_rebuild_trigger_subscription_provider_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error when provider is not found.
This test verifies:
- Proper error is raised when provider doesn't exist
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("non_existent_org/non_existent_plugin/non_existent_provider")
# Make get_trigger_provider return None
mock_external_service_dependencies["trigger_manager"].get_trigger_provider.return_value = None
with pytest.raises(ValueError, match="Provider.*not found"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=fake.uuid4(),
credentials={},
parameters={},
)
def test_rebuild_trigger_subscription_unsupported_credential_type(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error when credential type is not supported for rebuild.
This test verifies:
- Proper error is raised for unsupported credential types (not OAUTH2 or API_KEY)
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.UNAUTHORIZED # Not supported
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
{},
mock_external_service_dependencies,
)
with pytest.raises(ValueError, match="Credential type not supported for rebuild"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials={},
parameters={},
)
def test_rebuild_trigger_subscription_name_uniqueness_check(
self, db_session_with_containers, mock_external_service_dependencies
):

View File

@ -10,7 +10,9 @@ from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun
from models.enums import CreatorUserRole
from services.account_service import AccountService, TenantService
from services.app_service import AppService
# Delay import of AppService to avoid circular dependency
# from services.app_service import AppService
from services.workflow_app_service import WorkflowAppService
@ -86,6 +88,9 @@ class TestWorkflowAppService:
"api_rpm": 10,
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -147,6 +152,9 @@ class TestWorkflowAppService:
"api_rpm": 10,
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -308,6 +316,156 @@ class TestWorkflowAppService:
assert result_no_match["total"] == 0
assert len(result_no_match["data"]) == 0
def test_get_paginate_workflow_app_logs_with_special_characters_in_keyword(
self, db_session_with_containers, mock_external_service_dependencies
):
r"""
Test workflow app logs pagination with special characters in keyword to verify SQL injection prevention.
This test verifies:
- Special characters (%, _) in keyword are properly escaped
- Search treats special characters as literal characters, not wildcards
- SQL injection via LIKE wildcards is prevented
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account)
from extensions.ext_database import db
service = WorkflowAppService()
# Test 1: Search with % character
workflow_run_1 = WorkflowRun(
id=str(uuid.uuid4()),
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow.id,
type="workflow",
triggered_from="app-run",
version="1.0.0",
graph=json.dumps({"nodes": [], "edges": []}),
status="succeeded",
inputs=json.dumps({"search_term": "50% discount", "input2": "other_value"}),
outputs=json.dumps({"result": "50% discount applied", "status": "success"}),
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC),
)
db.session.add(workflow_run_1)
db.session.flush()
workflow_app_log_1 = WorkflowAppLog(
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_1.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
workflow_app_log_1.id = str(uuid.uuid4())
workflow_app_log_1.created_at = datetime.now(UTC)
db.session.add(workflow_app_log_1)
db.session.commit()
result = service.get_paginate_workflow_app_logs(
session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20
)
# Should find the workflow_run_1 entry
assert result["total"] >= 1
assert len(result["data"]) >= 1
assert any(log.workflow_run_id == workflow_run_1.id for log in result["data"])
# Test 2: Search with _ character
workflow_run_2 = WorkflowRun(
id=str(uuid.uuid4()),
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow.id,
type="workflow",
triggered_from="app-run",
version="1.0.0",
graph=json.dumps({"nodes": [], "edges": []}),
status="succeeded",
inputs=json.dumps({"search_term": "test_data_value", "input2": "other_value"}),
outputs=json.dumps({"result": "test_data_value found", "status": "success"}),
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC),
)
db.session.add(workflow_run_2)
db.session.flush()
workflow_app_log_2 = WorkflowAppLog(
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_2.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
workflow_app_log_2.id = str(uuid.uuid4())
workflow_app_log_2.created_at = datetime.now(UTC)
db.session.add(workflow_app_log_2)
db.session.commit()
result = service.get_paginate_workflow_app_logs(
session=db_session_with_containers, app_model=app, keyword="test_data", page=1, limit=20
)
# Should find the workflow_run_2 entry
assert result["total"] >= 1
assert len(result["data"]) >= 1
assert any(log.workflow_run_id == workflow_run_2.id for log in result["data"])
# Test 3: Search with % should NOT match 100% (verifies escaping works correctly)
workflow_run_4 = WorkflowRun(
id=str(uuid.uuid4()),
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow.id,
type="workflow",
triggered_from="app-run",
version="1.0.0",
graph=json.dumps({"nodes": [], "edges": []}),
status="succeeded",
inputs=json.dumps({"search_term": "100% different", "input2": "other_value"}),
outputs=json.dumps({"result": "100% different result", "status": "success"}),
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC),
)
db.session.add(workflow_run_4)
db.session.flush()
workflow_app_log_4 = WorkflowAppLog(
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_4.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
workflow_app_log_4.id = str(uuid.uuid4())
workflow_app_log_4.created_at = datetime.now(UTC)
db.session.add(workflow_app_log_4)
db.session.commit()
result = service.get_paginate_workflow_app_logs(
session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20
)
# Should only find the 50% entry (workflow_run_1), not the 100% entry (workflow_run_4)
# This verifies that escaping works correctly - 50% should not match 100%
assert result["total"] >= 1
assert len(result["data"]) >= 1
# Verify that we found workflow_run_1 (50% discount) but not workflow_run_4 (100% different)
found_run_ids = [log.workflow_run_id for log in result["data"]]
assert workflow_run_1.id in found_run_ids
assert workflow_run_4.id not in found_run_ids
def test_get_paginate_workflow_app_logs_with_status_filter(
self, db_session_with_containers, mock_external_service_dependencies
):

View File

@ -0,0 +1,285 @@
from __future__ import annotations
import builtins
import sys
from datetime import datetime
from importlib import util
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any
import pytest
from flask.views import MethodView
# kombu references MethodView as a global when importing celery/kombu pools.
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
def _load_app_module():
module_name = "controllers.console.app.app"
if module_name in sys.modules:
return sys.modules[module_name]
root = Path(__file__).resolve().parents[5]
module_path = root / "controllers" / "console" / "app" / "app.py"
class _StubNamespace:
def __init__(self):
self.models: dict[str, Any] = {}
self.payload = None
def schema_model(self, name, schema):
self.models[name] = schema
def _decorator(self, obj):
return obj
def doc(self, *args, **kwargs):
return self._decorator
def expect(self, *args, **kwargs):
return self._decorator
def response(self, *args, **kwargs):
return self._decorator
def route(self, *args, **kwargs):
def decorator(obj):
return obj
return decorator
stub_namespace = _StubNamespace()
original_console = sys.modules.get("controllers.console")
original_app_pkg = sys.modules.get("controllers.console.app")
stubbed_modules: list[tuple[str, ModuleType | None]] = []
console_module = ModuleType("controllers.console")
console_module.__path__ = [str(root / "controllers" / "console")]
console_module.console_ns = stub_namespace
console_module.api = None
console_module.bp = None
sys.modules["controllers.console"] = console_module
app_package = ModuleType("controllers.console.app")
app_package.__path__ = [str(root / "controllers" / "console" / "app")]
sys.modules["controllers.console.app"] = app_package
console_module.app = app_package
def _stub_module(name: str, attrs: dict[str, Any]):
original = sys.modules.get(name)
module = ModuleType(name)
for key, value in attrs.items():
setattr(module, key, value)
sys.modules[name] = module
stubbed_modules.append((name, original))
class _OpsTraceManager:
@staticmethod
def get_app_tracing_config(app_id: str) -> dict[str, Any]:
return {}
@staticmethod
def update_app_tracing_config(app_id: str, **kwargs) -> None:
return None
_stub_module(
"core.ops.ops_trace_manager",
{
"OpsTraceManager": _OpsTraceManager,
"TraceQueueManager": object,
"TraceTask": object,
},
)
spec = util.spec_from_file_location(module_name, module_path)
module = util.module_from_spec(spec)
sys.modules[module_name] = module
try:
assert spec.loader is not None
spec.loader.exec_module(module)
finally:
for name, original in reversed(stubbed_modules):
if original is not None:
sys.modules[name] = original
else:
sys.modules.pop(name, None)
if original_console is not None:
sys.modules["controllers.console"] = original_console
else:
sys.modules.pop("controllers.console", None)
if original_app_pkg is not None:
sys.modules["controllers.console.app"] = original_app_pkg
else:
sys.modules.pop("controllers.console.app", None)
return module
_app_module = _load_app_module()
AppDetailWithSite = _app_module.AppDetailWithSite
AppPagination = _app_module.AppPagination
AppPartial = _app_module.AppPartial
@pytest.fixture(autouse=True)
def patch_signed_url(monkeypatch):
"""Ensure icon URL generation uses a deterministic helper for tests."""
def _fake_signed_url(key: str | None) -> str | None:
if not key:
return None
return f"signed:{key}"
monkeypatch.setattr(_app_module.file_helpers, "get_signed_file_url", _fake_signed_url)
def _ts(hour: int = 12) -> datetime:
return datetime(2024, 1, 1, hour, 0, 0)
def _dummy_model_config():
return SimpleNamespace(
model_dict={"provider": "openai", "name": "gpt-4o"},
pre_prompt="hello",
created_by="config-author",
created_at=_ts(9),
updated_by="config-editor",
updated_at=_ts(10),
)
def _dummy_workflow():
return SimpleNamespace(
id="wf-1",
created_by="workflow-author",
created_at=_ts(8),
updated_by="workflow-editor",
updated_at=_ts(9),
)
def test_app_partial_serialization_uses_aliases():
created_at = _ts()
app_obj = SimpleNamespace(
id="app-1",
name="My App",
desc_or_prompt="Prompt snippet",
mode_compatible_with_agent="chat",
icon_type="image",
icon="icon-key",
icon_background="#fff",
app_model_config=_dummy_model_config(),
workflow=_dummy_workflow(),
created_by="creator",
created_at=created_at,
updated_by="editor",
updated_at=created_at,
tags=[SimpleNamespace(id="tag-1", name="Utilities", type="app")],
access_mode="private",
create_user_name="Creator",
author_name="Author",
has_draft_trigger=True,
)
serialized = AppPartial.model_validate(app_obj, from_attributes=True).model_dump(mode="json")
assert serialized["description"] == "Prompt snippet"
assert serialized["mode"] == "chat"
assert serialized["icon_url"] == "signed:icon-key"
assert serialized["created_at"] == int(created_at.timestamp())
assert serialized["updated_at"] == int(created_at.timestamp())
assert serialized["model_config"]["model"] == {"provider": "openai", "name": "gpt-4o"}
assert serialized["workflow"]["id"] == "wf-1"
assert serialized["tags"][0]["name"] == "Utilities"
def test_app_detail_with_site_includes_nested_serialization():
timestamp = _ts(14)
site = SimpleNamespace(
code="site-code",
title="Public Site",
icon_type="image",
icon="site-icon",
created_at=timestamp,
updated_at=timestamp,
)
app_obj = SimpleNamespace(
id="app-2",
name="Detailed App",
description="Desc",
mode_compatible_with_agent="advanced-chat",
icon_type="image",
icon="detail-icon",
icon_background="#123456",
enable_site=True,
enable_api=True,
app_model_config={
"opening_statement": "hi",
"model": {"provider": "openai", "name": "gpt-4o"},
"retriever_resource": {"enabled": True},
},
workflow=_dummy_workflow(),
tracing={"enabled": True},
use_icon_as_answer_icon=True,
created_by="creator",
created_at=timestamp,
updated_by="editor",
updated_at=timestamp,
access_mode="public",
tags=[SimpleNamespace(id="tag-2", name="Prod", type="app")],
api_base_url="https://api.example.com/v1",
max_active_requests=5,
deleted_tools=[{"type": "api", "tool_name": "search", "provider_id": "prov"}],
site=site,
)
serialized = AppDetailWithSite.model_validate(app_obj, from_attributes=True).model_dump(mode="json")
assert serialized["icon_url"] == "signed:detail-icon"
assert serialized["model_config"]["retriever_resource"] == {"enabled": True}
assert serialized["deleted_tools"][0]["tool_name"] == "search"
assert serialized["site"]["icon_url"] == "signed:site-icon"
assert serialized["site"]["created_at"] == int(timestamp.timestamp())
def test_app_pagination_aliases_per_page_and_has_next():
item_one = SimpleNamespace(
id="app-10",
name="Paginated One",
desc_or_prompt="Summary",
mode_compatible_with_agent="chat",
icon_type="image",
icon="first-icon",
created_at=_ts(15),
updated_at=_ts(15),
)
item_two = SimpleNamespace(
id="app-11",
name="Paginated Two",
desc_or_prompt="Summary",
mode_compatible_with_agent="agent-chat",
icon_type="emoji",
icon="🙂",
created_at=_ts(16),
updated_at=_ts(16),
)
pagination = SimpleNamespace(
page=2,
per_page=10,
total=50,
has_next=True,
items=[item_one, item_two],
)
serialized = AppPagination.model_validate(pagination, from_attributes=True).model_dump(mode="json")
assert serialized["page"] == 2
assert serialized["limit"] == 10
assert serialized["has_more"] is True
assert len(serialized["data"]) == 2
assert serialized["data"][0]["icon_url"] == "signed:first-icon"
assert serialized["data"][1]["icon_url"] is None

View File

@ -0,0 +1,145 @@
"""
Test for document detail API data_source_info serialization fix.
This test verifies that the document detail API returns both data_source_info
and data_source_detail_dict for all data_source_type values, including "local_file".
"""
import json
from typing import Generic, Literal, NotRequired, TypedDict, TypeVar, Union
from models.dataset import Document
class LocalFileInfo(TypedDict):
file_path: str
size: int
created_at: NotRequired[str]
class UploadFileInfo(TypedDict):
upload_file_id: str
class NotionImportInfo(TypedDict):
notion_page_id: str
workspace_id: str
class WebsiteCrawlInfo(TypedDict):
url: str
job_id: str
RawInfo = Union[LocalFileInfo, UploadFileInfo, NotionImportInfo, WebsiteCrawlInfo]
T_type = TypeVar("T_type", bound=str)
T_info = TypeVar("T_info", bound=Union[LocalFileInfo, UploadFileInfo, NotionImportInfo, WebsiteCrawlInfo])
class Case(TypedDict, Generic[T_type, T_info]):
data_source_type: T_type
data_source_info: str
expected_raw: T_info
LocalFileCase = Case[Literal["local_file"], LocalFileInfo]
UploadFileCase = Case[Literal["upload_file"], UploadFileInfo]
NotionImportCase = Case[Literal["notion_import"], NotionImportInfo]
WebsiteCrawlCase = Case[Literal["website_crawl"], WebsiteCrawlInfo]
AnyCase = Union[LocalFileCase, UploadFileCase, NotionImportCase, WebsiteCrawlCase]
case_1: LocalFileCase = {
"data_source_type": "local_file",
"data_source_info": json.dumps({"file_path": "/tmp/test.txt", "size": 1024}),
"expected_raw": {"file_path": "/tmp/test.txt", "size": 1024},
}
# ERROR: Expected LocalFileInfo, but got WebsiteCrawlInfo
case_2: LocalFileCase = {
"data_source_type": "local_file",
"data_source_info": "...",
"expected_raw": {"file_path": "https://google.com", "size": 123},
}
cases: list[AnyCase] = [case_1]
class TestDocumentDetailDataSourceInfo:
"""Test cases for document detail API data_source_info serialization."""
def test_data_source_info_dict_returns_raw_data(self):
"""Test that data_source_info_dict returns raw JSON data for all data_source_type values."""
# Test data for different data_source_type values
for case in cases:
document = Document(
data_source_type=case["data_source_type"],
data_source_info=case["data_source_info"],
)
# Test data_source_info_dict (raw data)
raw_result = document.data_source_info_dict
assert raw_result == case["expected_raw"], f"Failed for {case['data_source_type']}"
# Verify raw_result is always a valid dict
assert isinstance(raw_result, dict)
def test_local_file_data_source_info_without_db_context(self):
"""Test that local_file type data_source_info_dict works without database context."""
test_data: LocalFileInfo = {
"file_path": "/local/path/document.txt",
"size": 512,
"created_at": "2024-01-01T00:00:00Z",
}
document = Document(
data_source_type="local_file",
data_source_info=json.dumps(test_data),
)
# data_source_info_dict should return the raw data (this doesn't need DB context)
raw_data = document.data_source_info_dict
assert raw_data == test_data
assert isinstance(raw_data, dict)
# Verify the data contains expected keys for pipeline mode
assert "file_path" in raw_data
assert "size" in raw_data
def test_notion_and_website_crawl_data_source_detail(self):
"""Test that notion_import and website_crawl return raw data in data_source_detail_dict."""
# Test notion_import
notion_data: NotionImportInfo = {"notion_page_id": "page-123", "workspace_id": "ws-456"}
document = Document(
data_source_type="notion_import",
data_source_info=json.dumps(notion_data),
)
# data_source_detail_dict should return raw data for notion_import
detail_result = document.data_source_detail_dict
assert detail_result == notion_data
# Test website_crawl
website_data: WebsiteCrawlInfo = {"url": "https://example.com", "job_id": "job-789"}
document = Document(
data_source_type="website_crawl",
data_source_info=json.dumps(website_data),
)
# data_source_detail_dict should return raw data for website_crawl
detail_result = document.data_source_detail_dict
assert detail_result == website_data
def test_local_file_data_source_detail_dict_without_db(self):
"""Test that local_file returns empty data_source_detail_dict (this doesn't need DB context)."""
# Test local_file - this should work without database context since it returns {} early
document = Document(
data_source_type="local_file",
data_source_info=json.dumps({"file_path": "/tmp/test.txt"}),
)
# Should return empty dict for local_file type (handled in the model)
detail_result = document.data_source_detail_dict
assert detail_result == {}

View File

@ -0,0 +1,144 @@
from collections.abc import Sequence
from datetime import datetime
from unittest.mock import Mock
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
from core.variables import StringVariable
from core.variables.segments import Segment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events.node import NodeRunSucceededEvent
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
from core.workflow.system_variable import SystemVariable
class MockReadOnlyVariablePool:
def __init__(self, variables: dict[tuple[str, str], Segment] | None = None) -> None:
self._variables = variables or {}
def get(self, selector: Sequence[str]) -> Segment | None:
if len(selector) < 2:
return None
return self._variables.get((selector[0], selector[1]))
def get_all_by_node(self, node_id: str) -> dict[str, object]:
return {key: value for (nid, key), value in self._variables.items() if nid == node_id}
def get_by_prefix(self, prefix: str) -> dict[str, object]:
return {key: value for (nid, key), value in self._variables.items() if nid == prefix}
def _build_graph_runtime_state(
variable_pool: MockReadOnlyVariablePool,
conversation_id: str | None = None,
) -> ReadOnlyGraphRuntimeState:
graph_runtime_state = Mock(spec=ReadOnlyGraphRuntimeState)
graph_runtime_state.variable_pool = variable_pool
graph_runtime_state.system_variable = SystemVariable(conversation_id=conversation_id).as_view()
return graph_runtime_state
def _build_node_run_succeeded_event(
*,
node_type: NodeType,
outputs: dict[str, object] | None = None,
process_data: dict[str, object] | None = None,
) -> NodeRunSucceededEvent:
return NodeRunSucceededEvent(
id="node-exec-id",
node_id="assigner",
node_type=node_type,
start_at=datetime.utcnow(),
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs or {},
process_data=process_data or {},
),
)
def test_persists_conversation_variables_from_assigner_output():
conversation_id = "conv-123"
variable = StringVariable(
id="var-1",
name="name",
value="updated",
selector=[CONVERSATION_VARIABLE_NODE_ID, "name"],
)
process_data = common_helpers.set_updated_variables(
{}, [common_helpers.variable_to_processed_data(variable.selector, variable)]
)
variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable})
updater = Mock()
layer = ConversationVariablePersistenceLayer(updater)
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data)
layer.on_event(event)
updater.update.assert_called_once_with(conversation_id=conversation_id, variable=variable)
updater.flush.assert_called_once()
def test_skips_when_outputs_missing():
conversation_id = "conv-456"
variable = StringVariable(
id="var-2",
name="name",
value="updated",
selector=[CONVERSATION_VARIABLE_NODE_ID, "name"],
)
variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable})
updater = Mock()
layer = ConversationVariablePersistenceLayer(updater)
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER)
layer.on_event(event)
updater.update.assert_not_called()
updater.flush.assert_not_called()
def test_skips_non_assigner_nodes():
updater = Mock()
layer = ConversationVariablePersistenceLayer(updater)
layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool()), Mock(spec=CommandChannel))
event = _build_node_run_succeeded_event(node_type=NodeType.LLM)
layer.on_event(event)
updater.update.assert_not_called()
updater.flush.assert_not_called()
def test_skips_non_conversation_variables():
conversation_id = "conv-789"
non_conversation_variable = StringVariable(
id="var-3",
name="name",
value="updated",
selector=["environment", "name"],
)
process_data = common_helpers.set_updated_variables(
{}, [common_helpers.variable_to_processed_data(non_conversation_variable.selector, non_conversation_variable)]
)
variable_pool = MockReadOnlyVariablePool()
updater = Mock()
layer = ConversationVariablePersistenceLayer(updater)
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data)
layer.on_event(event)
updater.update.assert_not_called()
updater.flush.assert_not_called()

View File

@ -1,4 +1,5 @@
import json
from collections.abc import Sequence
from time import time
from unittest.mock import Mock
@ -15,6 +16,7 @@ from core.app.layers.pause_state_persist_layer import (
from core.variables.segments import Segment
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError
from core.workflow.graph_events.graph import (
GraphRunFailedEvent,
GraphRunPausedEvent,
@ -66,8 +68,10 @@ class MockReadOnlyVariablePool:
def __init__(self, variables: dict[tuple[str, str], object] | None = None):
self._variables = variables or {}
def get(self, node_id: str, variable_key: str) -> Segment | None:
value = self._variables.get((node_id, variable_key))
def get(self, selector: Sequence[str]) -> Segment | None:
if len(selector) < 2:
return None
value = self._variables.get((selector[0], selector[1]))
if value is None:
return None
mock_segment = Mock(spec=Segment)
@ -209,8 +213,9 @@ class TestPauseStatePersistenceLayer:
assert layer._session_maker is session_factory
assert layer._state_owner_user_id == state_owner_user_id
assert not hasattr(layer, "graph_runtime_state")
assert not hasattr(layer, "command_channel")
with pytest.raises(GraphEngineLayerNotInitializedError):
_ = layer.graph_runtime_state
assert layer.command_channel is None
def test_initialize_sets_dependencies(self):
session_factory = Mock(name="session_factory")
@ -295,7 +300,7 @@ class TestPauseStatePersistenceLayer:
mock_factory.assert_not_called()
mock_repo.create_workflow_pause.assert_not_called()
def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self):
def test_on_event_raises_when_graph_runtime_state_is_uninitialized(self):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
@ -305,7 +310,7 @@ class TestPauseStatePersistenceLayer:
event = TestDataFactory.create_graph_run_paused_event()
with pytest.raises(AttributeError):
with pytest.raises(GraphEngineLayerNotInitializedError):
layer.on_event(event)
def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch):

View File

@ -1,8 +1,12 @@
"""Primarily used for testing merged cell scenarios"""
import os
import tempfile
from types import SimpleNamespace
from docx import Document
from docx.oxml import OxmlElement
from docx.oxml.ns import qn
import core.rag.extractor.word_extractor as we
from core.rag.extractor.word_extractor import WordExtractor
@ -165,3 +169,110 @@ def test_extract_images_from_docx_uses_internal_files_url():
dify_config.FILES_URL = original_files_url
if original_internal_files_url is not None:
dify_config.INTERNAL_FILES_URL = original_internal_files_url
def test_extract_hyperlinks(monkeypatch):
# Mock db and storage to avoid issues during image extraction (even if no images are present)
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None))
db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda o: None, commit=lambda: None))
monkeypatch.setattr(we, "db", db_stub)
monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False)
monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False)
doc = Document()
p = doc.add_paragraph("Visit ")
# Adding a hyperlink manually
r_id = "rId99"
hyperlink = OxmlElement("w:hyperlink")
hyperlink.set(qn("r:id"), r_id)
new_run = OxmlElement("w:r")
t = OxmlElement("w:t")
t.text = "Dify"
new_run.append(t)
hyperlink.append(new_run)
p._p.append(hyperlink)
# Add relationship to the part
doc.part.rels.add_relationship(
"http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink",
"https://dify.ai",
r_id,
is_external=True,
)
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp:
doc.save(tmp.name)
tmp_path = tmp.name
try:
extractor = WordExtractor(tmp_path, "tenant_id", "user_id")
docs = extractor.extract()
# Verify modern hyperlink extraction
assert "Visit[Dify](https://dify.ai)" in docs[0].page_content
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
def test_extract_legacy_hyperlinks(monkeypatch):
# Mock db and storage
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None))
db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda o: None, commit=lambda: None))
monkeypatch.setattr(we, "db", db_stub)
monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False)
monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False)
doc = Document()
p = doc.add_paragraph()
# Construct a legacy HYPERLINK field:
# 1. w:fldChar (begin)
# 2. w:instrText (HYPERLINK "http://example.com")
# 3. w:fldChar (separate)
# 4. w:r (visible text "Example")
# 5. w:fldChar (end)
run1 = OxmlElement("w:r")
fldCharBegin = OxmlElement("w:fldChar")
fldCharBegin.set(qn("w:fldCharType"), "begin")
run1.append(fldCharBegin)
p._p.append(run1)
run2 = OxmlElement("w:r")
instrText = OxmlElement("w:instrText")
instrText.text = ' HYPERLINK "http://example.com" '
run2.append(instrText)
p._p.append(run2)
run3 = OxmlElement("w:r")
fldCharSep = OxmlElement("w:fldChar")
fldCharSep.set(qn("w:fldCharType"), "separate")
run3.append(fldCharSep)
p._p.append(run3)
run4 = OxmlElement("w:r")
t4 = OxmlElement("w:t")
t4.text = "Example"
run4.append(t4)
p._p.append(run4)
run5 = OxmlElement("w:r")
fldCharEnd = OxmlElement("w:fldChar")
fldCharEnd.set(qn("w:fldCharType"), "end")
run5.append(fldCharEnd)
p._p.append(run5)
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp:
doc.save(tmp.name)
tmp_path = tmp.name
try:
extractor = WordExtractor(tmp_path, "tenant_id", "user_id")
docs = extractor.extract()
# Verify legacy hyperlink extraction
assert "[Example](http://example.com)" in docs[0].page_content
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)

View File

@ -0,0 +1,113 @@
import threading
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from flask import Flask, current_app
from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from models.dataset import Dataset
class TestRetrievalService:
@pytest.fixture
def mock_dataset(self) -> Dataset:
dataset = Mock(spec=Dataset)
dataset.id = str(uuid4())
dataset.tenant_id = str(uuid4())
dataset.name = "test_dataset"
dataset.indexing_technique = "high_quality"
dataset.provider = "dify"
return dataset
def test_multiple_retrieve_reranking_with_app_context(self, mock_dataset):
"""
Repro test for current bug:
reranking runs after `with flask_app.app_context():` exits.
`_multiple_retrieve_thread` catches exceptions and stores them into `thread_exceptions`,
so we must assert from that list (not from an outer try/except).
"""
dataset_retrieval = DatasetRetrieval()
flask_app = Flask(__name__)
tenant_id = str(uuid4())
# second dataset to ensure dataset_count > 1 reranking branch
secondary_dataset = Mock(spec=Dataset)
secondary_dataset.id = str(uuid4())
secondary_dataset.provider = "dify"
secondary_dataset.indexing_technique = "high_quality"
# retriever returns 1 doc into internal list (all_documents_item)
document = Document(
page_content="Context aware doc",
metadata={
"doc_id": "doc1",
"score": 0.95,
"document_id": str(uuid4()),
"dataset_id": mock_dataset.id,
},
provider="dify",
)
def fake_retriever(
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
):
all_documents.append(document)
called = {"init": 0, "invoke": 0}
class ContextRequiredPostProcessor:
def __init__(self, *args, **kwargs):
called["init"] += 1
# will raise RuntimeError if no Flask app context exists
_ = current_app.name
def invoke(self, *args, **kwargs):
called["invoke"] += 1
_ = current_app.name
return kwargs.get("documents") or args[1]
# output list from _multiple_retrieve_thread
all_documents: list[Document] = []
# IMPORTANT: _multiple_retrieve_thread swallows exceptions and appends them here
thread_exceptions: list[Exception] = []
def target():
with patch.object(dataset_retrieval, "_retriever", side_effect=fake_retriever):
with patch(
"core.rag.retrieval.dataset_retrieval.DataPostProcessor",
ContextRequiredPostProcessor,
):
dataset_retrieval._multiple_retrieve_thread(
flask_app=flask_app,
available_datasets=[mock_dataset, secondary_dataset],
metadata_condition=None,
metadata_filter_document_ids=None,
all_documents=all_documents,
tenant_id=tenant_id,
reranking_enable=True,
reranking_mode="reranking_model",
reranking_model={
"reranking_provider_name": "cohere",
"reranking_model_name": "rerank-v2",
},
weights=None,
top_k=3,
score_threshold=0.0,
query="test query",
attachment_id=None,
dataset_count=2, # force reranking branch
thread_exceptions=thread_exceptions, # ✅ key
)
t = threading.Thread(target=target)
t.start()
t.join()
# Ensure reranking branch was actually executed
assert called["init"] >= 1, "DataPostProcessor was never constructed; reranking branch may not have run."
# Current buggy code should record an exception (not raise it)
assert not thread_exceptions, thread_exceptions

View File

@ -3,8 +3,15 @@
import json
from unittest.mock import MagicMock
from core.variables import IntegerVariable, StringVariable
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, GraphEngineCommand
from core.workflow.graph_engine.entities.commands import (
AbortCommand,
CommandType,
GraphEngineCommand,
UpdateVariablesCommand,
VariableUpdate,
)
class TestRedisChannel:
@ -148,6 +155,43 @@ class TestRedisChannel:
assert commands[0].command_type == CommandType.ABORT
assert isinstance(commands[1], AbortCommand)
def test_fetch_commands_with_update_variables_command(self):
"""Test fetching update variables command from Redis."""
mock_redis = MagicMock()
pending_pipe = MagicMock()
fetch_pipe = MagicMock()
pending_context = MagicMock()
fetch_context = MagicMock()
pending_context.__enter__.return_value = pending_pipe
pending_context.__exit__.return_value = None
fetch_context.__enter__.return_value = fetch_pipe
fetch_context.__exit__.return_value = None
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
update_command = UpdateVariablesCommand(
updates=[
VariableUpdate(
value=StringVariable(name="foo", value="bar", selector=["node1", "foo"]),
),
VariableUpdate(
value=IntegerVariable(name="baz", value=123, selector=["node2", "baz"]),
),
]
)
command_json = json.dumps(update_command.model_dump())
pending_pipe.execute.return_value = [b"1", 1]
fetch_pipe.execute.return_value = [[command_json.encode()], 1]
channel = RedisChannel(mock_redis, "test:key")
commands = channel.fetch_commands()
assert len(commands) == 1
assert isinstance(commands[0], UpdateVariablesCommand)
assert isinstance(commands[0].updates[0].value, StringVariable)
assert list(commands[0].updates[0].value.selector) == ["node1", "foo"]
assert commands[0].updates[0].value.value == "bar"
def test_fetch_commands_skips_invalid_json(self):
"""Test that invalid JSON commands are skipped."""
mock_redis = MagicMock()

View File

@ -0,0 +1,56 @@
from __future__ import annotations
import pytest
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_engine.layers.base import (
GraphEngineLayer,
GraphEngineLayerNotInitializedError,
)
from core.workflow.graph_events import GraphEngineEvent
from ..test_table_runner import WorkflowRunner
class LayerForTest(GraphEngineLayer):
def on_graph_start(self) -> None:
pass
def on_event(self, event: GraphEngineEvent) -> None:
pass
def on_graph_end(self, error: Exception | None) -> None:
pass
def test_layer_runtime_state_raises_when_uninitialized() -> None:
layer = LayerForTest()
with pytest.raises(GraphEngineLayerNotInitializedError):
_ = layer.graph_runtime_state
def test_layer_runtime_state_available_after_engine_layer() -> None:
runner = WorkflowRunner()
fixture_data = runner.load_fixture("simple_passthrough_workflow")
graph, graph_runtime_state = runner.create_graph_from_fixture(
fixture_data,
inputs={"query": "test layer state"},
)
engine = GraphEngine(
workflow_id="test_workflow",
graph=graph,
graph_runtime_state=graph_runtime_state,
command_channel=InMemoryChannel(),
)
layer = LayerForTest()
engine.layer(layer)
outputs = layer.graph_runtime_state.outputs
ready_queue_size = layer.graph_runtime_state.ready_queue_size
assert outputs == {}
assert isinstance(ready_queue_size, int)
assert ready_queue_size >= 0

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import queue
import threading
from unittest import mock
from core.workflow.entities.pause_reason import SchedulingPause
@ -36,6 +37,7 @@ def test_dispatcher_should_consume_remains_events_after_pause():
event_queue=event_queue,
event_handler=event_handler,
execution_coordinator=execution_coordinator,
stop_event=threading.Event(),
)
dispatcher._dispatcher_loop()
assert event_queue.empty()
@ -96,6 +98,7 @@ def _run_dispatcher_for_event(event) -> int:
event_queue=event_queue,
event_handler=event_handler,
execution_coordinator=coordinator,
stop_event=threading.Event(),
)
dispatcher._dispatcher_loop()
@ -181,6 +184,7 @@ def test_dispatcher_drain_event_queue():
event_queue=event_queue,
event_handler=event_handler,
execution_coordinator=coordinator,
stop_event=threading.Event(),
)
dispatcher._dispatcher_loop()

View File

@ -4,12 +4,19 @@ import time
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import IntegerVariable, StringVariable
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand
from core.workflow.graph_engine.entities.commands import (
AbortCommand,
CommandType,
PauseCommand,
UpdateVariablesCommand,
VariableUpdate,
)
from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
@ -180,3 +187,67 @@ def test_pause_command():
graph_execution = engine.graph_runtime_state.graph_execution
assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")]
def test_update_variables_command_updates_pool():
"""Test that GraphEngine updates variable pool via update variables command."""
shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
shared_runtime_state.variable_pool.add(("node1", "foo"), "old value")
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start"
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=shared_runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
command_channel = InMemoryChannel()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=shared_runtime_state,
command_channel=command_channel,
)
update_command = UpdateVariablesCommand(
updates=[
VariableUpdate(
value=StringVariable(name="foo", value="new value", selector=["node1", "foo"]),
),
VariableUpdate(
value=IntegerVariable(name="bar", value=123, selector=["node2", "bar"]),
),
]
)
command_channel.send_command(update_command)
list(engine.run())
updated_existing = shared_runtime_state.variable_pool.get(["node1", "foo"])
added_new = shared_runtime_state.variable_pool.get(["node2", "bar"])
assert updated_existing is not None
assert updated_existing.value == "new value"
assert added_new is not None
assert added_new.value == 123

View File

@ -0,0 +1,539 @@
"""
Unit tests for stop_event functionality in GraphEngine.
Tests the unified stop_event management by GraphEngine and its propagation
to WorkerPool, Worker, Dispatcher, and Nodes.
"""
import threading
import time
from unittest.mock import MagicMock, Mock, patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import (
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunStartedEvent,
)
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import UserFrom
class TestStopEventPropagation:
"""Test suite for stop_event propagation through GraphEngine components."""
def test_graph_engine_creates_stop_event(self):
"""Test that GraphEngine creates a stop_event on initialization."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
# Verify stop_event was created
assert engine._stop_event is not None
assert isinstance(engine._stop_event, threading.Event)
# Verify it was set in graph_runtime_state
assert runtime_state.stop_event is not None
assert runtime_state.stop_event is engine._stop_event
def test_stop_event_cleared_on_start(self):
"""Test that stop_event is cleared when execution starts."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start" # Set proper id
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
# Set the stop_event before running
engine._stop_event.set()
assert engine._stop_event.is_set()
# Run the engine (should clear the stop_event)
events = list(engine.run())
# After running, stop_event should be set again (by _stop_execution)
# But during start it was cleared
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
def test_stop_event_set_on_stop(self):
"""Test that stop_event is set when execution stops."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start" # Set proper id
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
# Initially not set
assert not engine._stop_event.is_set()
# Run the engine
list(engine.run())
# After execution completes, stop_event should be set
assert engine._stop_event.is_set()
def test_stop_event_passed_to_worker_pool(self):
"""Test that stop_event is passed to WorkerPool."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
# Verify WorkerPool has the stop_event
assert engine._worker_pool._stop_event is not None
assert engine._worker_pool._stop_event is engine._stop_event
def test_stop_event_passed_to_dispatcher(self):
"""Test that stop_event is passed to Dispatcher."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
# Verify Dispatcher has the stop_event
assert engine._dispatcher._stop_event is not None
assert engine._dispatcher._stop_event is engine._stop_event
class TestNodeStopCheck:
"""Test suite for Node._should_stop() functionality."""
def test_node_should_stop_checks_runtime_state(self):
"""Test that Node._should_stop() checks GraphRuntimeState.stop_event."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
answer_node = AnswerNode(
id="answer",
config={"id": "answer", "data": {"title": "answer", "answer": "{{#start.result#}}"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
# Initially stop_event is not set
assert not answer_node._should_stop()
# Set the stop_event
runtime_state.stop_event.set()
# Now _should_stop should return True
assert answer_node._should_stop()
def test_node_run_checks_stop_event_between_yields(self):
"""Test that Node.run() checks stop_event between yielding events."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
# Create a simple node
answer_node = AnswerNode(
id="answer",
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
# Set stop_event BEFORE running the node
runtime_state.stop_event.set()
# Run the node - should yield start event then detect stop
# The node should check stop_event before processing
assert answer_node._should_stop(), "stop_event should be set"
# Run and collect events
events = list(answer_node.run())
# Since stop_event is set at the start, we should get:
# 1. NodeRunStartedEvent (always yielded first)
# 2. Either NodeRunFailedEvent (if detected early) or NodeRunSucceededEvent (if too fast)
assert len(events) >= 2
assert isinstance(events[0], NodeRunStartedEvent)
# Note: AnswerNode is very simple and might complete before stop check
# The important thing is that _should_stop() returns True when stop_event is set
assert answer_node._should_stop()
class TestStopEventIntegration:
"""Integration tests for stop_event in workflow execution."""
def test_simple_workflow_respects_stop_event(self):
"""Test that a simple workflow respects stop_event."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start"
# Create start and answer nodes
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
answer_node = AnswerNode(
id="answer",
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.nodes["answer"] = answer_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
# Set stop_event before running
runtime_state.stop_event.set()
# Run the engine
events = list(engine.run())
# Should get started event but not succeeded (due to stop)
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
# The workflow should still complete (start node runs quickly)
# but answer node might be cancelled depending on timing
def test_stop_event_with_concurrent_nodes(self):
"""Test stop_event behavior with multiple concurrent nodes."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
# Create multiple nodes
for i in range(3):
answer_node = AnswerNode(
id=f"answer_{i}",
config={"id": f"answer_{i}", "data": {"title": f"answer_{i}", "answer": f"test{i}"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes[f"answer_{i}"] = answer_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
# All nodes should share the same stop_event
for node in mock_graph.nodes.values():
assert node.graph_runtime_state.stop_event is runtime_state.stop_event
assert node.graph_runtime_state.stop_event is engine._stop_event
class TestStopEventTimeoutBehavior:
"""Test stop_event behavior with join timeouts."""
@patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread")
def test_dispatcher_uses_shorter_timeout(self, mock_thread_cls: MagicMock):
"""Test that Dispatcher uses 2s timeout instead of 10s."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
dispatcher = engine._dispatcher
dispatcher.start() # This will create and start the mocked thread
mock_thread_instance = mock_thread_cls.return_value
mock_thread_instance.is_alive.return_value = True
dispatcher.stop()
mock_thread_instance.join.assert_called_once_with(timeout=2.0)
@patch("core.workflow.graph_engine.worker_management.worker_pool.Worker")
def test_worker_pool_uses_shorter_timeout(self, mock_worker_cls: MagicMock):
"""Test that WorkerPool uses 2s timeout instead of 10s."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
worker_pool = engine._worker_pool
worker_pool.start(initial_count=1) # Start with one worker
mock_worker_instance = mock_worker_cls.return_value
mock_worker_instance.is_alive.return_value = True
worker_pool.stop()
mock_worker_instance.join.assert_called_once_with(timeout=2.0)
class TestStopEventResumeBehavior:
"""Test stop_event behavior during workflow resume."""
def test_stop_event_cleared_on_resume(self):
"""Test that stop_event is cleared when resuming a paused workflow."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start" # Set proper id
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
# Simulate a previous execution that set stop_event
engine._stop_event.set()
assert engine._stop_event.is_set()
# Run the engine (should clear stop_event in _start_execution)
events = list(engine.run())
# Execution should complete successfully
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
class TestWorkerStopBehavior:
"""Test Worker behavior with shared stop_event."""
def test_worker_uses_shared_stop_event(self):
"""Test that Worker uses shared stop_event from GraphEngine."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
# Get the worker pool and check workers
worker_pool = engine._worker_pool
# Start the worker pool to create workers
worker_pool.start()
# Check that at least one worker was created
assert len(worker_pool._workers) > 0
# Verify workers use the shared stop_event
for worker in worker_pool._workers:
assert worker._stop_event is engine._stop_event
# Clean up
worker_pool.stop()
def test_worker_stop_is_noop(self):
"""Test that Worker.stop() is now a no-op."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
# Create a mock worker
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
from core.workflow.graph_engine.worker import Worker
ready_queue = InMemoryReadyQueue()
event_queue = MagicMock()
# Create a proper mock graph with real dict
mock_graph = Mock(spec=Graph)
mock_graph.nodes = {} # Use real dict
stop_event = threading.Event()
worker = Worker(
ready_queue=ready_queue,
event_queue=event_queue,
graph=mock_graph,
layers=[],
stop_event=stop_event,
)
# Calling stop() should do nothing (no-op)
# and should NOT set the stop_event
worker.stop()
assert not stop_event.is_set()

View File

@ -78,7 +78,7 @@ class TestFileSaverImpl:
file_binary=_PNG_DATA,
mimetype=mime_type,
)
mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png")
mocked_sign_file.assert_called_once_with(tool_file_id=mock_tool_file.id, extension=".png", for_external=True)
def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch):
_TEST_URL = "https://example.com/image.png"

View File

@ -5,8 +5,8 @@ from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.helper.code_executor.code_executor import CodeExecutionError
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.nodes.template_transform.template_renderer import TemplateRenderError
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from models.workflow import WorkflowType
@ -127,7 +127,9 @@ class TestTemplateTransformNode:
"""Test version class method."""
assert TemplateTransformNode.version() == "1"
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_simple_template(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
):
@ -145,7 +147,7 @@ class TestTemplateTransformNode:
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
# Setup mock executor
mock_execute.return_value = {"result": "Hello Alice, you are 30 years old!"}
mock_execute.return_value = "Hello Alice, you are 30 years old!"
node = TemplateTransformNode(
id="test_node",
@ -162,7 +164,9 @@ class TestTemplateTransformNode:
assert result.inputs["name"] == "Alice"
assert result.inputs["age"] == 30
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with None variable values."""
node_data = {
@ -172,7 +176,7 @@ class TestTemplateTransformNode:
}
mock_graph_runtime_state.variable_pool.get.return_value = None
mock_execute.return_value = {"result": "Value: "}
mock_execute.return_value = "Value: "
node = TemplateTransformNode(
id="test_node",
@ -187,13 +191,15 @@ class TestTemplateTransformNode:
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.inputs["value"] is None
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_code_execution_error(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
):
"""Test _run when code execution fails."""
mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
mock_execute.side_effect = CodeExecutionError("Template syntax error")
mock_execute.side_effect = TemplateRenderError("Template syntax error")
node = TemplateTransformNode(
id="test_node",
@ -208,14 +214,16 @@ class TestTemplateTransformNode:
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Template syntax error" in result.error
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
@patch("core.workflow.nodes.template_transform.template_transform_node.MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH", 10)
def test_run_output_length_exceeds_limit(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
):
"""Test _run when output exceeds maximum length."""
mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
mock_execute.return_value = {"result": "This is a very long output that exceeds the limit"}
mock_execute.return_value = "This is a very long output that exceeds the limit"
node = TemplateTransformNode(
id="test_node",
@ -230,7 +238,9 @@ class TestTemplateTransformNode:
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Output length exceeds" in result.error
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_complex_jinja2_template(
self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params
):
@ -257,7 +267,7 @@ class TestTemplateTransformNode:
("sys", "show_total"): mock_show_total,
}
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
mock_execute.return_value = {"result": "apple, banana, orange (Total: 3)"}
mock_execute.return_value = "apple, banana, orange (Total: 3)"
node = TemplateTransformNode(
id="test_node",
@ -292,7 +302,9 @@ class TestTemplateTransformNode:
assert mapping["node_123.var1"] == ["sys", "input1"]
assert mapping["node_123.var2"] == ["sys", "input2"]
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with no variables (static template)."""
node_data = {
@ -301,7 +313,7 @@ class TestTemplateTransformNode:
"template": "This is a static message.",
}
mock_execute.return_value = {"result": "This is a static message."}
mock_execute.return_value = "This is a static message."
node = TemplateTransformNode(
id="test_node",
@ -317,7 +329,9 @@ class TestTemplateTransformNode:
assert result.outputs["output"] == "This is a static message."
assert result.inputs == {}
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with numeric variable values."""
node_data = {
@ -339,7 +353,7 @@ class TestTemplateTransformNode:
("sys", "quantity"): mock_quantity,
}
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
mock_execute.return_value = {"result": "Total: $31.5"}
mock_execute.return_value = "Total: $31.5"
node = TemplateTransformNode(
id="test_node",
@ -354,7 +368,9 @@ class TestTemplateTransformNode:
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["output"] == "Total: $31.5"
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with dictionary variable values."""
node_data = {
@ -367,7 +383,7 @@ class TestTemplateTransformNode:
mock_user.to_object.return_value = {"name": "John Doe", "email": "john@example.com"}
mock_graph_runtime_state.variable_pool.get.return_value = mock_user
mock_execute.return_value = {"result": "Name: John Doe, Email: john@example.com"}
mock_execute.return_value = "Name: John Doe, Email: john@example.com"
node = TemplateTransformNode(
id="test_node",
@ -383,7 +399,9 @@ class TestTemplateTransformNode:
assert "John Doe" in result.outputs["output"]
assert "john@example.com" in result.outputs["output"]
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with list variable values."""
node_data = {
@ -396,7 +414,7 @@ class TestTemplateTransformNode:
mock_tags.to_object.return_value = ["python", "ai", "workflow"]
mock_graph_runtime_state.variable_pool.get.return_value = mock_tags
mock_execute.return_value = {"result": "Tags: #python #ai #workflow "}
mock_execute.return_value = "Tags: #python #ai #workflow "
node = TemplateTransformNode(
id="test_node",

View File

@ -1,14 +1,14 @@
import time
import uuid
from unittest import mock
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import ArrayStringVariable, StringVariable
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_events.node import NodeRunSucceededEvent
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode
from core.workflow.runtime import GraphRuntimeState, VariablePool
@ -86,9 +86,6 @@ def test_overwrite_string_variable():
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
@ -104,20 +101,14 @@ def test_overwrite_string_variable():
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
list(node.run())
expected_var = StringVariable(
id=conversation_variable.id,
name=conversation_variable.name,
description=conversation_variable.description,
selector=conversation_variable.selector,
value_type=conversation_variable.value_type,
value=input_variable.value,
)
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
mock_conv_var_updater.flush.assert_called_once()
events = list(node.run())
succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent))
updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data)
assert updated_variables is not None
assert updated_variables[0].name == conversation_variable.name
assert updated_variables[0].new_value == input_variable.value
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
@ -191,9 +182,6 @@ def test_append_variable_to_array():
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
@ -209,22 +197,14 @@ def test_append_variable_to_array():
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
list(node.run())
expected_value = list(conversation_variable.value)
expected_value.append(input_variable.value)
expected_var = ArrayStringVariable(
id=conversation_variable.id,
name=conversation_variable.name,
description=conversation_variable.description,
selector=conversation_variable.selector,
value_type=conversation_variable.value_type,
value=expected_value,
)
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
mock_conv_var_updater.flush.assert_called_once()
events = list(node.run())
succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent))
updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data)
assert updated_variables is not None
assert updated_variables[0].name == conversation_variable.name
assert updated_variables[0].new_value == ["the first value", "the second value"]
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
@ -287,9 +267,6 @@ def test_clear_array():
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
@ -305,20 +282,14 @@ def test_clear_array():
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
list(node.run())
expected_var = ArrayStringVariable(
id=conversation_variable.id,
name=conversation_variable.name,
description=conversation_variable.description,
selector=conversation_variable.selector,
value_type=conversation_variable.value_type,
value=[],
)
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
mock_conv_var_updater.flush.assert_called_once()
events = list(node.run())
succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent))
updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data)
assert updated_variables is not None
assert updated_variables[0].name == conversation_variable.name
assert updated_variables[0].new_value == []
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None

View File

@ -390,3 +390,42 @@ def test_remove_last_from_empty_array():
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
assert got.to_object() == []
def test_node_factory_creates_variable_assigner_node():
graph_config = {
"edges": [],
"nodes": [
{
"data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []},
"id": "assigner",
},
],
}
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
variable_pool = VariablePool(
system_variables=SystemVariable(conversation_id="conversation_id"),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
node = node_factory.create_node(graph_config["nodes"][0])
assert isinstance(node, VariableAssignerNode)

View File

@ -1,6 +1,6 @@
import pytest
from libs.helper import extract_tenant_id
from libs.helper import escape_like_pattern, extract_tenant_id
from models.account import Account
from models.model import EndUser
@ -63,3 +63,51 @@ class TestExtractTenantId:
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
extract_tenant_id(dict_user)
class TestEscapeLikePattern:
"""Test cases for the escape_like_pattern utility function."""
def test_escape_percent_character(self):
"""Test escaping percent character."""
result = escape_like_pattern("50% discount")
assert result == "50\\% discount"
def test_escape_underscore_character(self):
"""Test escaping underscore character."""
result = escape_like_pattern("test_data")
assert result == "test\\_data"
def test_escape_backslash_character(self):
"""Test escaping backslash character."""
result = escape_like_pattern("path\\to\\file")
assert result == "path\\\\to\\\\file"
def test_escape_combined_special_characters(self):
"""Test escaping multiple special characters together."""
result = escape_like_pattern("file_50%\\path")
assert result == "file\\_50\\%\\\\path"
def test_escape_empty_string(self):
"""Test escaping empty string returns empty string."""
result = escape_like_pattern("")
assert result == ""
def test_escape_none_handling(self):
"""Test escaping None returns None (falsy check handles it)."""
# The function checks `if not pattern`, so None is falsy and returns as-is
result = escape_like_pattern(None)
assert result is None
def test_escape_normal_string_no_change(self):
"""Test that normal strings without special characters are unchanged."""
result = escape_like_pattern("normal text")
assert result == "normal text"
def test_escape_order_matters(self):
"""Test that backslash is escaped first to prevent double escaping."""
# If we escape % first, then escape \, we might get wrong results
# This test ensures the order is correct: \ first, then % and _
result = escape_like_pattern("test\\%_value")
# Should be: test\\\%\_value
assert result == "test\\\\\\%\\_value"

Some files were not shown because too many files have changed in this diff Show More