mirror of
https://github.com/langgenius/dify.git
synced 2026-01-14 06:07:33 +08:00
Merge branch 'main' into feat/grouping-branching
# Conflicts: # web/package.json
This commit is contained in:
commit
760a739e91
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@ -20,4 +20,4 @@
|
||||
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
|
||||
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
|
||||
- [x] I've updated the documentation accordingly.
|
||||
- [x] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
|
||||
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods
|
||||
|
||||
5
Makefile
5
Makefile
@ -60,9 +60,10 @@ check:
|
||||
@echo "✅ Code check complete"
|
||||
|
||||
lint:
|
||||
@echo "🔧 Running ruff format, check with fixes, and import linter..."
|
||||
@echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..."
|
||||
@uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
|
||||
@uv run --directory api --dev lint-imports
|
||||
@uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
type-check:
|
||||
@ -122,7 +123,7 @@ help:
|
||||
@echo "Backend Code Quality:"
|
||||
@echo " make format - Format code with ruff"
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format and fix code with ruff"
|
||||
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
|
||||
@echo " make type-check - Run type checking with basedpyright"
|
||||
@echo " make test - Run backend unit tests"
|
||||
@echo ""
|
||||
|
||||
@ -575,6 +575,10 @@ LOGSTORE_DUAL_WRITE_ENABLED=false
|
||||
# Enable dual-read fallback to SQL database when LogStore returns no results (default: true)
|
||||
# Useful for migration scenarios where historical data exists only in SQL database
|
||||
LOGSTORE_DUAL_READ_ENABLED=true
|
||||
# Control flag for whether to write the `graph` field to LogStore.
|
||||
# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
|
||||
# otherwise write an empty {} instead. Defaults to writing the `graph` field.
|
||||
LOGSTORE_ENABLE_PUT_GRAPH_FIELD=true
|
||||
|
||||
# Celery beat configuration
|
||||
CELERY_BEAT_SCHEDULER_TIME=1
|
||||
|
||||
@ -3,9 +3,11 @@ root_packages =
|
||||
core
|
||||
configs
|
||||
controllers
|
||||
extensions
|
||||
models
|
||||
tasks
|
||||
services
|
||||
include_external_packages = True
|
||||
|
||||
[importlinter:contract:workflow]
|
||||
name = Workflow
|
||||
@ -33,6 +35,28 @@ ignore_imports =
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
||||
|
||||
[importlinter:contract:workflow-infrastructure-dependencies]
|
||||
name = Workflow Infrastructure Dependencies
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.workflow
|
||||
forbidden_modules =
|
||||
extensions.ext_database
|
||||
extensions.ext_redis
|
||||
allow_indirect_imports = True
|
||||
ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
||||
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
type = layers
|
||||
|
||||
@ -50,16 +50,33 @@ WORKDIR /app/api
|
||||
|
||||
# Create non-root user
|
||||
ARG dify_uid=1001
|
||||
ARG NODE_MAJOR=22
|
||||
ARG NODE_PACKAGE_VERSION=22.21.0-1nodesource1
|
||||
ARG NODESOURCE_KEY_FPR=6F71F525282841EEDAF851B42F59B5F99B1BE0B4
|
||||
RUN groupadd -r -g ${dify_uid} dify && \
|
||||
useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \
|
||||
chown -R dify:dify /app
|
||||
|
||||
RUN \
|
||||
apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
curl \
|
||||
gnupg \
|
||||
&& mkdir -p /etc/apt/keyrings \
|
||||
&& curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key -o /tmp/nodesource.gpg \
|
||||
&& gpg --show-keys --with-colons /tmp/nodesource.gpg \
|
||||
| awk -F: '/^fpr:/ {print $10}' \
|
||||
| grep -Fx "${NODESOURCE_KEY_FPR}" \
|
||||
&& gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg /tmp/nodesource.gpg \
|
||||
&& rm -f /tmp/nodesource.gpg \
|
||||
&& echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_${NODE_MAJOR}.x nodistro main" \
|
||||
> /etc/apt/sources.list.d/nodesource.list \
|
||||
&& apt-get update \
|
||||
# Install dependencies
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
curl nodejs \
|
||||
nodejs=${NODE_PACKAGE_VERSION} \
|
||||
# for gmpy2 \
|
||||
libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# For Security
|
||||
@ -79,7 +96,8 @@ COPY --from=packages --chown=dify:dify ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
||||
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
|
||||
|
||||
# Download nltk data
|
||||
RUN mkdir -p /usr/local/share/nltk_data && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \
|
||||
RUN mkdir -p /usr/local/share/nltk_data \
|
||||
&& NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; from unstructured.nlp.tokenize import download_nltk_packages; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords'); download_nltk_packages()" \
|
||||
&& chmod -R 755 /usr/local/share/nltk_data
|
||||
|
||||
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
|
||||
|
||||
213
api/commands.py
213
api/commands.py
@ -235,7 +235,7 @@ def migrate_annotation_vector_database():
|
||||
if annotations:
|
||||
for annotation in annotations:
|
||||
document = Document(
|
||||
page_content=annotation.question,
|
||||
page_content=annotation.question_text,
|
||||
metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id},
|
||||
)
|
||||
documents.append(document)
|
||||
@ -1184,6 +1184,217 @@ def remove_orphaned_files_on_storage(force: bool):
|
||||
click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow"))
|
||||
|
||||
|
||||
@click.command("file-usage", help="Query file usages and show where files are referenced.")
|
||||
@click.option("--file-id", type=str, default=None, help="Filter by file UUID.")
|
||||
@click.option("--key", type=str, default=None, help="Filter by storage key.")
|
||||
@click.option("--src", type=str, default=None, help="Filter by table.column pattern (e.g., 'documents.%' or '%.icon').")
|
||||
@click.option("--limit", type=int, default=100, help="Limit number of results (default: 100).")
|
||||
@click.option("--offset", type=int, default=0, help="Offset for pagination (default: 0).")
|
||||
@click.option("--json", "output_json", is_flag=True, help="Output results in JSON format.")
|
||||
def file_usage(
|
||||
file_id: str | None,
|
||||
key: str | None,
|
||||
src: str | None,
|
||||
limit: int,
|
||||
offset: int,
|
||||
output_json: bool,
|
||||
):
|
||||
"""
|
||||
Query file usages and show where files are referenced in the database.
|
||||
|
||||
This command reuses the same reference checking logic as clear-orphaned-file-records
|
||||
and displays detailed information about where each file is referenced.
|
||||
"""
|
||||
# define tables and columns to process
|
||||
files_tables = [
|
||||
{"table": "upload_files", "id_column": "id", "key_column": "key"},
|
||||
{"table": "tool_files", "id_column": "id", "key_column": "file_key"},
|
||||
]
|
||||
ids_tables = [
|
||||
{"type": "uuid", "table": "message_files", "column": "upload_file_id", "pk_column": "id"},
|
||||
{"type": "text", "table": "documents", "column": "data_source_info", "pk_column": "id"},
|
||||
{"type": "text", "table": "document_segments", "column": "content", "pk_column": "id"},
|
||||
{"type": "text", "table": "messages", "column": "answer", "pk_column": "id"},
|
||||
{"type": "text", "table": "workflow_node_executions", "column": "inputs", "pk_column": "id"},
|
||||
{"type": "text", "table": "workflow_node_executions", "column": "process_data", "pk_column": "id"},
|
||||
{"type": "text", "table": "workflow_node_executions", "column": "outputs", "pk_column": "id"},
|
||||
{"type": "text", "table": "conversations", "column": "introduction", "pk_column": "id"},
|
||||
{"type": "text", "table": "conversations", "column": "system_instruction", "pk_column": "id"},
|
||||
{"type": "text", "table": "accounts", "column": "avatar", "pk_column": "id"},
|
||||
{"type": "text", "table": "apps", "column": "icon", "pk_column": "id"},
|
||||
{"type": "text", "table": "sites", "column": "icon", "pk_column": "id"},
|
||||
{"type": "json", "table": "messages", "column": "inputs", "pk_column": "id"},
|
||||
{"type": "json", "table": "messages", "column": "message", "pk_column": "id"},
|
||||
]
|
||||
|
||||
# Stream file usages with pagination to avoid holding all results in memory
|
||||
paginated_usages = []
|
||||
total_count = 0
|
||||
|
||||
# First, build a mapping of file_id -> storage_key from the base tables
|
||||
file_key_map = {}
|
||||
for files_table in files_tables:
|
||||
query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}"
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
file_key_map[str(row[0])] = f"{files_table['table']}:{row[1]}"
|
||||
|
||||
# If filtering by key or file_id, verify it exists
|
||||
if file_id and file_id not in file_key_map:
|
||||
if output_json:
|
||||
click.echo(json.dumps({"error": f"File ID {file_id} not found in base tables"}))
|
||||
else:
|
||||
click.echo(click.style(f"File ID {file_id} not found in base tables.", fg="red"))
|
||||
return
|
||||
|
||||
if key:
|
||||
valid_prefixes = {f"upload_files:{key}", f"tool_files:{key}"}
|
||||
matching_file_ids = [fid for fid, fkey in file_key_map.items() if fkey in valid_prefixes]
|
||||
if not matching_file_ids:
|
||||
if output_json:
|
||||
click.echo(json.dumps({"error": f"Key {key} not found in base tables"}))
|
||||
else:
|
||||
click.echo(click.style(f"Key {key} not found in base tables.", fg="red"))
|
||||
return
|
||||
|
||||
guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
|
||||
# For each reference table/column, find matching file IDs and record the references
|
||||
for ids_table in ids_tables:
|
||||
src_filter = f"{ids_table['table']}.{ids_table['column']}"
|
||||
|
||||
# Skip if src filter doesn't match (use fnmatch for wildcard patterns)
|
||||
if src:
|
||||
if "%" in src or "_" in src:
|
||||
import fnmatch
|
||||
|
||||
# Convert SQL LIKE wildcards to fnmatch wildcards (% -> *, _ -> ?)
|
||||
pattern = src.replace("%", "*").replace("_", "?")
|
||||
if not fnmatch.fnmatch(src_filter, pattern):
|
||||
continue
|
||||
else:
|
||||
if src_filter != src:
|
||||
continue
|
||||
|
||||
if ids_table["type"] == "uuid":
|
||||
# Direct UUID match
|
||||
query = (
|
||||
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
|
||||
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
record_id = str(row[0])
|
||||
ref_file_id = str(row[1])
|
||||
if ref_file_id not in file_key_map:
|
||||
continue
|
||||
storage_key = file_key_map[ref_file_id]
|
||||
|
||||
# Apply filters
|
||||
if file_id and ref_file_id != file_id:
|
||||
continue
|
||||
if key and not storage_key.endswith(key):
|
||||
continue
|
||||
|
||||
# Only collect items within the requested page range
|
||||
if offset <= total_count < offset + limit:
|
||||
paginated_usages.append(
|
||||
{
|
||||
"src": f"{ids_table['table']}.{ids_table['column']}",
|
||||
"record_id": record_id,
|
||||
"file_id": ref_file_id,
|
||||
"key": storage_key,
|
||||
}
|
||||
)
|
||||
total_count += 1
|
||||
|
||||
elif ids_table["type"] in ("text", "json"):
|
||||
# Extract UUIDs from text/json content
|
||||
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
|
||||
query = (
|
||||
f"SELECT {ids_table['pk_column']}, {column_cast} "
|
||||
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
record_id = str(row[0])
|
||||
content = str(row[1])
|
||||
|
||||
# Find all UUIDs in the content
|
||||
import re
|
||||
|
||||
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
|
||||
matches = uuid_pattern.findall(content)
|
||||
|
||||
for ref_file_id in matches:
|
||||
if ref_file_id not in file_key_map:
|
||||
continue
|
||||
storage_key = file_key_map[ref_file_id]
|
||||
|
||||
# Apply filters
|
||||
if file_id and ref_file_id != file_id:
|
||||
continue
|
||||
if key and not storage_key.endswith(key):
|
||||
continue
|
||||
|
||||
# Only collect items within the requested page range
|
||||
if offset <= total_count < offset + limit:
|
||||
paginated_usages.append(
|
||||
{
|
||||
"src": f"{ids_table['table']}.{ids_table['column']}",
|
||||
"record_id": record_id,
|
||||
"file_id": ref_file_id,
|
||||
"key": storage_key,
|
||||
}
|
||||
)
|
||||
total_count += 1
|
||||
|
||||
# Output results
|
||||
if output_json:
|
||||
result = {
|
||||
"total": total_count,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"usages": paginated_usages,
|
||||
}
|
||||
click.echo(json.dumps(result, indent=2))
|
||||
else:
|
||||
click.echo(
|
||||
click.style(f"Found {total_count} file usages (showing {len(paginated_usages)} results)", fg="white")
|
||||
)
|
||||
click.echo("")
|
||||
|
||||
if not paginated_usages:
|
||||
click.echo(click.style("No file usages found matching the specified criteria.", fg="yellow"))
|
||||
return
|
||||
|
||||
# Print table header
|
||||
click.echo(
|
||||
click.style(
|
||||
f"{'Src (Table.Column)':<50} {'Record ID':<40} {'File ID':<40} {'Storage Key':<60}",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
click.echo(click.style("-" * 190, fg="white"))
|
||||
|
||||
# Print each usage
|
||||
for usage in paginated_usages:
|
||||
click.echo(f"{usage['src']:<50} {usage['record_id']:<40} {usage['file_id']:<40} {usage['key']:<60}")
|
||||
|
||||
# Show pagination info
|
||||
if offset + limit < total_count:
|
||||
click.echo("")
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Showing {offset + 1}-{offset + len(paginated_usages)} of {total_count} results", fg="white"
|
||||
)
|
||||
)
|
||||
click.echo(click.style(f"Use --offset {offset + limit} to see next page", fg="white"))
|
||||
|
||||
|
||||
@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
|
||||
@ -16,7 +16,6 @@ class MilvusConfig(BaseSettings):
|
||||
description="Authentication token for Milvus, if token-based authentication is enabled",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_USER: str | None = Field(
|
||||
description="Username for authenticating with Milvus, if username/password authentication is enabled",
|
||||
default=None,
|
||||
|
||||
@ -1,14 +1,16 @@
|
||||
import re
|
||||
import uuid
|
||||
from typing import Literal
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
@ -19,27 +21,19 @@ from controllers.console.wraps import (
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import (
|
||||
deleted_tool_fields,
|
||||
model_config_fields,
|
||||
model_config_partial_fields,
|
||||
site_fields,
|
||||
tag_fields,
|
||||
)
|
||||
from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
|
||||
from libs.helper import AppIconUrlField, TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, Workflow
|
||||
from models.model import IconType
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
@ -192,124 +186,292 @@ class AppTracePayload(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
JSONValue: TypeAlias = Any
|
||||
|
||||
|
||||
reg(AppListQuery)
|
||||
reg(CreateAppPayload)
|
||||
reg(UpdateAppPayload)
|
||||
reg(CopyAppPayload)
|
||||
reg(AppExportQuery)
|
||||
reg(AppNamePayload)
|
||||
reg(AppIconPayload)
|
||||
reg(AppSiteStatusPayload)
|
||||
reg(AppApiStatusPayload)
|
||||
reg(AppTracePayload)
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base models first
|
||||
tag_model = console_ns.model("Tag", tag_fields)
|
||||
|
||||
workflow_partial_model = console_ns.model("WorkflowPartial", _workflow_partial_fields_dict)
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
model_config_model = console_ns.model("ModelConfig", model_config_fields)
|
||||
|
||||
model_config_partial_model = console_ns.model("ModelConfigPartial", model_config_partial_fields)
|
||||
def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None:
|
||||
if icon is None or icon_type is None:
|
||||
return None
|
||||
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
|
||||
if icon_type_value.lower() != IconType.IMAGE.value:
|
||||
return None
|
||||
return file_helpers.get_signed_file_url(icon)
|
||||
|
||||
deleted_tool_model = console_ns.model("DeletedTool", deleted_tool_fields)
|
||||
|
||||
site_model = console_ns.model("Site", site_fields)
|
||||
class Tag(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
|
||||
app_partial_model = console_ns.model(
|
||||
"AppPartial",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"max_active_requests": fields.Raw(),
|
||||
"description": fields.String(attribute="desc_or_prompt"),
|
||||
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"model_config": fields.Nested(model_config_partial_model, attribute="app_model_config", allow_null=True),
|
||||
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"tags": fields.List(fields.Nested(tag_model)),
|
||||
"access_mode": fields.String,
|
||||
"create_user_name": fields.String,
|
||||
"author_name": fields.String,
|
||||
"has_draft_trigger": fields.Boolean,
|
||||
},
|
||||
)
|
||||
|
||||
app_detail_model = console_ns.model(
|
||||
"AppDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"enable_site": fields.Boolean,
|
||||
"enable_api": fields.Boolean,
|
||||
"model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
|
||||
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
|
||||
"tracing": fields.Raw,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"access_mode": fields.String,
|
||||
"tags": fields.List(fields.Nested(tag_model)),
|
||||
},
|
||||
)
|
||||
class WorkflowPartial(ResponseModel):
|
||||
id: str
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_by: str | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
app_detail_with_site_model = console_ns.model(
|
||||
"AppDetailWithSite",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"enable_site": fields.Boolean,
|
||||
"enable_api": fields.Boolean,
|
||||
"model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
|
||||
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
|
||||
"api_base_url": fields.String,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"max_active_requests": fields.Integer,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"deleted_tools": fields.List(fields.Nested(deleted_tool_model)),
|
||||
"access_mode": fields.String,
|
||||
"tags": fields.List(fields.Nested(tag_model)),
|
||||
"site": fields.Nested(site_model),
|
||||
},
|
||||
)
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
app_pagination_model = console_ns.model(
|
||||
"AppPagination",
|
||||
{
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(app_partial_model), attribute="items"),
|
||||
},
|
||||
|
||||
class ModelConfigPartial(ResponseModel):
|
||||
model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model"))
|
||||
pre_prompt: str | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_by: str | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class ModelConfig(ResponseModel):
|
||||
opening_statement: str | None = None
|
||||
suggested_questions: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("suggested_questions_list", "suggested_questions")
|
||||
)
|
||||
suggested_questions_after_answer: JSONValue | None = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("suggested_questions_after_answer_dict", "suggested_questions_after_answer"),
|
||||
)
|
||||
speech_to_text: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("speech_to_text_dict", "speech_to_text")
|
||||
)
|
||||
text_to_speech: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("text_to_speech_dict", "text_to_speech")
|
||||
)
|
||||
retriever_resource: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("retriever_resource_dict", "retriever_resource")
|
||||
)
|
||||
annotation_reply: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("annotation_reply_dict", "annotation_reply")
|
||||
)
|
||||
more_like_this: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("more_like_this_dict", "more_like_this")
|
||||
)
|
||||
sensitive_word_avoidance: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("sensitive_word_avoidance_dict", "sensitive_word_avoidance")
|
||||
)
|
||||
external_data_tools: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("external_data_tools_list", "external_data_tools")
|
||||
)
|
||||
model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model"))
|
||||
user_input_form: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("user_input_form_list", "user_input_form")
|
||||
)
|
||||
dataset_query_variable: str | None = None
|
||||
pre_prompt: str | None = None
|
||||
agent_mode: JSONValue | None = Field(default=None, validation_alias=AliasChoices("agent_mode_dict", "agent_mode"))
|
||||
prompt_type: str | None = None
|
||||
chat_prompt_config: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("chat_prompt_config_dict", "chat_prompt_config")
|
||||
)
|
||||
completion_prompt_config: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("completion_prompt_config_dict", "completion_prompt_config")
|
||||
)
|
||||
dataset_configs: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("dataset_configs_dict", "dataset_configs")
|
||||
)
|
||||
file_upload: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("file_upload_dict", "file_upload")
|
||||
)
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_by: str | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class Site(ResponseModel):
|
||||
access_token: str | None = Field(default=None, validation_alias="code")
|
||||
code: str | None = None
|
||||
title: str | None = None
|
||||
icon_type: str | IconType | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
description: str | None = None
|
||||
default_language: str | None = None
|
||||
chat_color_theme: str | None = None
|
||||
chat_color_theme_inverted: bool | None = None
|
||||
customize_domain: str | None = None
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
customize_token_strategy: str | None = None
|
||||
prompt_public: bool | None = None
|
||||
app_base_url: str | None = None
|
||||
show_workflow_steps: bool | None = None
|
||||
use_icon_as_answer_icon: bool | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_by: str | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return _build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
@field_validator("icon_type", mode="before")
|
||||
@classmethod
|
||||
def _normalize_icon_type(cls, value: str | IconType | None) -> str | None:
|
||||
if isinstance(value, IconType):
|
||||
return value.value
|
||||
return value
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class DeletedTool(ResponseModel):
|
||||
type: str
|
||||
tool_name: str
|
||||
provider_id: str
|
||||
|
||||
|
||||
class AppPartial(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
max_active_requests: int | None = None
|
||||
description: str | None = Field(default=None, validation_alias=AliasChoices("desc_or_prompt", "description"))
|
||||
mode: str = Field(validation_alias="mode_compatible_with_agent")
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
model_config_: ModelConfigPartial | None = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("app_model_config", "model_config"),
|
||||
alias="model_config",
|
||||
)
|
||||
workflow: WorkflowPartial | None = None
|
||||
use_icon_as_answer_icon: bool | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_by: str | None = None
|
||||
updated_at: int | None = None
|
||||
tags: list[Tag] = Field(default_factory=list)
|
||||
access_mode: str | None = None
|
||||
create_user_name: str | None = None
|
||||
author_name: str | None = None
|
||||
has_draft_trigger: bool | None = None
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return _build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class AppDetail(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: str = Field(validation_alias="mode_compatible_with_agent")
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
enable_site: bool
|
||||
enable_api: bool
|
||||
model_config_: ModelConfig | None = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("app_model_config", "model_config"),
|
||||
alias="model_config",
|
||||
)
|
||||
workflow: WorkflowPartial | None = None
|
||||
tracing: JSONValue | None = None
|
||||
use_icon_as_answer_icon: bool | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_by: str | None = None
|
||||
updated_at: int | None = None
|
||||
access_mode: str | None = None
|
||||
tags: list[Tag] = Field(default_factory=list)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class AppDetailWithSite(AppDetail):
|
||||
icon_type: str | None = None
|
||||
api_base_url: str | None = None
|
||||
max_active_requests: int | None = None
|
||||
deleted_tools: list[DeletedTool] = Field(default_factory=list)
|
||||
site: Site | None = None
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return _build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
|
||||
class AppPagination(ResponseModel):
|
||||
page: int
|
||||
limit: int = Field(validation_alias=AliasChoices("per_page", "limit"))
|
||||
total: int
|
||||
has_more: bool = Field(validation_alias=AliasChoices("has_next", "has_more"))
|
||||
data: list[AppPartial] = Field(validation_alias=AliasChoices("items", "data"))
|
||||
|
||||
|
||||
class AppExportResponse(ResponseModel):
|
||||
data: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
AppListQuery,
|
||||
CreateAppPayload,
|
||||
UpdateAppPayload,
|
||||
CopyAppPayload,
|
||||
AppExportQuery,
|
||||
AppNamePayload,
|
||||
AppIconPayload,
|
||||
AppSiteStatusPayload,
|
||||
AppApiStatusPayload,
|
||||
AppTracePayload,
|
||||
Tag,
|
||||
WorkflowPartial,
|
||||
ModelConfigPartial,
|
||||
ModelConfig,
|
||||
Site,
|
||||
DeletedTool,
|
||||
AppPartial,
|
||||
AppDetail,
|
||||
AppDetailWithSite,
|
||||
AppPagination,
|
||||
AppExportResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -318,7 +480,7 @@ class AppListApi(Resource):
|
||||
@console_ns.doc("list_apps")
|
||||
@console_ns.doc(description="Get list of applications with pagination and filtering")
|
||||
@console_ns.expect(console_ns.models[AppListQuery.__name__])
|
||||
@console_ns.response(200, "Success", app_pagination_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[AppPagination.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -334,7 +496,8 @@ class AppListApi(Resource):
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
|
||||
if not app_pagination:
|
||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
|
||||
return empty.model_dump(mode="json"), 200
|
||||
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
app_ids = [str(app.id) for app in app_pagination.items]
|
||||
@ -378,18 +541,18 @@ class AppListApi(Resource):
|
||||
for app in app_pagination.items:
|
||||
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
|
||||
|
||||
return marshal(app_pagination, app_pagination_model), 200
|
||||
pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True)
|
||||
return pagination_model.model_dump(mode="json"), 200
|
||||
|
||||
@console_ns.doc("create_app")
|
||||
@console_ns.doc(description="Create a new application")
|
||||
@console_ns.expect(console_ns.models[CreateAppPayload.__name__])
|
||||
@console_ns.response(201, "App created successfully", app_detail_model)
|
||||
@console_ns.response(201, "App created successfully", console_ns.models[AppDetail.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_detail_model)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
@ -399,8 +562,8 @@ class AppListApi(Resource):
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
|
||||
|
||||
return app, 201
|
||||
app_detail = AppDetail.model_validate(app, from_attributes=True)
|
||||
return app_detail.model_dump(mode="json"), 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>")
|
||||
@ -408,13 +571,12 @@ class AppApi(Resource):
|
||||
@console_ns.doc("get_app_detail")
|
||||
@console_ns.doc(description="Get application details")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Success", app_detail_with_site_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[AppDetailWithSite.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_with_site_model)
|
||||
@get_app_model(mode=None)
|
||||
def get(self, app_model):
|
||||
"""Get app detail"""
|
||||
app_service = AppService()
|
||||
@ -425,21 +587,21 @@ class AppApi(Resource):
|
||||
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
|
||||
app_model.access_mode = app_setting.access_mode
|
||||
|
||||
return app_model
|
||||
response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
|
||||
return response_model.model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("update_app")
|
||||
@console_ns.doc(description="Update application details")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[UpdateAppPayload.__name__])
|
||||
@console_ns.response(200, "App updated successfully", app_detail_with_site_model)
|
||||
@console_ns.response(200, "App updated successfully", console_ns.models[AppDetailWithSite.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
@marshal_with(app_detail_with_site_model)
|
||||
def put(self, app_model):
|
||||
"""Update app"""
|
||||
args = UpdateAppPayload.model_validate(console_ns.payload)
|
||||
@ -456,8 +618,8 @@ class AppApi(Resource):
|
||||
"max_active_requests": args.max_active_requests or 0,
|
||||
}
|
||||
app_model = app_service.update_app(app_model, args_dict)
|
||||
|
||||
return app_model
|
||||
response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
|
||||
return response_model.model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("delete_app")
|
||||
@console_ns.doc(description="Delete application")
|
||||
@ -483,14 +645,13 @@ class AppCopyApi(Resource):
|
||||
@console_ns.doc(description="Create a copy of an existing application")
|
||||
@console_ns.doc(params={"app_id": "Application ID to copy"})
|
||||
@console_ns.expect(console_ns.models[CopyAppPayload.__name__])
|
||||
@console_ns.response(201, "App copied successfully", app_detail_with_site_model)
|
||||
@console_ns.response(201, "App copied successfully", console_ns.models[AppDetailWithSite.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
@marshal_with(app_detail_with_site_model)
|
||||
def post(self, app_model):
|
||||
"""Copy app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
@ -516,7 +677,8 @@ class AppCopyApi(Resource):
|
||||
stmt = select(App).where(App.id == result.app_id)
|
||||
app = session.scalar(stmt)
|
||||
|
||||
return app, 201
|
||||
response_model = AppDetailWithSite.model_validate(app, from_attributes=True)
|
||||
return response_model.model_dump(mode="json"), 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/export")
|
||||
@ -525,11 +687,7 @@ class AppExportApi(Resource):
|
||||
@console_ns.doc(description="Export application configuration as DSL")
|
||||
@console_ns.doc(params={"app_id": "Application ID to export"})
|
||||
@console_ns.expect(console_ns.models[AppExportQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"App exported successfully",
|
||||
console_ns.model("AppExportResponse", {"data": fields.String(description="DSL export data")}),
|
||||
)
|
||||
@console_ns.response(200, "App exported successfully", console_ns.models[AppExportResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@ -540,13 +698,14 @@ class AppExportApi(Resource):
|
||||
"""Export app"""
|
||||
args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
return {
|
||||
"data": AppDslService.export_dsl(
|
||||
payload = AppExportResponse(
|
||||
data=AppDslService.export_dsl(
|
||||
app_model=app_model,
|
||||
include_secret=args.include_secret,
|
||||
workflow_id=args.workflow_id,
|
||||
)
|
||||
}
|
||||
)
|
||||
return payload.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
@ -555,20 +714,19 @@ class AppNameApi(Resource):
|
||||
@console_ns.doc(description="Check if app name is available")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppNamePayload.__name__])
|
||||
@console_ns.response(200, "Name availability checked")
|
||||
@console_ns.response(200, "Name availability checked", console_ns.models[AppDetail.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_model)
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
args = AppNamePayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_name(app_model, args.name)
|
||||
|
||||
return app_model
|
||||
response_model = AppDetail.model_validate(app_model, from_attributes=True)
|
||||
return response_model.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/icon")
|
||||
@ -582,16 +740,15 @@ class AppIconApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_model)
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
args = AppIconPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
|
||||
|
||||
return app_model
|
||||
response_model = AppDetail.model_validate(app_model, from_attributes=True)
|
||||
return response_model.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site-enable")
|
||||
@ -600,21 +757,20 @@ class AppSiteStatus(Resource):
|
||||
@console_ns.doc(description="Enable or disable app site")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__])
|
||||
@console_ns.response(200, "Site status updated successfully", app_detail_model)
|
||||
@console_ns.response(200, "Site status updated successfully", console_ns.models[AppDetail.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_model)
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
args = AppSiteStatusPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_site_status(app_model, args.enable_site)
|
||||
|
||||
return app_model
|
||||
response_model = AppDetail.model_validate(app_model, from_attributes=True)
|
||||
return response_model.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/api-enable")
|
||||
@ -623,21 +779,20 @@ class AppApiStatus(Resource):
|
||||
@console_ns.doc(description="Enable or disable app API")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppApiStatusPayload.__name__])
|
||||
@console_ns.response(200, "API status updated successfully", app_detail_model)
|
||||
@console_ns.response(200, "API status updated successfully", console_ns.models[AppDetail.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_model)
|
||||
@get_app_model(mode=None)
|
||||
def post(self, app_model):
|
||||
args = AppApiStatusPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_api_status(app_model, args.enable_api)
|
||||
|
||||
return app_model
|
||||
response_model = AppDetail.model_validate(app_model, from_attributes=True)
|
||||
return response_model.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trace")
|
||||
|
||||
@ -348,10 +348,13 @@ class CompletionConversationApi(Resource):
|
||||
)
|
||||
|
||||
if args.keyword:
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_keyword = escape_like_pattern(args.keyword)
|
||||
query = query.join(Message, Message.conversation_id == Conversation.id).where(
|
||||
or_(
|
||||
Message.query.ilike(f"%{args.keyword}%"),
|
||||
Message.answer.ilike(f"%{args.keyword}%"),
|
||||
Message.query.ilike(f"%{escaped_keyword}%", escape="\\"),
|
||||
Message.answer.ilike(f"%{escaped_keyword}%", escape="\\"),
|
||||
)
|
||||
)
|
||||
|
||||
@ -460,7 +463,10 @@ class ChatConversationApi(Resource):
|
||||
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
|
||||
if args.keyword:
|
||||
keyword_filter = f"%{args.keyword}%"
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_keyword = escape_like_pattern(args.keyword)
|
||||
keyword_filter = f"%{escaped_keyword}%"
|
||||
query = (
|
||||
query.join(
|
||||
Message,
|
||||
@ -469,11 +475,11 @@ class ChatConversationApi(Resource):
|
||||
.join(subquery, subquery.c.conversation_id == Conversation.id)
|
||||
.where(
|
||||
or_(
|
||||
Message.query.ilike(keyword_filter),
|
||||
Message.answer.ilike(keyword_filter),
|
||||
Conversation.name.ilike(keyword_filter),
|
||||
Conversation.introduction.ilike(keyword_filter),
|
||||
subquery.c.from_end_user_session_id.ilike(keyword_filter),
|
||||
Message.query.ilike(keyword_filter, escape="\\"),
|
||||
Message.answer.ilike(keyword_filter, escape="\\"),
|
||||
Conversation.name.ilike(keyword_filter, escape="\\"),
|
||||
Conversation.introduction.ilike(keyword_filter, escape="\\"),
|
||||
subquery.c.from_end_user_session_id.ilike(keyword_filter, escape="\\"),
|
||||
),
|
||||
)
|
||||
.group_by(Conversation.id)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
import flask_login
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
@ -96,14 +98,13 @@ class LoginApi(Resource):
|
||||
if is_login_error_rate_limit:
|
||||
raise EmailPasswordLoginLimitError()
|
||||
|
||||
# TODO: why invitation is re-assigned with different type?
|
||||
invitation = args.invite_token # type: ignore
|
||||
if invitation:
|
||||
invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
|
||||
invitation_data: dict[str, Any] | None = None
|
||||
if args.invite_token:
|
||||
invitation_data = RegisterService.get_invitation_if_token_valid(None, args.email, args.invite_token)
|
||||
|
||||
try:
|
||||
if invitation:
|
||||
data = invitation.get("data", {}) # type: ignore
|
||||
if invitation_data:
|
||||
data = invitation_data.get("data", {})
|
||||
invitee_email = data.get("email") if data else None
|
||||
if invitee_email != args.email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
@ -751,12 +751,12 @@ class DocumentApi(DocumentResource):
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
"position": document.position,
|
||||
"data_source_type": document.data_source_type,
|
||||
"data_source_info": data_source_info,
|
||||
"data_source_info": document.data_source_info_dict,
|
||||
"data_source_detail_dict": document.data_source_detail_dict,
|
||||
"dataset_process_rule_id": document.dataset_process_rule_id,
|
||||
"dataset_process_rule": dataset_process_rules,
|
||||
"document_process_rule": document_process_rules,
|
||||
@ -784,12 +784,12 @@ class DocumentApi(DocumentResource):
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
"position": document.position,
|
||||
"data_source_type": document.data_source_type,
|
||||
"data_source_info": data_source_info,
|
||||
"data_source_info": document.data_source_info_dict,
|
||||
"data_source_detail_dict": document.data_source_detail_dict,
|
||||
"dataset_process_rule_id": document.dataset_process_rule_id,
|
||||
"dataset_process_rule": dataset_process_rules,
|
||||
"document_process_rule": document_process_rules,
|
||||
|
||||
@ -30,6 +30,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from libs.helper import escape_like_pattern
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import ChildChunk, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
@ -145,6 +146,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
query = query.where(DocumentSegment.hit_count >= hit_count_gte)
|
||||
|
||||
if keyword:
|
||||
# Escape special characters in keyword to prevent SQL injection via LIKE wildcards
|
||||
escaped_keyword = escape_like_pattern(keyword)
|
||||
# Search in both content and keywords fields
|
||||
# Use database-specific methods for JSON array search
|
||||
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
|
||||
@ -156,15 +159,15 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
.scalar_subquery()
|
||||
),
|
||||
",",
|
||||
).ilike(f"%{keyword}%")
|
||||
).ilike(f"%{escaped_keyword}%", escape="\\")
|
||||
else:
|
||||
# MySQL: Cast JSON to string for pattern matching
|
||||
# MySQL stores Chinese text directly in JSON without Unicode escaping
|
||||
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{keyword}%")
|
||||
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{escaped_keyword}%", escape="\\")
|
||||
|
||||
query = query.where(
|
||||
or_(
|
||||
DocumentSegment.content.ilike(f"%{keyword}%"),
|
||||
DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"),
|
||||
keywords_condition,
|
||||
)
|
||||
)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import marshal, reqparse
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
@ -56,15 +56,10 @@ class DatasetsHitTestingBase:
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
@staticmethod
|
||||
def parse_args():
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("query", type=str, required=False, location="json")
|
||||
.add_argument("attachment_ids", type=list, required=False, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
)
|
||||
return parser.parse_args()
|
||||
def parse_args(payload: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate and return hit-testing arguments from an incoming payload."""
|
||||
hit_testing_payload = HitTestingPayload.model_validate(payload or {})
|
||||
return hit_testing_payload.model_dump(exclude_none=True)
|
||||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset, args):
|
||||
|
||||
@ -4,12 +4,11 @@ from typing import Any
|
||||
|
||||
from flask import make_response, redirect, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
@ -44,6 +43,12 @@ class TriggerSubscriptionUpdateRequest(BaseModel):
|
||||
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription")
|
||||
properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_at_least_one_field(self):
|
||||
if all(v is None for v in (self.name, self.credentials, self.parameters, self.properties)):
|
||||
raise ValueError("At least one of name, credentials, parameters, or properties must be provided")
|
||||
return self
|
||||
|
||||
|
||||
class TriggerSubscriptionVerifyRequest(BaseModel):
|
||||
"""Request payload for verifying subscription credentials."""
|
||||
@ -333,7 +338,7 @@ class TriggerSubscriptionUpdateApi(Resource):
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
args = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
|
||||
request = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
|
||||
|
||||
subscription = TriggerProviderService.get_subscription_by_id(
|
||||
tenant_id=user.current_tenant_id,
|
||||
@ -345,50 +350,32 @@ class TriggerSubscriptionUpdateApi(Resource):
|
||||
provider_id = TriggerProviderID(subscription.provider_id)
|
||||
|
||||
try:
|
||||
# rename only
|
||||
if (
|
||||
args.name is not None
|
||||
and args.credentials is None
|
||||
and args.parameters is None
|
||||
and args.properties is None
|
||||
):
|
||||
# For rename only, just update the name
|
||||
rename = request.name is not None and not any((request.credentials, request.parameters, request.properties))
|
||||
# When credential type is UNAUTHORIZED, it indicates the subscription was manually created
|
||||
# For Manually created subscription, they dont have credentials, parameters
|
||||
# They only have name and properties(which is input by user)
|
||||
manually_created = subscription.credential_type == CredentialType.UNAUTHORIZED
|
||||
if rename or manually_created:
|
||||
TriggerProviderService.update_trigger_subscription(
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
name=args.name,
|
||||
name=request.name,
|
||||
properties=request.properties,
|
||||
)
|
||||
return 200
|
||||
|
||||
# rebuild for create automatically by the provider
|
||||
match subscription.credential_type:
|
||||
case CredentialType.UNAUTHORIZED:
|
||||
TriggerProviderService.update_trigger_subscription(
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
name=args.name,
|
||||
properties=args.properties,
|
||||
)
|
||||
return 200
|
||||
case CredentialType.API_KEY | CredentialType.OAUTH2:
|
||||
if args.credentials:
|
||||
new_credentials: dict[str, Any] = {
|
||||
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
|
||||
for key, value in args.credentials.items()
|
||||
}
|
||||
else:
|
||||
new_credentials = subscription.credentials
|
||||
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=user.current_tenant_id,
|
||||
name=args.name,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription_id,
|
||||
credentials=new_credentials,
|
||||
parameters=args.parameters or subscription.parameters,
|
||||
)
|
||||
return 200
|
||||
case _:
|
||||
raise BadRequest("Invalid credential type")
|
||||
# For the rest cases(API_KEY, OAUTH2)
|
||||
# we need to call third party provider(e.g. GitHub) to rebuild the subscription
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=user.current_tenant_id,
|
||||
name=request.name,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription_id,
|
||||
credentials=request.credentials or subscription.credentials,
|
||||
parameters=request.parameters or subscription.parameters,
|
||||
)
|
||||
return 200
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
|
||||
@ -24,7 +24,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
args = self.parse_args()
|
||||
args = self.parse_args(service_api_ns.payload)
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from flask_restx import reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from pydantic import TypeAdapter
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotChatAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
@ -21,6 +23,35 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers
|
||||
from services.web_conversation_service import WebConversationService
|
||||
|
||||
|
||||
class ConversationListQuery(BaseModel):
|
||||
last_id: str | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
pinned: bool | None = None
|
||||
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = "-updated_at"
|
||||
|
||||
@field_validator("last_id")
|
||||
@classmethod
|
||||
def validate_last_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = None
|
||||
auto_generate: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)
|
||||
|
||||
|
||||
@web_ns.route("/conversations")
|
||||
class ConversationListApi(WebApiResource):
|
||||
@web_ns.doc("Get Conversation List")
|
||||
@ -64,25 +95,8 @@ class ConversationListApi(WebApiResource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
|
||||
.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
required=False,
|
||||
default="-updated_at",
|
||||
location="args",
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
pinned = None
|
||||
if "pinned" in args and args["pinned"] is not None:
|
||||
pinned = args["pinned"] == "true"
|
||||
raw_args = request.args.to_dict()
|
||||
query = ConversationListQuery.model_validate(raw_args)
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
@ -90,11 +104,11 @@ class ConversationListApi(WebApiResource):
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
last_id=query.last_id,
|
||||
limit=query.limit,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=pinned,
|
||||
sort_by=args["sort_by"],
|
||||
pinned=query.pinned,
|
||||
sort_by=query.sort_by,
|
||||
)
|
||||
adapter = TypeAdapter(SimpleConversation)
|
||||
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
|
||||
@ -168,16 +182,11 @@ class ConversationRenameApi(WebApiResource):
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=False, location="json")
|
||||
.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = ConversationRenamePayload.model_validate(web_ns.payload or {})
|
||||
|
||||
try:
|
||||
conversation = ConversationService.rename(
|
||||
app_model, conversation_id, end_user, args["name"], args["auto_generate"]
|
||||
app_model, conversation_id, end_user, payload.name, payload.auto_generate
|
||||
)
|
||||
return (
|
||||
TypeAdapter(SimpleConversation)
|
||||
|
||||
@ -1,18 +1,30 @@
|
||||
from flask_restx import reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from pydantic import TypeAdapter
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotCompletionAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from libs.helper import uuid_value
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
||||
class SavedMessageListQuery(BaseModel):
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SavedMessageCreatePayload(BaseModel):
|
||||
message_id: UUIDStrOrEmpty
|
||||
|
||||
|
||||
register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||
|
||||
|
||||
@web_ns.route("/saved-messages")
|
||||
class SavedMessageListApi(WebApiResource):
|
||||
@web_ns.doc("Get Saved Messages")
|
||||
@ -42,14 +54,10 @@ class SavedMessageListApi(WebApiResource):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
raw_args = request.args.to_dict()
|
||||
query = SavedMessageListQuery.model_validate(raw_args)
|
||||
|
||||
pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
|
||||
pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit)
|
||||
adapter = TypeAdapter(SavedMessageItem)
|
||||
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
|
||||
return SavedMessageInfiniteScrollPagination(
|
||||
@ -79,11 +87,10 @@ class SavedMessageListApi(WebApiResource):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
payload = SavedMessageCreatePayload.model_validate(web_ns.payload or {})
|
||||
|
||||
try:
|
||||
SavedMessageService.save(app_model, end_user, args["message_id"])
|
||||
SavedMessageService.save(app_model, end_user, payload.message_id)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
|
||||
@ -20,6 +20,8 @@ from core.app.entities.queue_entities import (
|
||||
QueueTextChunkEvent,
|
||||
)
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
@ -40,6 +42,7 @@ from models import Workflow
|
||||
from models.enums import UserFrom
|
||||
from models.model import App, Conversation, Message, MessageAnnotation
|
||||
from models.workflow import ConversationVariable
|
||||
from services.conversation_variable_updater import ConversationVariableUpdater
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -200,6 +203,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
conversation_variable_layer = ConversationVariablePersistenceLayer(
|
||||
ConversationVariableUpdater(session_factory.get_session_maker())
|
||||
)
|
||||
workflow_entry.graph_engine.layer(conversation_variable_layer)
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
|
||||
@ -75,7 +75,7 @@ class AnnotationReplyFeature:
|
||||
AppAnnotationService.add_annotation_history(
|
||||
annotation.id,
|
||||
app_record.id,
|
||||
annotation.question,
|
||||
annotation.question_text,
|
||||
annotation.content,
|
||||
query,
|
||||
user_id,
|
||||
|
||||
60
api/core/app/layers/conversation_variable_persist_layer.py
Normal file
60
api/core/app/layers/conversation_variable_persist_layer.py
Normal file
@ -0,0 +1,60 @@
|
||||
import logging
|
||||
|
||||
from core.variables import Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationVariablePersistenceLayer(GraphEngineLayer):
|
||||
def __init__(self, conversation_variable_updater: ConversationVariableUpdater) -> None:
|
||||
super().__init__()
|
||||
self._conversation_variable_updater = conversation_variable_updater
|
||||
|
||||
def on_graph_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if not isinstance(event, NodeRunSucceededEvent):
|
||||
return
|
||||
if event.node_type != NodeType.VARIABLE_ASSIGNER:
|
||||
return
|
||||
if self.graph_runtime_state is None:
|
||||
return
|
||||
|
||||
updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or []
|
||||
if not updated_variables:
|
||||
return
|
||||
|
||||
conversation_id = self.graph_runtime_state.system_variable.conversation_id
|
||||
if conversation_id is None:
|
||||
return
|
||||
|
||||
updated_any = False
|
||||
for item in updated_variables:
|
||||
selector = item.selector
|
||||
if len(selector) < 2:
|
||||
logger.warning("Conversation variable selector invalid. selector=%s", selector)
|
||||
continue
|
||||
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
|
||||
continue
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if not isinstance(variable, Variable):
|
||||
logger.warning(
|
||||
"Conversation variable not found in variable pool. selector=%s",
|
||||
selector,
|
||||
)
|
||||
continue
|
||||
self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable)
|
||||
updated_any = True
|
||||
|
||||
if updated_any:
|
||||
self._conversation_variable_updater.flush()
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
pass
|
||||
@ -66,6 +66,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
"""
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(session_factory)
|
||||
super().__init__()
|
||||
self._session_maker = session_factory
|
||||
self._state_owner_user_id = state_owner_user_id
|
||||
self._generate_entity = generate_entity
|
||||
@ -98,8 +99,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
if not isinstance(event, GraphRunPausedEvent):
|
||||
return
|
||||
|
||||
assert self.graph_runtime_state is not None
|
||||
|
||||
entity_wrapper: _GenerateEntityUnion
|
||||
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
|
||||
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)
|
||||
|
||||
@ -33,6 +33,7 @@ class TriggerPostLayer(GraphEngineLayer):
|
||||
trigger_log_id: str,
|
||||
session_maker: sessionmaker[Session],
|
||||
):
|
||||
super().__init__()
|
||||
self.trigger_log_id = trigger_log_id
|
||||
self.start_time = start_time
|
||||
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
|
||||
@ -57,10 +58,6 @@ class TriggerPostLayer(GraphEngineLayer):
|
||||
elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds()
|
||||
|
||||
# Extract relevant data from result
|
||||
if not self.graph_runtime_state:
|
||||
logger.exception("Graph runtime state is not set")
|
||||
return
|
||||
|
||||
outputs = self.graph_runtime_state.outputs
|
||||
|
||||
# BASICLY, workflow_execution_id is the same as workflow_run_id
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
_session_maker: sessionmaker | None = None
|
||||
_session_maker: sessionmaker[Session] | None = None
|
||||
|
||||
|
||||
def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
|
||||
@ -10,7 +10,7 @@ def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
|
||||
_session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
|
||||
|
||||
|
||||
def get_session_maker() -> sessionmaker:
|
||||
def get_session_maker() -> sessionmaker[Session]:
|
||||
if _session_maker is None:
|
||||
raise RuntimeError("Session factory not configured. Call configure_session_factory() first.")
|
||||
return _session_maker
|
||||
@ -27,7 +27,7 @@ class SessionFactory:
|
||||
configure_session_factory(engine, expire_on_commit)
|
||||
|
||||
@staticmethod
|
||||
def get_session_maker() -> sessionmaker:
|
||||
def get_session_maker() -> sessionmaker[Session]:
|
||||
return get_session_maker()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -8,8 +8,9 @@ import urllib.parse
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def get_signed_file_url(upload_file_id: str, as_attachment=False) -> str:
|
||||
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview"
|
||||
def get_signed_file_url(upload_file_id: str, as_attachment=False, for_external: bool = True) -> str:
|
||||
base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
|
||||
url = f"{base_url}/files/{upload_file_id}/file-preview"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
|
||||
@ -112,17 +112,17 @@ class File(BaseModel):
|
||||
|
||||
return text
|
||||
|
||||
def generate_url(self) -> str | None:
|
||||
def generate_url(self, for_external: bool = True) -> str | None:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.remote_url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if self.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
return helpers.get_signed_file_url(upload_file_id=self.related_id)
|
||||
return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external)
|
||||
elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]:
|
||||
assert self.related_id is not None
|
||||
assert self.extension is not None
|
||||
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension)
|
||||
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension, for_external=for_external)
|
||||
return None
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
@ -133,7 +133,7 @@ class File(BaseModel):
|
||||
"extension": self.extension,
|
||||
"size": self.size,
|
||||
"type": self.type,
|
||||
"url": self.generate_url(),
|
||||
"url": self.generate_url(for_external=False),
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
@ -76,7 +76,7 @@ class TemplateTransformer(ABC):
|
||||
Post-process the result to convert scientific notation strings back to numbers
|
||||
"""
|
||||
|
||||
def convert_scientific_notation(value):
|
||||
def convert_scientific_notation(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
# Check if the string looks like scientific notation
|
||||
if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
|
||||
@ -90,7 +90,7 @@ class TemplateTransformer(ABC):
|
||||
return [convert_scientific_notation(v) for v in value]
|
||||
return value
|
||||
|
||||
return convert_scientific_notation(result) # type: ignore[no-any-return]
|
||||
return convert_scientific_notation(result)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
|
||||
@ -984,9 +984,11 @@ class ClickzettaVector(BaseVector):
|
||||
|
||||
# No need for dataset_id filter since each dataset has its own table
|
||||
|
||||
# Use simple quote escaping for LIKE clause
|
||||
escaped_query = query.replace("'", "''")
|
||||
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'")
|
||||
# Escape special characters for LIKE clause to prevent SQL injection
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_query = escape_like_pattern(query).replace("'", "''")
|
||||
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%' ESCAPE '\\\\'")
|
||||
where_clause = " AND ".join(filter_clauses)
|
||||
|
||||
search_sql = f"""
|
||||
|
||||
@ -287,11 +287,15 @@ class IrisVector(BaseVector):
|
||||
cursor.execute(sql, (query,))
|
||||
else:
|
||||
# Fallback to LIKE search (inefficient for large datasets)
|
||||
query_pattern = f"%{query}%"
|
||||
# Escape special characters for LIKE clause to prevent SQL injection
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_query = escape_like_pattern(query)
|
||||
query_pattern = f"%{escaped_query}%"
|
||||
sql = f"""
|
||||
SELECT TOP {top_k} id, text, meta
|
||||
FROM {self.schema}.{self.table_name}
|
||||
WHERE text LIKE ?
|
||||
WHERE text LIKE ? ESCAPE '\\'
|
||||
"""
|
||||
cursor.execute(sql, (query_pattern,))
|
||||
|
||||
|
||||
@ -66,6 +66,8 @@ class WeaviateVector(BaseVector):
|
||||
in a Weaviate collection.
|
||||
"""
|
||||
|
||||
_DOCUMENT_ID_PROPERTY = "document_id"
|
||||
|
||||
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
|
||||
"""
|
||||
Initializes the Weaviate vector store.
|
||||
@ -353,15 +355,12 @@ class WeaviateVector(BaseVector):
|
||||
return []
|
||||
|
||||
col = self._client.collections.use(self._collection_name)
|
||||
props = list({*self._attributes, "document_id", Field.TEXT_KEY.value})
|
||||
props = list({*self._attributes, self._DOCUMENT_ID_PROPERTY, Field.TEXT_KEY.value})
|
||||
|
||||
where = None
|
||||
doc_ids = kwargs.get("document_ids_filter") or []
|
||||
if doc_ids:
|
||||
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
|
||||
where = ors[0]
|
||||
for f in ors[1:]:
|
||||
where = where | f
|
||||
where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids)
|
||||
|
||||
top_k = int(kwargs.get("top_k", 4))
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
@ -408,10 +407,7 @@ class WeaviateVector(BaseVector):
|
||||
where = None
|
||||
doc_ids = kwargs.get("document_ids_filter") or []
|
||||
if doc_ids:
|
||||
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
|
||||
where = ors[0]
|
||||
for f in ors[1:]:
|
||||
where = where | f
|
||||
where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids)
|
||||
|
||||
top_k = int(kwargs.get("top_k", 4))
|
||||
|
||||
|
||||
@ -7,10 +7,11 @@ import re
|
||||
import tempfile
|
||||
import uuid
|
||||
from urllib.parse import urlparse
|
||||
from xml.etree import ElementTree
|
||||
|
||||
import httpx
|
||||
from docx import Document as DocxDocument
|
||||
from docx.oxml.ns import qn
|
||||
from docx.text.run import Run
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
@ -229,44 +230,20 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
image_map = self._extract_images_from_docx(doc)
|
||||
|
||||
hyperlinks_url = None
|
||||
url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+")
|
||||
for para in doc.paragraphs:
|
||||
for run in para.runs:
|
||||
if run.text and hyperlinks_url:
|
||||
result = f" [{run.text}]({hyperlinks_url}) "
|
||||
run.text = result
|
||||
hyperlinks_url = None
|
||||
if "HYPERLINK" in run.element.xml:
|
||||
try:
|
||||
xml = ElementTree.XML(run.element.xml)
|
||||
x_child = [c for c in xml.iter() if c is not None]
|
||||
for x in x_child:
|
||||
if x is None:
|
||||
continue
|
||||
if x.tag.endswith("instrText"):
|
||||
if x.text is None:
|
||||
continue
|
||||
for i in url_pattern.findall(x.text):
|
||||
hyperlinks_url = str(i)
|
||||
except Exception:
|
||||
logger.exception("Failed to parse HYPERLINK xml")
|
||||
|
||||
def parse_paragraph(paragraph):
|
||||
paragraph_content = []
|
||||
|
||||
def append_image_link(image_id, has_drawing):
|
||||
def append_image_link(image_id, has_drawing, target_buffer):
|
||||
"""Helper to append image link from image_map based on relationship type."""
|
||||
rel = doc.part.rels[image_id]
|
||||
if rel.is_external:
|
||||
if image_id in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_id])
|
||||
target_buffer.append(image_map[image_id])
|
||||
else:
|
||||
image_part = rel.target_part
|
||||
if image_part in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
target_buffer.append(image_map[image_part])
|
||||
|
||||
for run in paragraph.runs:
|
||||
def process_run(run, target_buffer):
|
||||
# Helper to extract text and embedded images from a run element and append them to target_buffer
|
||||
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
|
||||
# Process drawing type images
|
||||
drawing_elements = run.element.findall(
|
||||
@ -287,13 +264,13 @@ class WordExtractor(BaseExtractor):
|
||||
# External image: use embed_id as key
|
||||
if embed_id in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[embed_id])
|
||||
target_buffer.append(image_map[embed_id])
|
||||
else:
|
||||
# Internal image: use target_part as key
|
||||
image_part = doc.part.related_parts.get(embed_id)
|
||||
if image_part in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[image_part])
|
||||
target_buffer.append(image_map[image_part])
|
||||
# Process pict type images
|
||||
shape_elements = run.element.findall(
|
||||
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
|
||||
@ -308,7 +285,7 @@ class WordExtractor(BaseExtractor):
|
||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
||||
)
|
||||
if image_id and image_id in doc.part.rels:
|
||||
append_image_link(image_id, has_drawing)
|
||||
append_image_link(image_id, has_drawing, target_buffer)
|
||||
# Find imagedata element in VML
|
||||
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
|
||||
if image_data is not None:
|
||||
@ -316,9 +293,93 @@ class WordExtractor(BaseExtractor):
|
||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
||||
)
|
||||
if image_id and image_id in doc.part.rels:
|
||||
append_image_link(image_id, has_drawing)
|
||||
append_image_link(image_id, has_drawing, target_buffer)
|
||||
if run.text.strip():
|
||||
paragraph_content.append(run.text.strip())
|
||||
target_buffer.append(run.text.strip())
|
||||
|
||||
def process_hyperlink(hyperlink_elem, target_buffer):
|
||||
# Helper to extract text from a hyperlink element and append it to target_buffer
|
||||
r_id = hyperlink_elem.get(qn("r:id"))
|
||||
|
||||
# Extract text from runs inside the hyperlink
|
||||
link_text_parts = []
|
||||
for run_elem in hyperlink_elem.findall(qn("w:r")):
|
||||
run = Run(run_elem, paragraph)
|
||||
# Hyperlink text may be split across multiple runs (e.g., with different formatting),
|
||||
# so collect all run texts first
|
||||
if run.text:
|
||||
link_text_parts.append(run.text)
|
||||
|
||||
link_text = "".join(link_text_parts).strip()
|
||||
|
||||
# Resolve URL
|
||||
if r_id:
|
||||
try:
|
||||
rel = doc.part.rels.get(r_id)
|
||||
if rel and rel.is_external:
|
||||
link_text = f"[{link_text or rel.target_ref}]({rel.target_ref})"
|
||||
except Exception:
|
||||
logger.exception("Failed to resolve URL for hyperlink with r:id: %s", r_id)
|
||||
|
||||
if link_text:
|
||||
target_buffer.append(link_text)
|
||||
|
||||
paragraph_content = []
|
||||
# State for legacy HYPERLINK fields
|
||||
hyperlink_field_url = None
|
||||
hyperlink_field_text_parts: list = []
|
||||
is_collecting_field_text = False
|
||||
# Iterate through paragraph elements in document order
|
||||
for child in paragraph._element:
|
||||
tag = child.tag
|
||||
if tag == qn("w:r"):
|
||||
# Regular run
|
||||
run = Run(child, paragraph)
|
||||
|
||||
# Check for fldChar (begin/end/separate) and instrText for legacy hyperlinks
|
||||
fld_chars = child.findall(qn("w:fldChar"))
|
||||
instr_texts = child.findall(qn("w:instrText"))
|
||||
|
||||
# Handle Fields
|
||||
if fld_chars or instr_texts:
|
||||
# Process instrText to find HYPERLINK "url"
|
||||
for instr in instr_texts:
|
||||
if instr.text and "HYPERLINK" in instr.text:
|
||||
# Quick regex to extract URL
|
||||
match = re.search(r'HYPERLINK\s+"([^"]+)"', instr.text, re.IGNORECASE)
|
||||
if match:
|
||||
hyperlink_field_url = match.group(1)
|
||||
|
||||
# Process fldChar
|
||||
for fld_char in fld_chars:
|
||||
fld_char_type = fld_char.get(qn("w:fldCharType"))
|
||||
if fld_char_type == "begin":
|
||||
# Start of a field: reset legacy link state
|
||||
hyperlink_field_url = None
|
||||
hyperlink_field_text_parts = []
|
||||
is_collecting_field_text = False
|
||||
elif fld_char_type == "separate":
|
||||
# Separator: if we found a URL, start collecting visible text
|
||||
if hyperlink_field_url:
|
||||
is_collecting_field_text = True
|
||||
elif fld_char_type == "end":
|
||||
# End of field
|
||||
if is_collecting_field_text and hyperlink_field_url:
|
||||
# Create markdown link and append to main content
|
||||
display_text = "".join(hyperlink_field_text_parts).strip()
|
||||
if display_text:
|
||||
link_md = f"[{display_text}]({hyperlink_field_url})"
|
||||
paragraph_content.append(link_md)
|
||||
# Reset state
|
||||
hyperlink_field_url = None
|
||||
hyperlink_field_text_parts = []
|
||||
is_collecting_field_text = False
|
||||
|
||||
# Decide where to append content
|
||||
target_buffer = hyperlink_field_text_parts if is_collecting_field_text else paragraph_content
|
||||
process_run(run, target_buffer)
|
||||
elif tag == qn("w:hyperlink"):
|
||||
process_hyperlink(child, paragraph_content)
|
||||
return "".join(paragraph_content) if paragraph_content else ""
|
||||
|
||||
paragraphs = doc.paragraphs.copy()
|
||||
|
||||
@ -1198,18 +1198,24 @@ class DatasetRetrieval:
|
||||
|
||||
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
|
||||
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
match condition:
|
||||
case "contains":
|
||||
filters.append(json_field.like(f"%{value}%"))
|
||||
escaped_value = escape_like_pattern(str(value))
|
||||
filters.append(json_field.like(f"%{escaped_value}%", escape="\\"))
|
||||
|
||||
case "not contains":
|
||||
filters.append(json_field.notlike(f"%{value}%"))
|
||||
escaped_value = escape_like_pattern(str(value))
|
||||
filters.append(json_field.notlike(f"%{escaped_value}%", escape="\\"))
|
||||
|
||||
case "start with":
|
||||
filters.append(json_field.like(f"{value}%"))
|
||||
escaped_value = escape_like_pattern(str(value))
|
||||
filters.append(json_field.like(f"{escaped_value}%", escape="\\"))
|
||||
|
||||
case "end with":
|
||||
filters.append(json_field.like(f"%{value}"))
|
||||
escaped_value = escape_like_pattern(str(value))
|
||||
filters.append(json_field.like(f"%{escaped_value}", escape="\\"))
|
||||
|
||||
case "is" | "=":
|
||||
if isinstance(value, str):
|
||||
@ -1474,38 +1480,38 @@ class DatasetRetrieval:
|
||||
if cancel_event and cancel_event.is_set():
|
||||
break
|
||||
|
||||
# Skip second reranking when there is only one dataset
|
||||
if reranking_enable and dataset_count > 1:
|
||||
# do rerank for searched documents
|
||||
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
||||
if query:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.TEXT_QUERY,
|
||||
)
|
||||
if attachment_id:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.IMAGE_QUERY,
|
||||
query=attachment_id,
|
||||
)
|
||||
else:
|
||||
if index_type == IndexTechniqueType.ECONOMY:
|
||||
if not query:
|
||||
all_documents_item = []
|
||||
else:
|
||||
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
|
||||
elif index_type == IndexTechniqueType.HIGH_QUALITY:
|
||||
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
|
||||
# Skip second reranking when there is only one dataset
|
||||
if reranking_enable and dataset_count > 1:
|
||||
# do rerank for searched documents
|
||||
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
||||
if query:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.TEXT_QUERY,
|
||||
)
|
||||
if attachment_id:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.IMAGE_QUERY,
|
||||
query=attachment_id,
|
||||
)
|
||||
else:
|
||||
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
|
||||
if all_documents_item:
|
||||
all_documents.extend(all_documents_item)
|
||||
if index_type == IndexTechniqueType.ECONOMY:
|
||||
if not query:
|
||||
all_documents_item = []
|
||||
else:
|
||||
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
|
||||
elif index_type == IndexTechniqueType.HIGH_QUALITY:
|
||||
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
|
||||
else:
|
||||
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
|
||||
if all_documents_item:
|
||||
all_documents.extend(all_documents_item)
|
||||
except Exception as e:
|
||||
if cancel_event:
|
||||
cancel_event.set()
|
||||
|
||||
@ -7,12 +7,12 @@ import time
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def sign_tool_file(tool_file_id: str, extension: str) -> str:
|
||||
def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True) -> str:
|
||||
"""
|
||||
sign file to get a temporary url for plugin access
|
||||
"""
|
||||
# Use internal URL for plugin/tool file access in Docker environments
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
# Use internal URL for plugin/tool file access in Docker environments, unless for_external is True
|
||||
base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
|
||||
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
|
||||
@ -64,6 +64,9 @@ engine.layer(DebugLoggingLayer(level="INFO"))
|
||||
engine.layer(ExecutionLimitsLayer(max_nodes=100))
|
||||
```
|
||||
|
||||
`engine.layer()` binds the read-only runtime state before execution, so layer hooks
|
||||
can assume `graph_runtime_state` is available.
|
||||
|
||||
### Event-Driven Architecture
|
||||
|
||||
All node executions emit events for monitoring and integration:
|
||||
|
||||
@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue.
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
@ -113,6 +113,8 @@ class RedisChannel:
|
||||
return AbortCommand.model_validate(data)
|
||||
if command_type == CommandType.PAUSE:
|
||||
return PauseCommand.model_validate(data)
|
||||
if command_type == CommandType.UPDATE_VARIABLES:
|
||||
return UpdateVariablesCommand.model_validate(data)
|
||||
|
||||
# For other command types, use base class
|
||||
return GraphEngineCommand.model_validate(data)
|
||||
|
||||
@ -5,11 +5,12 @@ This package handles external commands sent to the engine
|
||||
during execution.
|
||||
"""
|
||||
|
||||
from .command_handlers import AbortCommandHandler, PauseCommandHandler
|
||||
from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler
|
||||
from .command_processor import CommandProcessor
|
||||
|
||||
__all__ = [
|
||||
"AbortCommandHandler",
|
||||
"CommandProcessor",
|
||||
"PauseCommandHandler",
|
||||
"UpdateVariablesCommandHandler",
|
||||
]
|
||||
|
||||
@ -4,9 +4,10 @@ from typing import final
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
|
||||
from .command_processor import CommandHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -31,3 +32,25 @@ class PauseCommandHandler(CommandHandler):
|
||||
reason = command.reason
|
||||
pause_reason = SchedulingPause(message=reason)
|
||||
execution.pause(pause_reason)
|
||||
|
||||
|
||||
@final
|
||||
class UpdateVariablesCommandHandler(CommandHandler):
|
||||
def __init__(self, variable_pool: VariablePool) -> None:
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
@override
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
assert isinstance(command, UpdateVariablesCommand)
|
||||
for update in command.updates:
|
||||
try:
|
||||
variable = update.value
|
||||
self._variable_pool.add(variable.selector, variable)
|
||||
logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id)
|
||||
except ValueError as exc:
|
||||
logger.warning(
|
||||
"Skipping invalid variable selector %s for workflow %s: %s",
|
||||
getattr(update.value, "selector", None),
|
||||
execution.workflow_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
@ -5,17 +5,21 @@ This module defines command types that can be sent to a running GraphEngine
|
||||
instance to control its execution flow.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.variables.variables import VariableUnion
|
||||
|
||||
|
||||
class CommandType(StrEnum):
|
||||
"""Types of commands that can be sent to GraphEngine."""
|
||||
|
||||
ABORT = "abort"
|
||||
PAUSE = "pause"
|
||||
ABORT = auto()
|
||||
PAUSE = auto()
|
||||
UPDATE_VARIABLES = auto()
|
||||
|
||||
|
||||
class GraphEngineCommand(BaseModel):
|
||||
@ -37,3 +41,16 @@ class PauseCommand(GraphEngineCommand):
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
|
||||
reason: str = Field(default="unknown reason", description="reason for pause")
|
||||
|
||||
|
||||
class VariableUpdate(BaseModel):
|
||||
"""Represents a single variable update instruction."""
|
||||
|
||||
value: VariableUnion = Field(description="New variable value")
|
||||
|
||||
|
||||
class UpdateVariablesCommand(GraphEngineCommand):
|
||||
"""Command to update a group of variables in the variable pool."""
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command")
|
||||
updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates")
|
||||
|
||||
@ -8,6 +8,7 @@ Domain-Driven Design principles for improved maintainability and testability.
|
||||
import contextvars
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
@ -30,8 +31,13 @@ from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWr
|
||||
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
|
||||
from core.workflow.runtime.graph_runtime_state import GraphProtocol
|
||||
|
||||
from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler
|
||||
from .entities.commands import AbortCommand, PauseCommand
|
||||
from .command_processing import (
|
||||
AbortCommandHandler,
|
||||
CommandProcessor,
|
||||
PauseCommandHandler,
|
||||
UpdateVariablesCommandHandler,
|
||||
)
|
||||
from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand
|
||||
from .error_handler import ErrorHandler
|
||||
from .event_management import EventHandler, EventManager
|
||||
from .graph_state_manager import GraphStateManager
|
||||
@ -70,10 +76,13 @@ class GraphEngine:
|
||||
scale_down_idle_time: float | None = None,
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
# stop event
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
# Bind runtime state to current workflow context
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_runtime_state.stop_event = self._stop_event
|
||||
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
||||
self._command_channel = command_channel
|
||||
|
||||
@ -140,6 +149,9 @@ class GraphEngine:
|
||||
pause_handler = PauseCommandHandler()
|
||||
self._command_processor.register_handler(PauseCommand, pause_handler)
|
||||
|
||||
update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool)
|
||||
self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler)
|
||||
|
||||
# === Extensibility ===
|
||||
# Layers allow plugins to extend engine functionality
|
||||
self._layers: list[GraphEngineLayer] = []
|
||||
@ -169,6 +181,7 @@ class GraphEngine:
|
||||
max_workers=self._max_workers,
|
||||
scale_up_threshold=self._scale_up_threshold,
|
||||
scale_down_idle_time=self._scale_down_idle_time,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
# === Orchestration ===
|
||||
@ -199,6 +212,7 @@ class GraphEngine:
|
||||
event_handler=self._event_handler_registry,
|
||||
execution_coordinator=self._execution_coordinator,
|
||||
event_emitter=self._event_manager,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
# === Validation ===
|
||||
@ -212,9 +226,16 @@ class GraphEngine:
|
||||
if id(node.graph_runtime_state) != expected_state_id:
|
||||
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
|
||||
|
||||
def _bind_layer_context(
|
||||
self,
|
||||
layer: GraphEngineLayer,
|
||||
) -> None:
|
||||
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
|
||||
|
||||
def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
|
||||
"""Add a layer for extending functionality."""
|
||||
self._layers.append(layer)
|
||||
self._bind_layer_context(layer)
|
||||
return self
|
||||
|
||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
@ -301,14 +322,7 @@ class GraphEngine:
|
||||
def _initialize_layers(self) -> None:
|
||||
"""Initialize layers with context."""
|
||||
self._event_manager.set_layers(self._layers)
|
||||
# Create a read-only wrapper for the runtime state
|
||||
read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state)
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.initialize(read_only_state, self._command_channel)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
|
||||
|
||||
try:
|
||||
layer.on_graph_start()
|
||||
except Exception as e:
|
||||
@ -316,6 +330,7 @@ class GraphEngine:
|
||||
|
||||
def _start_execution(self, *, resume: bool = False) -> None:
|
||||
"""Start execution subsystems."""
|
||||
self._stop_event.clear()
|
||||
paused_nodes: list[str] = []
|
||||
if resume:
|
||||
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
|
||||
@ -343,13 +358,12 @@ class GraphEngine:
|
||||
|
||||
def _stop_execution(self) -> None:
|
||||
"""Stop execution subsystems."""
|
||||
self._stop_event.set()
|
||||
self._dispatcher.stop()
|
||||
self._worker_pool.stop()
|
||||
# Don't mark complete here as the dispatcher already does it
|
||||
|
||||
# Notify layers
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_graph_end(self._graph_execution.error)
|
||||
|
||||
@ -8,7 +8,7 @@ Pluggable middleware for engine extensions.
|
||||
|
||||
Abstract base class for layers.
|
||||
|
||||
- `initialize()` - Receive runtime context
|
||||
- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks)
|
||||
- `on_graph_start()` - Execution start hook
|
||||
- `on_event()` - Process all events
|
||||
- `on_graph_end()` - Execution end hook
|
||||
@ -34,6 +34,9 @@ engine.layer(debug_layer)
|
||||
engine.run()
|
||||
```
|
||||
|
||||
`engine.layer()` binds the read-only runtime state before execution, so
|
||||
`graph_runtime_state` is always available inside layer hooks.
|
||||
|
||||
## Custom Layers
|
||||
|
||||
```python
|
||||
|
||||
@ -13,6 +13,14 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.runtime import ReadOnlyGraphRuntimeState
|
||||
|
||||
|
||||
class GraphEngineLayerNotInitializedError(Exception):
|
||||
"""Raised when a layer's runtime state is accessed before initialization."""
|
||||
|
||||
def __init__(self, layer_name: str | None = None) -> None:
|
||||
name = layer_name or "GraphEngineLayer"
|
||||
super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.")
|
||||
|
||||
|
||||
class GraphEngineLayer(ABC):
|
||||
"""
|
||||
Abstract base class for GraphEngine layers.
|
||||
@ -28,22 +36,27 @@ class GraphEngineLayer(ABC):
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the layer. Subclasses can override with custom parameters."""
|
||||
self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
|
||||
self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
|
||||
self.command_channel: CommandChannel | None = None
|
||||
|
||||
@property
|
||||
def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState:
|
||||
if self._graph_runtime_state is None:
|
||||
raise GraphEngineLayerNotInitializedError(type(self).__name__)
|
||||
return self._graph_runtime_state
|
||||
|
||||
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
|
||||
"""
|
||||
Initialize the layer with engine dependencies.
|
||||
|
||||
Called by GraphEngine before execution starts to inject the read-only runtime state
|
||||
and command channel. This allows layers to observe engine context and send
|
||||
commands, but prevents direct state modification.
|
||||
|
||||
Called by GraphEngine to inject the read-only runtime state and command channel.
|
||||
This is invoked when the layer is registered with a `GraphEngine` instance.
|
||||
Implementations should be idempotent.
|
||||
Args:
|
||||
graph_runtime_state: Read-only view of the runtime state
|
||||
command_channel: Channel for sending commands to the engine
|
||||
"""
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self.command_channel = command_channel
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -109,10 +109,8 @@ class DebugLoggingLayer(GraphEngineLayer):
|
||||
self.logger.info("=" * 80)
|
||||
self.logger.info("🚀 GRAPH EXECUTION STARTED")
|
||||
self.logger.info("=" * 80)
|
||||
|
||||
if self.graph_runtime_state:
|
||||
# Log initial state
|
||||
self.logger.info("Initial State:")
|
||||
# Log initial state
|
||||
self.logger.info("Initial State:")
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
@ -243,8 +241,7 @@ class DebugLoggingLayer(GraphEngineLayer):
|
||||
self.logger.info(" Node retries: %s", self.retry_count)
|
||||
|
||||
# Log final state if available
|
||||
if self.graph_runtime_state and self.include_outputs:
|
||||
if self.graph_runtime_state.outputs:
|
||||
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
|
||||
if self.include_outputs and self.graph_runtime_state.outputs:
|
||||
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
|
||||
|
||||
self.logger.info("=" * 80)
|
||||
|
||||
@ -337,8 +337,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
if update_finished:
|
||||
execution.finished_at = naive_utc_now()
|
||||
runtime_state = self.graph_runtime_state
|
||||
if runtime_state is None:
|
||||
return
|
||||
execution.total_tokens = runtime_state.total_tokens
|
||||
execution.total_steps = runtime_state.node_run_steps
|
||||
execution.outputs = execution.outputs or runtime_state.outputs
|
||||
@ -404,6 +402,4 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
|
||||
def _system_variables(self) -> Mapping[str, Any]:
|
||||
runtime_state = self.graph_runtime_state
|
||||
if runtime_state is None:
|
||||
return {}
|
||||
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
@ -3,14 +3,20 @@ GraphEngine Manager for sending control commands via Redis channel.
|
||||
|
||||
This module provides a simplified interface for controlling workflow executions
|
||||
using the new Redis command channel, without requiring user permission checks.
|
||||
Supports stop, pause, and resume operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||
from core.workflow.graph_engine.entities.commands import (
|
||||
AbortCommand,
|
||||
GraphEngineCommand,
|
||||
PauseCommand,
|
||||
UpdateVariablesCommand,
|
||||
VariableUpdate,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -23,7 +29,6 @@ class GraphEngineManager:
|
||||
|
||||
This class provides a simple interface for controlling workflow executions
|
||||
by sending commands through Redis channels, without user validation.
|
||||
Supports stop and pause operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@ -45,6 +50,16 @@ class GraphEngineManager:
|
||||
pause_command = PauseCommand(reason=reason or "User requested pause")
|
||||
GraphEngineManager._send_command(task_id, pause_command)
|
||||
|
||||
@staticmethod
|
||||
def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None:
|
||||
"""Send a command to update variables in a running workflow."""
|
||||
|
||||
if not updates:
|
||||
return
|
||||
|
||||
update_command = UpdateVariablesCommand(updates=updates)
|
||||
GraphEngineManager._send_command(task_id, update_command)
|
||||
|
||||
@staticmethod
|
||||
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
|
||||
"""Send a command to the workflow-specific Redis channel."""
|
||||
|
||||
@ -44,6 +44,7 @@ class Dispatcher:
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
event_handler: "EventHandler",
|
||||
execution_coordinator: ExecutionCoordinator,
|
||||
stop_event: threading.Event,
|
||||
event_emitter: EventManager | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@ -61,7 +62,7 @@ class Dispatcher:
|
||||
self._event_emitter = event_emitter
|
||||
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._stop_event = stop_event
|
||||
self._start_time: float | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
@ -69,16 +70,14 @@ class Dispatcher:
|
||||
if self._thread and self._thread.is_alive():
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._start_time = time.time()
|
||||
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the dispatcher thread."""
|
||||
self._stop_event.set()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=10.0)
|
||||
self._thread.join(timeout=2.0)
|
||||
|
||||
def _dispatcher_loop(self) -> None:
|
||||
"""Main dispatcher loop."""
|
||||
|
||||
@ -42,6 +42,7 @@ class Worker(threading.Thread):
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: Sequence[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
worker_id: int = 0,
|
||||
flask_app: Flask | None = None,
|
||||
context_vars: contextvars.Context | None = None,
|
||||
@ -65,13 +66,16 @@ class Worker(threading.Thread):
|
||||
self._worker_id = worker_id
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
self._stop_event = threading.Event()
|
||||
self._last_task_time = time.time()
|
||||
self._stop_event = stop_event
|
||||
self._layers = layers if layers is not None else []
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to stop processing."""
|
||||
self._stop_event.set()
|
||||
"""Worker is controlled via shared stop_event from GraphEngine.
|
||||
|
||||
This method is a no-op retained for backward compatibility.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_idle(self) -> bool:
|
||||
|
||||
@ -41,6 +41,7 @@ class WorkerPool:
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: list[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
flask_app: "Flask | None" = None,
|
||||
context_vars: "Context | None" = None,
|
||||
min_workers: int | None = None,
|
||||
@ -81,6 +82,7 @@ class WorkerPool:
|
||||
self._worker_counter = 0
|
||||
self._lock = threading.RLock()
|
||||
self._running = False
|
||||
self._stop_event = stop_event
|
||||
|
||||
# No longer tracking worker states with callbacks to avoid lock contention
|
||||
|
||||
@ -135,7 +137,7 @@ class WorkerPool:
|
||||
# Wait for workers to finish
|
||||
for worker in self._workers:
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=10.0)
|
||||
worker.join(timeout=2.0)
|
||||
|
||||
self._workers.clear()
|
||||
|
||||
@ -152,6 +154,7 @@ class WorkerPool:
|
||||
worker_id=worker_id,
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
worker.start()
|
||||
|
||||
@ -264,6 +264,10 @@ class Node(Generic[NodeDataT]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _should_stop(self) -> bool:
|
||||
"""Check if execution should be stopped."""
|
||||
return self.graph_runtime_state.stop_event.is_set()
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
@ -332,6 +336,21 @@ class Node(Generic[NodeDataT]):
|
||||
yield event
|
||||
else:
|
||||
yield event
|
||||
|
||||
if self._should_stop():
|
||||
error_message = "Execution cancelled"
|
||||
yield NodeRunFailedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error_message,
|
||||
),
|
||||
error=error_message,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Node %s failed to run", self._node_id)
|
||||
result = NodeRunResult(
|
||||
|
||||
@ -11,6 +11,11 @@ from core.workflow.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
Jinja2TemplateRenderer,
|
||||
)
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
@ -37,6 +42,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
code_executor: type[CodeExecutor] | None = None,
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits | None = None,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
) -> None:
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
@ -54,6 +60,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: dict[str, object]) -> Node:
|
||||
@ -106,6 +113,14 @@ class DifyNodeFactory(NodeFactory):
|
||||
code_providers=self._code_providers,
|
||||
code_limits=self._code_limits,
|
||||
)
|
||||
if node_type == NodeType.TEMPLATE_TRANSFORM:
|
||||
return TemplateTransformNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
template_renderer=self._template_renderer,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
|
||||
@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
|
||||
|
||||
class TemplateRenderError(ValueError):
|
||||
"""Raised when rendering a Jinja2 template fails."""
|
||||
|
||||
|
||||
class Jinja2TemplateRenderer(Protocol):
|
||||
"""Render Jinja2 templates for template transform nodes."""
|
||||
|
||||
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
|
||||
"""Render a Jinja2 template with provided variables."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
|
||||
"""Adapter that renders Jinja2 templates via CodeExecutor."""
|
||||
|
||||
_code_executor: type[CodeExecutor]
|
||||
|
||||
def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None:
|
||||
self._code_executor = code_executor or CodeExecutor
|
||||
|
||||
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
|
||||
try:
|
||||
result = self._code_executor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=template, inputs=variables
|
||||
)
|
||||
except CodeExecutionError as exc:
|
||||
raise TemplateRenderError(str(exc)) from exc
|
||||
|
||||
rendered = result.get("result")
|
||||
if not isinstance(rendered, str):
|
||||
raise TemplateRenderError("Template render result must be a string.")
|
||||
return rendered
|
||||
@ -1,18 +1,44 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
Jinja2TemplateRenderer,
|
||||
TemplateRenderError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
|
||||
|
||||
class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
_template_renderer: Jinja2TemplateRenderer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
@ -39,13 +65,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
variables[variable_name] = value.to_object() if value else None
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
|
||||
)
|
||||
except CodeExecutionError as e:
|
||||
rendered = self._template_renderer.render_template(self.node_data.template, variables)
|
||||
except TemplateRenderError as e:
|
||||
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
||||
|
||||
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
return NodeRunResult(
|
||||
inputs=variables,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
@ -53,7 +77,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -1,28 +0,0 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables.variables import Variable
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable
|
||||
|
||||
from .exc import VariableOperatorNodeError
|
||||
|
||||
|
||||
class ConversationVariableUpdaterImpl:
|
||||
def update(self, conversation_id: str, variable: Variable):
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableOperatorNodeError("conversation variable not found in the database")
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
|
||||
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
|
||||
return ConversationVariableUpdaterImpl()
|
||||
@ -1,9 +1,8 @@
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
@ -11,19 +10,14 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
|
||||
from ..common.impl import conversation_variable_updater_factory
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
|
||||
|
||||
|
||||
class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
node_type = NodeType.VARIABLE_ASSIGNER
|
||||
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -31,7 +25,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
@ -39,7 +32,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._conv_var_updater_factory = conv_var_updater_factory
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
@ -96,16 +88,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
# Over write the variable.
|
||||
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
|
||||
|
||||
# TODO: Move database operation to the pipeline.
|
||||
# Update conversation variable.
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
||||
if not conversation_id:
|
||||
raise VariableOperatorNodeError("conversation_id not found")
|
||||
conv_var_updater = self._conv_var_updater_factory()
|
||||
conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
conv_var_updater.flush()
|
||||
updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)]
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={
|
||||
|
||||
@ -1,24 +1,20 @@
|
||||
import json
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
||||
|
||||
from . import helpers
|
||||
from .entities import VariableAssignerNodeData, VariableOperationItem
|
||||
from .enums import InputType, Operation
|
||||
from .exc import (
|
||||
ConversationIDNotFoundError,
|
||||
InputTypeNotSupportedError,
|
||||
InvalidDataError,
|
||||
InvalidInputValueError,
|
||||
@ -26,6 +22,10 @@ from .exc import (
|
||||
VariableNotFoundError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
|
||||
selector_node_id = item.variable_selector[0]
|
||||
@ -53,6 +53,20 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
|
||||
class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
node_type = NodeType.VARIABLE_ASSIGNER
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
|
||||
"""
|
||||
Check if this Variable Assigner node blocks the output of specific variables.
|
||||
@ -70,9 +84,6 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
|
||||
return False
|
||||
|
||||
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
|
||||
return conversation_variable_updater_factory()
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "2"
|
||||
@ -179,26 +190,12 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
# remove the duplicated items first.
|
||||
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))
|
||||
|
||||
conv_var_updater = self._conv_var_updater_factory()
|
||||
# Update variables
|
||||
for selector in updated_variable_selectors:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if not isinstance(variable, Variable):
|
||||
raise VariableNotFoundError(variable_selector=selector)
|
||||
process_data[variable.name] = variable.value
|
||||
|
||||
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
||||
if not conversation_id:
|
||||
if self.invoke_from != InvokeFrom.DEBUGGER:
|
||||
raise ConversationIDNotFoundError
|
||||
else:
|
||||
conversation_id = conversation_id.value
|
||||
conv_var_updater.update(
|
||||
conversation_id=cast(str, conversation_id),
|
||||
variable=variable,
|
||||
)
|
||||
conv_var_updater.flush()
|
||||
updated_variables = [
|
||||
common_helpers.variable_to_processed_data(selector, seg)
|
||||
for selector in updated_variable_selectors
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import threading
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
@ -168,6 +169,7 @@ class GraphRuntimeState:
|
||||
self._pending_response_coordinator_dump: str | None = None
|
||||
self._pending_graph_execution_workflow_id: str | None = None
|
||||
self._paused_nodes: set[str] = set()
|
||||
self.stop_event: threading.Event = threading.Event()
|
||||
|
||||
if graph is not None:
|
||||
self.attach_graph(graph)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
@ -9,7 +9,7 @@ from core.workflow.system_variable import SystemVariableReadOnlyView
|
||||
class ReadOnlyVariablePool(Protocol):
|
||||
"""Read-only interface for VariablePool."""
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||
"""Get a variable value (read-only)."""
|
||||
...
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
@ -18,9 +18,9 @@ class ReadOnlyVariablePoolWrapper:
|
||||
def __init__(self, variable_pool: VariablePool) -> None:
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||
"""Return a copy of a variable value if present."""
|
||||
value = self._variable_pool.get([node_id, variable_key])
|
||||
value = self._variable_pool.get(selector)
|
||||
return deepcopy(value) if value is not None else None
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
|
||||
|
||||
@ -3,8 +3,9 @@
|
||||
set -e
|
||||
|
||||
# Set UTF-8 encoding to address potential encoding issues in containerized environments
|
||||
export LANG=${LANG:-en_US.UTF-8}
|
||||
export LC_ALL=${LC_ALL:-en_US.UTF-8}
|
||||
# Use C.UTF-8 which is universally available in all containers
|
||||
export LANG=${LANG:-C.UTF-8}
|
||||
export LC_ALL=${LC_ALL:-C.UTF-8}
|
||||
export PYTHONIOENCODING=${PYTHONIOENCODING:-utf-8}
|
||||
|
||||
if [[ "${MIGRATION_ENABLED}" == "true" ]]; then
|
||||
|
||||
@ -6,6 +6,7 @@ BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEAD
|
||||
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
|
||||
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||
EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE)
|
||||
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
|
||||
|
||||
|
||||
@ -42,10 +43,28 @@ def init_app(app: DifyApp):
|
||||
|
||||
_apply_cors_once(
|
||||
web_bp,
|
||||
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=list(AUTHENTICATED_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
resources={
|
||||
# Embedded bot endpoints (unauthenticated, cross-origin safe)
|
||||
r"^/chat-messages$": {
|
||||
"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
|
||||
"supports_credentials": False,
|
||||
"allow_headers": list(EMBED_HEADERS),
|
||||
"methods": ["GET", "POST", "OPTIONS"],
|
||||
},
|
||||
r"^/chat-messages/.*": {
|
||||
"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
|
||||
"supports_credentials": False,
|
||||
"allow_headers": list(EMBED_HEADERS),
|
||||
"methods": ["GET", "POST", "OPTIONS"],
|
||||
},
|
||||
# Default web application endpoints (authenticated)
|
||||
r"/*": {
|
||||
"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
|
||||
"supports_credentials": True,
|
||||
"allow_headers": list(AUTHENTICATED_HEADERS),
|
||||
"methods": ["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
},
|
||||
},
|
||||
expose_headers=list(EXPOSED_HEADERS),
|
||||
)
|
||||
app.register_blueprint(web_bp)
|
||||
|
||||
@ -11,6 +11,7 @@ def init_app(app: DifyApp):
|
||||
create_tenant,
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
file_usage,
|
||||
fix_app_site_missing,
|
||||
install_plugins,
|
||||
install_rag_pipeline_plugins,
|
||||
@ -47,6 +48,7 @@ def init_app(app: DifyApp):
|
||||
clear_free_plan_tenant_expired_logs,
|
||||
clear_orphaned_file_records,
|
||||
remove_orphaned_files_on_storage,
|
||||
file_usage,
|
||||
setup_system_tool_oauth_client,
|
||||
setup_system_trigger_oauth_client,
|
||||
cleanup_orphaned_draft_variables,
|
||||
|
||||
@ -53,3 +53,10 @@ def _setup_gevent_compatibility():
|
||||
def init_app(app: DifyApp):
|
||||
db.init_app(app)
|
||||
_setup_gevent_compatibility()
|
||||
|
||||
# Eagerly build the engine so pool_size/max_overflow/etc. come from config
|
||||
try:
|
||||
with app.app_context():
|
||||
_ = db.engine # triggers engine creation with the configured options
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize SQLAlchemy engine during app startup")
|
||||
|
||||
@ -22,6 +22,18 @@ from models.enums import WorkflowRunTriggeredFrom
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def to_serializable(obj):
|
||||
"""
|
||||
Convert non-JSON-serializable objects into JSON-compatible formats.
|
||||
|
||||
- Uses `to_dict()` if it's a callable method.
|
||||
- Falls back to string representation.
|
||||
"""
|
||||
if hasattr(obj, "to_dict") and callable(obj.to_dict):
|
||||
return obj.to_dict()
|
||||
return str(obj)
|
||||
|
||||
|
||||
class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
def __init__(
|
||||
self,
|
||||
@ -69,6 +81,11 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
# Set to True to enable dual-write for safe migration, False to use LogStore only
|
||||
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
|
||||
|
||||
# Control flag for whether to write the `graph` field to LogStore.
|
||||
# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
|
||||
# otherwise write an empty {} instead. Defaults to writing the `graph` field.
|
||||
self._enable_put_graph_field = os.environ.get("LOGSTORE_ENABLE_PUT_GRAPH_FIELD", "true").lower() == "true"
|
||||
|
||||
def _to_logstore_model(self, domain_model: WorkflowExecution) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Convert a domain model to a logstore model (List[Tuple[str, str]]).
|
||||
@ -108,9 +125,24 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
),
|
||||
("type", domain_model.workflow_type.value),
|
||||
("version", domain_model.workflow_version),
|
||||
("graph", json.dumps(domain_model.graph, ensure_ascii=False) if domain_model.graph else "{}"),
|
||||
("inputs", json.dumps(domain_model.inputs, ensure_ascii=False) if domain_model.inputs else "{}"),
|
||||
("outputs", json.dumps(domain_model.outputs, ensure_ascii=False) if domain_model.outputs else "{}"),
|
||||
(
|
||||
"graph",
|
||||
json.dumps(domain_model.graph, ensure_ascii=False, default=to_serializable)
|
||||
if domain_model.graph and self._enable_put_graph_field
|
||||
else "{}",
|
||||
),
|
||||
(
|
||||
"inputs",
|
||||
json.dumps(domain_model.inputs, ensure_ascii=False, default=to_serializable)
|
||||
if domain_model.inputs
|
||||
else "{}",
|
||||
),
|
||||
(
|
||||
"outputs",
|
||||
json.dumps(domain_model.outputs, ensure_ascii=False, default=to_serializable)
|
||||
if domain_model.outputs
|
||||
else "{}",
|
||||
),
|
||||
("status", domain_model.status.value),
|
||||
("error_message", domain_model.error_message or ""),
|
||||
("total_tokens", str(domain_model.total_tokens)),
|
||||
|
||||
@ -32,6 +32,38 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def escape_like_pattern(pattern: str) -> str:
|
||||
"""
|
||||
Escape special characters in a string for safe use in SQL LIKE patterns.
|
||||
|
||||
This function escapes the special characters used in SQL LIKE patterns:
|
||||
- Backslash (\\) -> \\
|
||||
- Percent (%) -> \\%
|
||||
- Underscore (_) -> \\_
|
||||
|
||||
The escaped pattern can then be safely used in SQL LIKE queries with the
|
||||
ESCAPE '\\' clause to prevent SQL injection via LIKE wildcards.
|
||||
|
||||
Args:
|
||||
pattern: The string pattern to escape
|
||||
|
||||
Returns:
|
||||
Escaped string safe for use in SQL LIKE queries
|
||||
|
||||
Examples:
|
||||
>>> escape_like_pattern("50% discount")
|
||||
'50\\% discount'
|
||||
>>> escape_like_pattern("test_data")
|
||||
'test\\_data'
|
||||
>>> escape_like_pattern("path\\to\\file")
|
||||
'path\\\\to\\\\file'
|
||||
"""
|
||||
if not pattern:
|
||||
return pattern
|
||||
# Escape backslash first, then percent and underscore
|
||||
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
|
||||
|
||||
def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
|
||||
"""
|
||||
Extract tenant_id from Account or EndUser object.
|
||||
|
||||
@ -70,6 +70,7 @@ class AppMode(StrEnum):
|
||||
class IconType(StrEnum):
|
||||
IMAGE = auto()
|
||||
EMOJI = auto()
|
||||
LINK = auto()
|
||||
|
||||
|
||||
class App(Base):
|
||||
@ -81,7 +82,7 @@ class App(Base):
|
||||
name: Mapped[str] = mapped_column(String(255))
|
||||
description: Mapped[str] = mapped_column(LongText, default=sa.text("''"))
|
||||
mode: Mapped[str] = mapped_column(String(255))
|
||||
icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji
|
||||
icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji, link
|
||||
icon = mapped_column(String(255))
|
||||
icon_background: Mapped[str | None] = mapped_column(String(255))
|
||||
app_model_config_id = mapped_column(StringUUID, nullable=True)
|
||||
@ -1419,15 +1420,20 @@ class MessageAnnotation(Base):
|
||||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"))
|
||||
message_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
question = mapped_column(LongText, nullable=True)
|
||||
content = mapped_column(LongText, nullable=False)
|
||||
question: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
content: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def question_text(self) -> str:
|
||||
"""Return a non-null question string, falling back to the answer content."""
|
||||
return self.question or self.content
|
||||
|
||||
@property
|
||||
def account(self):
|
||||
account = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
|
||||
@ -1506,6 +1506,7 @@ class WorkflowDraftVariable(Base):
|
||||
file_id: str | None = None,
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = WorkflowDraftVariable()
|
||||
variable.id = str(uuid4())
|
||||
variable.created_at = naive_utc_now()
|
||||
variable.updated_at = naive_utc_now()
|
||||
variable.description = description
|
||||
|
||||
@ -77,7 +77,7 @@ class AppAnnotationService:
|
||||
if annotation_setting:
|
||||
add_annotation_to_index_task.delay(
|
||||
annotation.id,
|
||||
annotation.question,
|
||||
question,
|
||||
current_tenant_id,
|
||||
app_id,
|
||||
annotation_setting.collection_binding_id,
|
||||
@ -137,13 +137,16 @@ class AppAnnotationService:
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
if keyword:
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_keyword = escape_like_pattern(keyword)
|
||||
stmt = (
|
||||
select(MessageAnnotation)
|
||||
.where(MessageAnnotation.app_id == app_id)
|
||||
.where(
|
||||
or_(
|
||||
MessageAnnotation.question.ilike(f"%{keyword}%"),
|
||||
MessageAnnotation.content.ilike(f"%{keyword}%"),
|
||||
MessageAnnotation.question.ilike(f"%{escaped_keyword}%", escape="\\"),
|
||||
MessageAnnotation.content.ilike(f"%{escaped_keyword}%", escape="\\"),
|
||||
)
|
||||
)
|
||||
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
|
||||
@ -253,7 +256,7 @@ class AppAnnotationService:
|
||||
if app_annotation_setting:
|
||||
update_annotation_to_index_task.delay(
|
||||
annotation.id,
|
||||
annotation.question,
|
||||
annotation.question_text,
|
||||
current_tenant_id,
|
||||
app_id,
|
||||
app_annotation_setting.collection_binding_id,
|
||||
|
||||
@ -32,7 +32,7 @@ from extensions.ext_redis import redis_client
|
||||
from factories import variable_factory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, App, AppMode
|
||||
from models.model import AppModelConfig
|
||||
from models.model import AppModelConfig, IconType
|
||||
from models.workflow import Workflow
|
||||
from services.plugin.dependencies_analysis import DependenciesAnalysisService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableService
|
||||
@ -428,10 +428,10 @@ class AppDslService:
|
||||
|
||||
# Set icon type
|
||||
icon_type_value = icon_type or app_data.get("icon_type")
|
||||
if icon_type_value in ["emoji", "link", "image"]:
|
||||
if icon_type_value in [IconType.EMOJI.value, IconType.IMAGE.value, IconType.LINK.value]:
|
||||
icon_type = icon_type_value
|
||||
else:
|
||||
icon_type = "emoji"
|
||||
icon_type = IconType.EMOJI.value
|
||||
icon = icon or str(app_data.get("icon", ""))
|
||||
|
||||
if app:
|
||||
|
||||
@ -55,8 +55,11 @@ class AppService:
|
||||
if args.get("is_created_by_me", False):
|
||||
filters.append(App.created_by == user_id)
|
||||
if args.get("name"):
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
name = args["name"][:30]
|
||||
filters.append(App.name.ilike(f"%{name}%"))
|
||||
escaped_name = escape_like_pattern(name)
|
||||
filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\"))
|
||||
# Check if tag_ids is not empty to avoid WHERE false condition
|
||||
if args.get("tag_ids") and len(args["tag_ids"]) > 0:
|
||||
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"])
|
||||
|
||||
@ -11,13 +11,13 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.db.session_factory import session_factory
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account, ConversationVariable
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from services.conversation_variable_updater import ConversationVariableUpdater
|
||||
from services.errors.conversation import (
|
||||
ConversationNotExistsError,
|
||||
ConversationVariableNotExistsError,
|
||||
@ -218,7 +218,9 @@ class ConversationService:
|
||||
# Apply variable_name filter if provided
|
||||
if variable_name:
|
||||
# Filter using JSON extraction to match variable names case-insensitively
|
||||
escaped_variable_name = variable_name.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_variable_name = escape_like_pattern(variable_name)
|
||||
# Filter using JSON extraction to match variable names case-insensitively
|
||||
if dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]:
|
||||
stmt = stmt.where(
|
||||
@ -335,7 +337,7 @@ class ConversationService:
|
||||
updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict)
|
||||
|
||||
# Use the conversation variable updater to persist the changes
|
||||
updater = conversation_variable_updater_factory()
|
||||
updater = ConversationVariableUpdater(session_factory.get_session_maker())
|
||||
updater.update(conversation_id, updated_variable)
|
||||
updater.flush()
|
||||
|
||||
|
||||
28
api/services/conversation_variable_updater.py
Normal file
28
api/services/conversation_variable_updater.py
Normal file
@ -0,0 +1,28 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.variables.variables import Variable
|
||||
from models import ConversationVariable
|
||||
|
||||
|
||||
class ConversationVariableNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ConversationVariableUpdater:
|
||||
def __init__(self, session_maker: sessionmaker[Session]) -> None:
|
||||
self._session_maker: sessionmaker[Session] = session_maker
|
||||
|
||||
def update(self, conversation_id: str, variable: Variable) -> None:
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
with self._session_maker() as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise ConversationVariableNotFoundError("conversation variable not found in the database")
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
|
||||
def flush(self) -> None:
|
||||
pass
|
||||
@ -144,7 +144,8 @@ class DatasetService:
|
||||
query = query.where(Dataset.permission == DatasetPermissionEnum.ALL_TEAM)
|
||||
|
||||
if search:
|
||||
query = query.where(Dataset.name.ilike(f"%{search}%"))
|
||||
escaped_search = helper.escape_like_pattern(search)
|
||||
query = query.where(Dataset.name.ilike(f"%{escaped_search}%", escape="\\"))
|
||||
|
||||
# Check if tag_ids is not empty to avoid WHERE false condition
|
||||
if tag_ids and len(tag_ids) > 0:
|
||||
@ -3423,7 +3424,8 @@ class SegmentService:
|
||||
.order_by(ChildChunk.position.asc())
|
||||
)
|
||||
if keyword:
|
||||
query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
|
||||
escaped_keyword = helper.escape_like_pattern(keyword)
|
||||
query = query.where(ChildChunk.content.ilike(f"%{escaped_keyword}%", escape="\\"))
|
||||
return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
@classmethod
|
||||
@ -3456,7 +3458,8 @@ class SegmentService:
|
||||
query = query.where(DocumentSegment.status.in_(status_list))
|
||||
|
||||
if keyword:
|
||||
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
|
||||
escaped_keyword = helper.escape_like_pattern(keyword)
|
||||
query = query.where(DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"))
|
||||
|
||||
query = query.order_by(DocumentSegment.position.asc(), DocumentSegment.id.asc())
|
||||
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
@ -35,7 +35,10 @@ class ExternalDatasetService:
|
||||
.order_by(ExternalKnowledgeApis.created_at.desc())
|
||||
)
|
||||
if search:
|
||||
query = query.where(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_search = escape_like_pattern(search)
|
||||
query = query.where(ExternalKnowledgeApis.name.ilike(f"%{escaped_search}%", escape="\\"))
|
||||
|
||||
external_knowledge_apis = db.paginate(
|
||||
select=query, page=page, per_page=per_page, max_per_page=100, error_out=False
|
||||
|
||||
@ -19,7 +19,10 @@ class TagService:
|
||||
.where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
|
||||
)
|
||||
if keyword:
|
||||
query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%")))
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_keyword = escape_like_pattern(keyword)
|
||||
query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\")))
|
||||
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
|
||||
results: list = query.order_by(Tag.created_at.desc()).all()
|
||||
return results
|
||||
|
||||
@ -853,7 +853,7 @@ class TriggerProviderService:
|
||||
"""
|
||||
Create a subscription builder for rebuilding an existing subscription.
|
||||
|
||||
This method creates a builder pre-filled with data from the rebuild request,
|
||||
This method rebuild the subscription by call DELETE and CREATE API of the third party provider(e.g. GitHub)
|
||||
keeping the same subscription_id and endpoint_id so the webhook URL remains unchanged.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
@ -868,111 +868,50 @@ class TriggerProviderService:
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
# Use distributed lock to prevent race conditions on the same subscription
|
||||
lock_key = f"trigger_subscription_rebuild_lock:{tenant_id}_{subscription_id}"
|
||||
with redis_client.lock(lock_key, timeout=20):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
try:
|
||||
# Get subscription within the transaction
|
||||
subscription: TriggerSubscription | None = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError(f"Subscription {subscription_id} not found")
|
||||
subscription = TriggerProviderService.get_subscription_by_id(
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError(f"Subscription {subscription_id} not found")
|
||||
|
||||
credential_type = CredentialType.of(subscription.credential_type)
|
||||
if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
|
||||
raise ValueError("Credential type not supported for rebuild")
|
||||
credential_type = CredentialType.of(subscription.credential_type)
|
||||
if credential_type not in {CredentialType.OAUTH2, CredentialType.API_KEY}:
|
||||
raise ValueError(f"Credential type {credential_type} not supported for auto creation")
|
||||
|
||||
# Decrypt existing credentials for merging
|
||||
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
decrypted_credentials = dict(credential_encrypter.decrypt(subscription.credentials))
|
||||
# Delete the previous subscription
|
||||
user_id = subscription.user_id
|
||||
unsubscribe_result = TriggerManager.unsubscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
subscription=subscription.to_entity(),
|
||||
credentials=subscription.credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
if not unsubscribe_result.success:
|
||||
raise ValueError(f"Failed to delete previous subscription: {unsubscribe_result.message}")
|
||||
|
||||
# Merge credentials: if caller passed HIDDEN_VALUE, retain existing decrypted value
|
||||
merged_credentials: dict[str, Any] = {
|
||||
key: value if value != HIDDEN_VALUE else decrypted_credentials.get(key, UNKNOWN_VALUE)
|
||||
for key, value in credentials.items()
|
||||
}
|
||||
|
||||
user_id = subscription.user_id
|
||||
|
||||
# TODO: Trying to invoke update api of the plugin trigger provider
|
||||
|
||||
# FALLBACK: If the update api is not implemented,
|
||||
# delete the previous subscription and create a new one
|
||||
|
||||
# Unsubscribe the previous subscription (external call, but we'll handle errors)
|
||||
try:
|
||||
TriggerManager.unsubscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
subscription=subscription.to_entity(),
|
||||
credentials=decrypted_credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error unsubscribing trigger during rebuild", exc_info=e)
|
||||
# Continue anyway - the subscription might already be deleted externally
|
||||
|
||||
# Create a new subscription with the same subscription_id and endpoint_id (external call)
|
||||
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
|
||||
parameters=parameters,
|
||||
credentials=merged_credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
|
||||
# Update the subscription in the same transaction
|
||||
# Inline update logic to reuse the same session
|
||||
if name is not None and name != subscription.name:
|
||||
existing = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
|
||||
.first()
|
||||
)
|
||||
if existing and existing.id != subscription.id:
|
||||
raise ValueError(f"Subscription name '{name}' already exists for this provider")
|
||||
subscription.name = name
|
||||
|
||||
# Update parameters
|
||||
subscription.parameters = dict(parameters)
|
||||
|
||||
# Update credentials with merged (and encrypted) values
|
||||
subscription.credentials = dict(credential_encrypter.encrypt(merged_credentials))
|
||||
|
||||
# Update properties
|
||||
if new_subscription.properties:
|
||||
properties_encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_properties_schema(),
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
subscription.properties = dict(properties_encrypter.encrypt(dict(new_subscription.properties)))
|
||||
|
||||
# Update expiration timestamp
|
||||
if new_subscription.expires_at is not None:
|
||||
subscription.expires_at = new_subscription.expires_at
|
||||
|
||||
# Commit the transaction
|
||||
session.commit()
|
||||
|
||||
# Clear subscription cache
|
||||
delete_cache_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=subscription.provider_id,
|
||||
subscription_id=subscription.id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on any error
|
||||
session.rollback()
|
||||
logger.exception("Failed to rebuild trigger subscription", exc_info=e)
|
||||
raise
|
||||
# Create a new subscription with the same subscription_id and endpoint_id
|
||||
new_credentials: dict[str, Any] = {
|
||||
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
|
||||
for key, value in credentials.items()
|
||||
}
|
||||
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
|
||||
parameters=parameters,
|
||||
credentials=new_credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
TriggerProviderService.update_trigger_subscription(
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription.id,
|
||||
name=name,
|
||||
parameters=parameters,
|
||||
credentials=new_credentials,
|
||||
properties=new_subscription.properties,
|
||||
expires_at=new_subscription.expires_at,
|
||||
)
|
||||
|
||||
@ -86,12 +86,19 @@ class WorkflowAppService:
|
||||
# Join to workflow run for filtering when needed.
|
||||
|
||||
if keyword:
|
||||
keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
# Escape special characters in keyword to prevent SQL injection via LIKE wildcards
|
||||
escaped_keyword = escape_like_pattern(keyword[:30])
|
||||
keyword_like_val = f"%{escaped_keyword}%"
|
||||
keyword_conditions = [
|
||||
WorkflowRun.inputs.ilike(keyword_like_val),
|
||||
WorkflowRun.outputs.ilike(keyword_like_val),
|
||||
WorkflowRun.inputs.ilike(keyword_like_val, escape="\\"),
|
||||
WorkflowRun.outputs.ilike(keyword_like_val, escape="\\"),
|
||||
# filter keyword by end user session id if created by end user role
|
||||
and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)),
|
||||
and_(
|
||||
WorkflowRun.created_by_role == "end_user",
|
||||
EndUser.session_id.ilike(keyword_like_val, escape="\\"),
|
||||
),
|
||||
]
|
||||
|
||||
# filter keyword by workflow run id
|
||||
|
||||
@ -679,6 +679,7 @@ def _batch_upsert_draft_variable(
|
||||
|
||||
def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]:
|
||||
d: dict[str, Any] = {
|
||||
"id": model.id,
|
||||
"app_id": model.app_id,
|
||||
"last_edited_at": None,
|
||||
"node_id": model.node_id,
|
||||
|
||||
@ -98,7 +98,7 @@ def enable_annotation_reply_task(
|
||||
if annotations:
|
||||
for annotation in annotations:
|
||||
document = Document(
|
||||
page_content=annotation.question,
|
||||
page_content=annotation.question_text,
|
||||
metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
|
||||
)
|
||||
documents.append(document)
|
||||
|
||||
@ -35,6 +35,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
|
||||
@ -569,10 +570,10 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
"""Test that layer requires proper initialization before handling events."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
# Don't initialize - graph_runtime_state should not be set
|
||||
# Don't initialize - graph_runtime_state should be uninitialized
|
||||
|
||||
event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
|
||||
|
||||
# Act & Assert - Should raise AttributeError
|
||||
with pytest.raises(AttributeError):
|
||||
# Act & Assert - Should raise GraphEngineLayerNotInitializedError
|
||||
with pytest.raises(GraphEngineLayerNotInitializedError):
|
||||
layer.on_event(event)
|
||||
|
||||
@ -444,6 +444,78 @@ class TestAnnotationService:
|
||||
assert total == 1
|
||||
assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content
|
||||
|
||||
def test_get_annotation_list_by_app_id_with_special_characters_in_keyword(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
r"""
|
||||
Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention.
|
||||
|
||||
This test verifies:
|
||||
- Special characters (%, _, \) in keyword are properly escaped
|
||||
- Search treats special characters as literal characters, not wildcards
|
||||
- SQL injection via LIKE wildcards is prevented
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create annotations with special characters in content
|
||||
annotation_with_percent = {
|
||||
"question": "Question with 50% discount",
|
||||
"answer": "Answer about 50% discount offer",
|
||||
}
|
||||
AppAnnotationService.insert_app_annotation_directly(annotation_with_percent, app.id)
|
||||
|
||||
annotation_with_underscore = {
|
||||
"question": "Question with test_data",
|
||||
"answer": "Answer about test_data value",
|
||||
}
|
||||
AppAnnotationService.insert_app_annotation_directly(annotation_with_underscore, app.id)
|
||||
|
||||
annotation_with_backslash = {
|
||||
"question": "Question with path\\to\\file",
|
||||
"answer": "Answer about path\\to\\file location",
|
||||
}
|
||||
AppAnnotationService.insert_app_annotation_directly(annotation_with_backslash, app.id)
|
||||
|
||||
# Create annotation that should NOT match (contains % but as part of different text)
|
||||
annotation_no_match = {
|
||||
"question": "Question with 100% different",
|
||||
"answer": "Answer about 100% different content",
|
||||
}
|
||||
AppAnnotationService.insert_app_annotation_directly(annotation_no_match, app.id)
|
||||
|
||||
# Test 1: Search with % character - should find exact match only
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
|
||||
app.id, page=1, limit=10, keyword="50%"
|
||||
)
|
||||
assert total == 1
|
||||
assert len(annotation_list) == 1
|
||||
assert "50%" in annotation_list[0].question or "50%" in annotation_list[0].content
|
||||
|
||||
# Test 2: Search with _ character - should find exact match only
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
|
||||
app.id, page=1, limit=10, keyword="test_data"
|
||||
)
|
||||
assert total == 1
|
||||
assert len(annotation_list) == 1
|
||||
assert "test_data" in annotation_list[0].question or "test_data" in annotation_list[0].content
|
||||
|
||||
# Test 3: Search with \ character - should find exact match only
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
|
||||
app.id, page=1, limit=10, keyword="path\\to\\file"
|
||||
)
|
||||
assert total == 1
|
||||
assert len(annotation_list) == 1
|
||||
assert "path\\to\\file" in annotation_list[0].question or "path\\to\\file" in annotation_list[0].content
|
||||
|
||||
# Test 4: Search with % should NOT match 100% (verifies escaping works)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
|
||||
app.id, page=1, limit=10, keyword="50%"
|
||||
)
|
||||
# Should only find the 50% annotation, not the 100% one
|
||||
assert total == 1
|
||||
assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list)
|
||||
|
||||
def test_get_annotation_list_by_app_id_app_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
|
||||
@ -7,7 +7,9 @@ from constants.model_template import default_app_templates
|
||||
from models import Account
|
||||
from models.model import App, Site
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
|
||||
# Delay import of AppService to avoid circular dependency
|
||||
# from services.app_service import AppService
|
||||
|
||||
|
||||
class TestAppService:
|
||||
@ -71,6 +73,9 @@ class TestAppService:
|
||||
}
|
||||
|
||||
# Create app
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -109,6 +114,9 @@ class TestAppService:
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Test different app modes
|
||||
@ -159,6 +167,9 @@ class TestAppService:
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
created_app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -194,6 +205,9 @@ class TestAppService:
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create multiple apps
|
||||
@ -245,6 +259,9 @@ class TestAppService:
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create apps with different modes
|
||||
@ -315,6 +332,9 @@ class TestAppService:
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create an app
|
||||
@ -392,6 +412,9 @@ class TestAppService:
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -458,6 +481,9 @@ class TestAppService:
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -508,6 +534,9 @@ class TestAppService:
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -562,6 +591,9 @@ class TestAppService:
|
||||
"icon_background": "#74B9FF",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -617,6 +649,9 @@ class TestAppService:
|
||||
"icon_background": "#A29BFE",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -672,6 +707,9 @@ class TestAppService:
|
||||
"icon_background": "#FD79A8",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -720,6 +758,9 @@ class TestAppService:
|
||||
"icon_background": "#E17055",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -768,6 +809,9 @@ class TestAppService:
|
||||
"icon_background": "#00B894",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -826,6 +870,9 @@ class TestAppService:
|
||||
"icon_background": "#6C5CE7",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -862,6 +909,9 @@ class TestAppService:
|
||||
"icon_background": "#FDCB6E",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -899,6 +949,9 @@ class TestAppService:
|
||||
"icon_background": "#E84393",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -947,8 +1000,132 @@ class TestAppService:
|
||||
"icon_background": "#D63031",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Attempt to create app with invalid mode
|
||||
with pytest.raises(ValueError, match="invalid mode value"):
|
||||
app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
def test_get_apps_with_special_characters_in_name(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
r"""
|
||||
Test app retrieval with special characters in name search to verify SQL injection prevention.
|
||||
|
||||
This test verifies:
|
||||
- Special characters (%, _, \) in name search are properly escaped
|
||||
- Search treats special characters as literal characters, not wildcards
|
||||
- SQL injection via LIKE wildcards is prevented
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create apps with special characters in names
|
||||
app_with_percent = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": "App with 50% discount",
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
},
|
||||
account,
|
||||
)
|
||||
|
||||
app_with_underscore = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": "test_data_app",
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
},
|
||||
account,
|
||||
)
|
||||
|
||||
app_with_backslash = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": "path\\to\\app",
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
},
|
||||
account,
|
||||
)
|
||||
|
||||
# Create app that should NOT match
|
||||
app_no_match = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": "100% different",
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
},
|
||||
account,
|
||||
)
|
||||
|
||||
# Test 1: Search with % character
|
||||
args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10}
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
assert len(paginated_apps.items) == 1
|
||||
assert paginated_apps.items[0].name == "App with 50% discount"
|
||||
|
||||
# Test 2: Search with _ character
|
||||
args = {"name": "test_data", "mode": "chat", "page": 1, "limit": 10}
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
assert len(paginated_apps.items) == 1
|
||||
assert paginated_apps.items[0].name == "test_data_app"
|
||||
|
||||
# Test 3: Search with \ character
|
||||
args = {"name": "path\\to\\app", "mode": "chat", "page": 1, "limit": 10}
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
assert len(paginated_apps.items) == 1
|
||||
assert paginated_apps.items[0].name == "path\\to\\app"
|
||||
|
||||
# Test 4: Search with % should NOT match 100% (verifies escaping works)
|
||||
args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10}
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
assert all("50%" in app.name for app in paginated_apps.items)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from unittest.mock import create_autospec, patch
|
||||
|
||||
import pytest
|
||||
@ -312,6 +313,85 @@ class TestTagService:
|
||||
result_no_match = TagService.get_tags("app", tenant.id, keyword="nonexistent")
|
||||
assert len(result_no_match) == 0
|
||||
|
||||
def test_get_tags_with_special_characters_in_keyword(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
r"""
|
||||
Test tag retrieval with special characters in keyword to verify SQL injection prevention.
|
||||
|
||||
This test verifies:
|
||||
- Special characters (%, _, \) in keyword are properly escaped
|
||||
- Search treats special characters as literal characters, not wildcards
|
||||
- SQL injection via LIKE wildcards is prevented
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create tags with special characters in names
|
||||
tag_with_percent = Tag(
|
||||
name="50% discount",
|
||||
type="app",
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
)
|
||||
tag_with_percent.id = str(uuid.uuid4())
|
||||
db.session.add(tag_with_percent)
|
||||
|
||||
tag_with_underscore = Tag(
|
||||
name="test_data_tag",
|
||||
type="app",
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
)
|
||||
tag_with_underscore.id = str(uuid.uuid4())
|
||||
db.session.add(tag_with_underscore)
|
||||
|
||||
tag_with_backslash = Tag(
|
||||
name="path\\to\\tag",
|
||||
type="app",
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
)
|
||||
tag_with_backslash.id = str(uuid.uuid4())
|
||||
db.session.add(tag_with_backslash)
|
||||
|
||||
# Create tag that should NOT match
|
||||
tag_no_match = Tag(
|
||||
name="100% different",
|
||||
type="app",
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
)
|
||||
tag_no_match.id = str(uuid.uuid4())
|
||||
db.session.add(tag_no_match)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Act & Assert: Test 1 - Search with % character
|
||||
result = TagService.get_tags("app", tenant.id, keyword="50%")
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "50% discount"
|
||||
|
||||
# Test 2 - Search with _ character
|
||||
result = TagService.get_tags("app", tenant.id, keyword="test_data")
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "test_data_tag"
|
||||
|
||||
# Test 3 - Search with \ character
|
||||
result = TagService.get_tags("app", tenant.id, keyword="path\\to\\tag")
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "path\\to\\tag"
|
||||
|
||||
# Test 4 - Search with % should NOT match 100% (verifies escaping works)
|
||||
result = TagService.get_tags("app", tenant.id, keyword="50%")
|
||||
assert len(result) == 1
|
||||
assert all("50%" in item.name for item in result)
|
||||
|
||||
def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test tag retrieval when no tags exist.
|
||||
|
||||
@ -474,64 +474,6 @@ class TestTriggerProviderService:
|
||||
assert subscription.name == original_name
|
||||
assert subscription.parameters == original_parameters
|
||||
|
||||
def test_rebuild_trigger_subscription_unsubscribe_error_continues(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that unsubscribe errors are handled gracefully and operation continues.
|
||||
|
||||
This test verifies:
|
||||
- Unsubscribe errors are caught and logged but don't stop the rebuild
|
||||
- Rebuild continues even if unsubscribe fails
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
original_credentials = {"api_key": "original-key"}
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
original_credentials,
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
# Make unsubscribe_trigger raise an error (should be caught and continue)
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.side_effect = ValueError(
|
||||
"Unsubscribe failed"
|
||||
)
|
||||
|
||||
new_subscription_entity = TriggerSubscriptionEntity(
|
||||
endpoint=subscription.endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
expires_at=-1,
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||
|
||||
# Execute rebuild - should succeed despite unsubscribe error
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials={"api_key": "new-key"},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
# Verify subscribe was still called (operation continued)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once()
|
||||
|
||||
# Verify subscription was updated
|
||||
db.session.refresh(subscription)
|
||||
assert subscription.parameters == {}
|
||||
|
||||
def test_rebuild_trigger_subscription_subscription_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
@ -558,70 +500,6 @@ class TestTriggerProviderService:
|
||||
parameters={},
|
||||
)
|
||||
|
||||
def test_rebuild_trigger_subscription_provider_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when provider is not found.
|
||||
|
||||
This test verifies:
|
||||
- Proper error is raised when provider doesn't exist
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("non_existent_org/non_existent_plugin/non_existent_provider")
|
||||
|
||||
# Make get_trigger_provider return None
|
||||
mock_external_service_dependencies["trigger_manager"].get_trigger_provider.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="Provider.*not found"):
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=fake.uuid4(),
|
||||
credentials={},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
def test_rebuild_trigger_subscription_unsupported_credential_type(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when credential type is not supported for rebuild.
|
||||
|
||||
This test verifies:
|
||||
- Proper error is raised for unsupported credential types (not OAUTH2 or API_KEY)
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.UNAUTHORIZED # Not supported
|
||||
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
{},
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Credential type not supported for rebuild"):
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials={},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
def test_rebuild_trigger_subscription_name_uniqueness_check(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
|
||||
@ -10,7 +10,9 @@ from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun
|
||||
from models.enums import CreatorUserRole
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
|
||||
# Delay import of AppService to avoid circular dependency
|
||||
# from services.app_service import AppService
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
|
||||
@ -86,6 +88,9 @@ class TestWorkflowAppService:
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -147,6 +152,9 @@ class TestWorkflowAppService:
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -308,6 +316,156 @@ class TestWorkflowAppService:
|
||||
assert result_no_match["total"] == 0
|
||||
assert len(result_no_match["data"]) == 0
|
||||
|
||||
def test_get_paginate_workflow_app_logs_with_special_characters_in_keyword(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
r"""
|
||||
Test workflow app logs pagination with special characters in keyword to verify SQL injection prevention.
|
||||
|
||||
This test verifies:
|
||||
- Special characters (%, _) in keyword are properly escaped
|
||||
- Search treats special characters as literal characters, not wildcards
|
||||
- SQL injection via LIKE wildcards is prevented
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = WorkflowAppService()
|
||||
|
||||
# Test 1: Search with % character
|
||||
workflow_run_1 = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="1.0.0",
|
||||
graph=json.dumps({"nodes": [], "edges": []}),
|
||||
status="succeeded",
|
||||
inputs=json.dumps({"search_term": "50% discount", "input2": "other_value"}),
|
||||
outputs=json.dumps({"result": "50% discount applied", "status": "success"}),
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(workflow_run_1)
|
||||
db.session.flush()
|
||||
|
||||
workflow_app_log_1 = WorkflowAppLog(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_1.id,
|
||||
created_from="service-api",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
workflow_app_log_1.id = str(uuid.uuid4())
|
||||
workflow_app_log_1.created_at = datetime.now(UTC)
|
||||
db.session.add(workflow_app_log_1)
|
||||
db.session.commit()
|
||||
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20
|
||||
)
|
||||
# Should find the workflow_run_1 entry
|
||||
assert result["total"] >= 1
|
||||
assert len(result["data"]) >= 1
|
||||
assert any(log.workflow_run_id == workflow_run_1.id for log in result["data"])
|
||||
|
||||
# Test 2: Search with _ character
|
||||
workflow_run_2 = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="1.0.0",
|
||||
graph=json.dumps({"nodes": [], "edges": []}),
|
||||
status="succeeded",
|
||||
inputs=json.dumps({"search_term": "test_data_value", "input2": "other_value"}),
|
||||
outputs=json.dumps({"result": "test_data_value found", "status": "success"}),
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(workflow_run_2)
|
||||
db.session.flush()
|
||||
|
||||
workflow_app_log_2 = WorkflowAppLog(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_2.id,
|
||||
created_from="service-api",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
workflow_app_log_2.id = str(uuid.uuid4())
|
||||
workflow_app_log_2.created_at = datetime.now(UTC)
|
||||
db.session.add(workflow_app_log_2)
|
||||
db.session.commit()
|
||||
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=db_session_with_containers, app_model=app, keyword="test_data", page=1, limit=20
|
||||
)
|
||||
# Should find the workflow_run_2 entry
|
||||
assert result["total"] >= 1
|
||||
assert len(result["data"]) >= 1
|
||||
assert any(log.workflow_run_id == workflow_run_2.id for log in result["data"])
|
||||
|
||||
# Test 3: Search with % should NOT match 100% (verifies escaping works correctly)
|
||||
workflow_run_4 = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="1.0.0",
|
||||
graph=json.dumps({"nodes": [], "edges": []}),
|
||||
status="succeeded",
|
||||
inputs=json.dumps({"search_term": "100% different", "input2": "other_value"}),
|
||||
outputs=json.dumps({"result": "100% different result", "status": "success"}),
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(workflow_run_4)
|
||||
db.session.flush()
|
||||
|
||||
workflow_app_log_4 = WorkflowAppLog(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_4.id,
|
||||
created_from="service-api",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
workflow_app_log_4.id = str(uuid.uuid4())
|
||||
workflow_app_log_4.created_at = datetime.now(UTC)
|
||||
db.session.add(workflow_app_log_4)
|
||||
db.session.commit()
|
||||
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20
|
||||
)
|
||||
# Should only find the 50% entry (workflow_run_1), not the 100% entry (workflow_run_4)
|
||||
# This verifies that escaping works correctly - 50% should not match 100%
|
||||
assert result["total"] >= 1
|
||||
assert len(result["data"]) >= 1
|
||||
# Verify that we found workflow_run_1 (50% discount) but not workflow_run_4 (100% different)
|
||||
found_run_ids = [log.workflow_run_id for log in result["data"]]
|
||||
assert workflow_run_1.id in found_run_ids
|
||||
assert workflow_run_4.id not in found_run_ids
|
||||
|
||||
def test_get_paginate_workflow_app_logs_with_status_filter(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
|
||||
0
api/tests/unit_tests/controllers/__init__.py
Normal file
0
api/tests/unit_tests/controllers/__init__.py
Normal file
@ -0,0 +1,285 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from importlib import util
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from flask.views import MethodView
|
||||
|
||||
# kombu references MethodView as a global when importing celery/kombu pools.
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _load_app_module():
|
||||
module_name = "controllers.console.app.app"
|
||||
if module_name in sys.modules:
|
||||
return sys.modules[module_name]
|
||||
|
||||
root = Path(__file__).resolve().parents[5]
|
||||
module_path = root / "controllers" / "console" / "app" / "app.py"
|
||||
|
||||
class _StubNamespace:
|
||||
def __init__(self):
|
||||
self.models: dict[str, Any] = {}
|
||||
self.payload = None
|
||||
|
||||
def schema_model(self, name, schema):
|
||||
self.models[name] = schema
|
||||
|
||||
def _decorator(self, obj):
|
||||
return obj
|
||||
|
||||
def doc(self, *args, **kwargs):
|
||||
return self._decorator
|
||||
|
||||
def expect(self, *args, **kwargs):
|
||||
return self._decorator
|
||||
|
||||
def response(self, *args, **kwargs):
|
||||
return self._decorator
|
||||
|
||||
def route(self, *args, **kwargs):
|
||||
def decorator(obj):
|
||||
return obj
|
||||
|
||||
return decorator
|
||||
|
||||
stub_namespace = _StubNamespace()
|
||||
|
||||
original_console = sys.modules.get("controllers.console")
|
||||
original_app_pkg = sys.modules.get("controllers.console.app")
|
||||
stubbed_modules: list[tuple[str, ModuleType | None]] = []
|
||||
|
||||
console_module = ModuleType("controllers.console")
|
||||
console_module.__path__ = [str(root / "controllers" / "console")]
|
||||
console_module.console_ns = stub_namespace
|
||||
console_module.api = None
|
||||
console_module.bp = None
|
||||
sys.modules["controllers.console"] = console_module
|
||||
|
||||
app_package = ModuleType("controllers.console.app")
|
||||
app_package.__path__ = [str(root / "controllers" / "console" / "app")]
|
||||
sys.modules["controllers.console.app"] = app_package
|
||||
console_module.app = app_package
|
||||
|
||||
def _stub_module(name: str, attrs: dict[str, Any]):
|
||||
original = sys.modules.get(name)
|
||||
module = ModuleType(name)
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
sys.modules[name] = module
|
||||
stubbed_modules.append((name, original))
|
||||
|
||||
class _OpsTraceManager:
|
||||
@staticmethod
|
||||
def get_app_tracing_config(app_id: str) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def update_app_tracing_config(app_id: str, **kwargs) -> None:
|
||||
return None
|
||||
|
||||
_stub_module(
|
||||
"core.ops.ops_trace_manager",
|
||||
{
|
||||
"OpsTraceManager": _OpsTraceManager,
|
||||
"TraceQueueManager": object,
|
||||
"TraceTask": object,
|
||||
},
|
||||
)
|
||||
|
||||
spec = util.spec_from_file_location(module_name, module_path)
|
||||
module = util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
|
||||
try:
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
finally:
|
||||
for name, original in reversed(stubbed_modules):
|
||||
if original is not None:
|
||||
sys.modules[name] = original
|
||||
else:
|
||||
sys.modules.pop(name, None)
|
||||
if original_console is not None:
|
||||
sys.modules["controllers.console"] = original_console
|
||||
else:
|
||||
sys.modules.pop("controllers.console", None)
|
||||
if original_app_pkg is not None:
|
||||
sys.modules["controllers.console.app"] = original_app_pkg
|
||||
else:
|
||||
sys.modules.pop("controllers.console.app", None)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
_app_module = _load_app_module()
|
||||
AppDetailWithSite = _app_module.AppDetailWithSite
|
||||
AppPagination = _app_module.AppPagination
|
||||
AppPartial = _app_module.AppPartial
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_signed_url(monkeypatch):
|
||||
"""Ensure icon URL generation uses a deterministic helper for tests."""
|
||||
|
||||
def _fake_signed_url(key: str | None) -> str | None:
|
||||
if not key:
|
||||
return None
|
||||
return f"signed:{key}"
|
||||
|
||||
monkeypatch.setattr(_app_module.file_helpers, "get_signed_file_url", _fake_signed_url)
|
||||
|
||||
|
||||
def _ts(hour: int = 12) -> datetime:
|
||||
return datetime(2024, 1, 1, hour, 0, 0)
|
||||
|
||||
|
||||
def _dummy_model_config():
|
||||
return SimpleNamespace(
|
||||
model_dict={"provider": "openai", "name": "gpt-4o"},
|
||||
pre_prompt="hello",
|
||||
created_by="config-author",
|
||||
created_at=_ts(9),
|
||||
updated_by="config-editor",
|
||||
updated_at=_ts(10),
|
||||
)
|
||||
|
||||
|
||||
def _dummy_workflow():
|
||||
return SimpleNamespace(
|
||||
id="wf-1",
|
||||
created_by="workflow-author",
|
||||
created_at=_ts(8),
|
||||
updated_by="workflow-editor",
|
||||
updated_at=_ts(9),
|
||||
)
|
||||
|
||||
|
||||
def test_app_partial_serialization_uses_aliases():
|
||||
created_at = _ts()
|
||||
app_obj = SimpleNamespace(
|
||||
id="app-1",
|
||||
name="My App",
|
||||
desc_or_prompt="Prompt snippet",
|
||||
mode_compatible_with_agent="chat",
|
||||
icon_type="image",
|
||||
icon="icon-key",
|
||||
icon_background="#fff",
|
||||
app_model_config=_dummy_model_config(),
|
||||
workflow=_dummy_workflow(),
|
||||
created_by="creator",
|
||||
created_at=created_at,
|
||||
updated_by="editor",
|
||||
updated_at=created_at,
|
||||
tags=[SimpleNamespace(id="tag-1", name="Utilities", type="app")],
|
||||
access_mode="private",
|
||||
create_user_name="Creator",
|
||||
author_name="Author",
|
||||
has_draft_trigger=True,
|
||||
)
|
||||
|
||||
serialized = AppPartial.model_validate(app_obj, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
assert serialized["description"] == "Prompt snippet"
|
||||
assert serialized["mode"] == "chat"
|
||||
assert serialized["icon_url"] == "signed:icon-key"
|
||||
assert serialized["created_at"] == int(created_at.timestamp())
|
||||
assert serialized["updated_at"] == int(created_at.timestamp())
|
||||
assert serialized["model_config"]["model"] == {"provider": "openai", "name": "gpt-4o"}
|
||||
assert serialized["workflow"]["id"] == "wf-1"
|
||||
assert serialized["tags"][0]["name"] == "Utilities"
|
||||
|
||||
|
||||
def test_app_detail_with_site_includes_nested_serialization():
|
||||
timestamp = _ts(14)
|
||||
site = SimpleNamespace(
|
||||
code="site-code",
|
||||
title="Public Site",
|
||||
icon_type="image",
|
||||
icon="site-icon",
|
||||
created_at=timestamp,
|
||||
updated_at=timestamp,
|
||||
)
|
||||
app_obj = SimpleNamespace(
|
||||
id="app-2",
|
||||
name="Detailed App",
|
||||
description="Desc",
|
||||
mode_compatible_with_agent="advanced-chat",
|
||||
icon_type="image",
|
||||
icon="detail-icon",
|
||||
icon_background="#123456",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
app_model_config={
|
||||
"opening_statement": "hi",
|
||||
"model": {"provider": "openai", "name": "gpt-4o"},
|
||||
"retriever_resource": {"enabled": True},
|
||||
},
|
||||
workflow=_dummy_workflow(),
|
||||
tracing={"enabled": True},
|
||||
use_icon_as_answer_icon=True,
|
||||
created_by="creator",
|
||||
created_at=timestamp,
|
||||
updated_by="editor",
|
||||
updated_at=timestamp,
|
||||
access_mode="public",
|
||||
tags=[SimpleNamespace(id="tag-2", name="Prod", type="app")],
|
||||
api_base_url="https://api.example.com/v1",
|
||||
max_active_requests=5,
|
||||
deleted_tools=[{"type": "api", "tool_name": "search", "provider_id": "prov"}],
|
||||
site=site,
|
||||
)
|
||||
|
||||
serialized = AppDetailWithSite.model_validate(app_obj, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
assert serialized["icon_url"] == "signed:detail-icon"
|
||||
assert serialized["model_config"]["retriever_resource"] == {"enabled": True}
|
||||
assert serialized["deleted_tools"][0]["tool_name"] == "search"
|
||||
assert serialized["site"]["icon_url"] == "signed:site-icon"
|
||||
assert serialized["site"]["created_at"] == int(timestamp.timestamp())
|
||||
|
||||
|
||||
def test_app_pagination_aliases_per_page_and_has_next():
|
||||
item_one = SimpleNamespace(
|
||||
id="app-10",
|
||||
name="Paginated One",
|
||||
desc_or_prompt="Summary",
|
||||
mode_compatible_with_agent="chat",
|
||||
icon_type="image",
|
||||
icon="first-icon",
|
||||
created_at=_ts(15),
|
||||
updated_at=_ts(15),
|
||||
)
|
||||
item_two = SimpleNamespace(
|
||||
id="app-11",
|
||||
name="Paginated Two",
|
||||
desc_or_prompt="Summary",
|
||||
mode_compatible_with_agent="agent-chat",
|
||||
icon_type="emoji",
|
||||
icon="🙂",
|
||||
created_at=_ts(16),
|
||||
updated_at=_ts(16),
|
||||
)
|
||||
pagination = SimpleNamespace(
|
||||
page=2,
|
||||
per_page=10,
|
||||
total=50,
|
||||
has_next=True,
|
||||
items=[item_one, item_two],
|
||||
)
|
||||
|
||||
serialized = AppPagination.model_validate(pagination, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
assert serialized["page"] == 2
|
||||
assert serialized["limit"] == 10
|
||||
assert serialized["has_more"] is True
|
||||
assert len(serialized["data"]) == 2
|
||||
assert serialized["data"][0]["icon_url"] == "signed:first-icon"
|
||||
assert serialized["data"][1]["icon_url"] is None
|
||||
@ -0,0 +1,145 @@
|
||||
"""
|
||||
Test for document detail API data_source_info serialization fix.
|
||||
|
||||
This test verifies that the document detail API returns both data_source_info
|
||||
and data_source_detail_dict for all data_source_type values, including "local_file".
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Generic, Literal, NotRequired, TypedDict, TypeVar, Union
|
||||
|
||||
from models.dataset import Document
|
||||
|
||||
|
||||
class LocalFileInfo(TypedDict):
|
||||
file_path: str
|
||||
size: int
|
||||
created_at: NotRequired[str]
|
||||
|
||||
|
||||
class UploadFileInfo(TypedDict):
|
||||
upload_file_id: str
|
||||
|
||||
|
||||
class NotionImportInfo(TypedDict):
|
||||
notion_page_id: str
|
||||
workspace_id: str
|
||||
|
||||
|
||||
class WebsiteCrawlInfo(TypedDict):
|
||||
url: str
|
||||
job_id: str
|
||||
|
||||
|
||||
RawInfo = Union[LocalFileInfo, UploadFileInfo, NotionImportInfo, WebsiteCrawlInfo]
|
||||
T_type = TypeVar("T_type", bound=str)
|
||||
T_info = TypeVar("T_info", bound=Union[LocalFileInfo, UploadFileInfo, NotionImportInfo, WebsiteCrawlInfo])
|
||||
|
||||
|
||||
class Case(TypedDict, Generic[T_type, T_info]):
|
||||
data_source_type: T_type
|
||||
data_source_info: str
|
||||
expected_raw: T_info
|
||||
|
||||
|
||||
LocalFileCase = Case[Literal["local_file"], LocalFileInfo]
|
||||
UploadFileCase = Case[Literal["upload_file"], UploadFileInfo]
|
||||
NotionImportCase = Case[Literal["notion_import"], NotionImportInfo]
|
||||
WebsiteCrawlCase = Case[Literal["website_crawl"], WebsiteCrawlInfo]
|
||||
|
||||
AnyCase = Union[LocalFileCase, UploadFileCase, NotionImportCase, WebsiteCrawlCase]
|
||||
|
||||
|
||||
case_1: LocalFileCase = {
|
||||
"data_source_type": "local_file",
|
||||
"data_source_info": json.dumps({"file_path": "/tmp/test.txt", "size": 1024}),
|
||||
"expected_raw": {"file_path": "/tmp/test.txt", "size": 1024},
|
||||
}
|
||||
|
||||
|
||||
# ERROR: Expected LocalFileInfo, but got WebsiteCrawlInfo
|
||||
case_2: LocalFileCase = {
|
||||
"data_source_type": "local_file",
|
||||
"data_source_info": "...",
|
||||
"expected_raw": {"file_path": "https://google.com", "size": 123},
|
||||
}
|
||||
|
||||
cases: list[AnyCase] = [case_1]
|
||||
|
||||
|
||||
class TestDocumentDetailDataSourceInfo:
|
||||
"""Test cases for document detail API data_source_info serialization."""
|
||||
|
||||
def test_data_source_info_dict_returns_raw_data(self):
|
||||
"""Test that data_source_info_dict returns raw JSON data for all data_source_type values."""
|
||||
# Test data for different data_source_type values
|
||||
for case in cases:
|
||||
document = Document(
|
||||
data_source_type=case["data_source_type"],
|
||||
data_source_info=case["data_source_info"],
|
||||
)
|
||||
|
||||
# Test data_source_info_dict (raw data)
|
||||
raw_result = document.data_source_info_dict
|
||||
assert raw_result == case["expected_raw"], f"Failed for {case['data_source_type']}"
|
||||
|
||||
# Verify raw_result is always a valid dict
|
||||
assert isinstance(raw_result, dict)
|
||||
|
||||
def test_local_file_data_source_info_without_db_context(self):
|
||||
"""Test that local_file type data_source_info_dict works without database context."""
|
||||
test_data: LocalFileInfo = {
|
||||
"file_path": "/local/path/document.txt",
|
||||
"size": 512,
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
}
|
||||
|
||||
document = Document(
|
||||
data_source_type="local_file",
|
||||
data_source_info=json.dumps(test_data),
|
||||
)
|
||||
|
||||
# data_source_info_dict should return the raw data (this doesn't need DB context)
|
||||
raw_data = document.data_source_info_dict
|
||||
assert raw_data == test_data
|
||||
assert isinstance(raw_data, dict)
|
||||
|
||||
# Verify the data contains expected keys for pipeline mode
|
||||
assert "file_path" in raw_data
|
||||
assert "size" in raw_data
|
||||
|
||||
def test_notion_and_website_crawl_data_source_detail(self):
|
||||
"""Test that notion_import and website_crawl return raw data in data_source_detail_dict."""
|
||||
# Test notion_import
|
||||
notion_data: NotionImportInfo = {"notion_page_id": "page-123", "workspace_id": "ws-456"}
|
||||
document = Document(
|
||||
data_source_type="notion_import",
|
||||
data_source_info=json.dumps(notion_data),
|
||||
)
|
||||
|
||||
# data_source_detail_dict should return raw data for notion_import
|
||||
detail_result = document.data_source_detail_dict
|
||||
assert detail_result == notion_data
|
||||
|
||||
# Test website_crawl
|
||||
website_data: WebsiteCrawlInfo = {"url": "https://example.com", "job_id": "job-789"}
|
||||
document = Document(
|
||||
data_source_type="website_crawl",
|
||||
data_source_info=json.dumps(website_data),
|
||||
)
|
||||
|
||||
# data_source_detail_dict should return raw data for website_crawl
|
||||
detail_result = document.data_source_detail_dict
|
||||
assert detail_result == website_data
|
||||
|
||||
def test_local_file_data_source_detail_dict_without_db(self):
|
||||
"""Test that local_file returns empty data_source_detail_dict (this doesn't need DB context)."""
|
||||
# Test local_file - this should work without database context since it returns {} early
|
||||
document = Document(
|
||||
data_source_type="local_file",
|
||||
data_source_info=json.dumps({"file_path": "/tmp/test.txt"}),
|
||||
)
|
||||
|
||||
# Should return empty dict for local_file type (handled in the model)
|
||||
detail_result = document.data_source_detail_dict
|
||||
assert detail_result == {}
|
||||
@ -0,0 +1,144 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock
|
||||
|
||||
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
|
||||
from core.variables import StringVariable
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||
from core.workflow.graph_events.node import NodeRunSucceededEvent
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
class MockReadOnlyVariablePool:
|
||||
def __init__(self, variables: dict[tuple[str, str], Segment] | None = None) -> None:
|
||||
self._variables = variables or {}
|
||||
|
||||
def get(self, selector: Sequence[str]) -> Segment | None:
|
||||
if len(selector) < 2:
|
||||
return None
|
||||
return self._variables.get((selector[0], selector[1]))
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> dict[str, object]:
|
||||
return {key: value for (nid, key), value in self._variables.items() if nid == node_id}
|
||||
|
||||
def get_by_prefix(self, prefix: str) -> dict[str, object]:
|
||||
return {key: value for (nid, key), value in self._variables.items() if nid == prefix}
|
||||
|
||||
|
||||
def _build_graph_runtime_state(
|
||||
variable_pool: MockReadOnlyVariablePool,
|
||||
conversation_id: str | None = None,
|
||||
) -> ReadOnlyGraphRuntimeState:
|
||||
graph_runtime_state = Mock(spec=ReadOnlyGraphRuntimeState)
|
||||
graph_runtime_state.variable_pool = variable_pool
|
||||
graph_runtime_state.system_variable = SystemVariable(conversation_id=conversation_id).as_view()
|
||||
return graph_runtime_state
|
||||
|
||||
|
||||
def _build_node_run_succeeded_event(
|
||||
*,
|
||||
node_type: NodeType,
|
||||
outputs: dict[str, object] | None = None,
|
||||
process_data: dict[str, object] | None = None,
|
||||
) -> NodeRunSucceededEvent:
|
||||
return NodeRunSucceededEvent(
|
||||
id="node-exec-id",
|
||||
node_id="assigner",
|
||||
node_type=node_type,
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs or {},
|
||||
process_data=process_data or {},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_persists_conversation_variables_from_assigner_output():
|
||||
conversation_id = "conv-123"
|
||||
variable = StringVariable(
|
||||
id="var-1",
|
||||
name="name",
|
||||
value="updated",
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "name"],
|
||||
)
|
||||
process_data = common_helpers.set_updated_variables(
|
||||
{}, [common_helpers.variable_to_processed_data(variable.selector, variable)]
|
||||
)
|
||||
|
||||
variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable})
|
||||
|
||||
updater = Mock()
|
||||
layer = ConversationVariablePersistenceLayer(updater)
|
||||
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
|
||||
|
||||
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data)
|
||||
layer.on_event(event)
|
||||
|
||||
updater.update.assert_called_once_with(conversation_id=conversation_id, variable=variable)
|
||||
updater.flush.assert_called_once()
|
||||
|
||||
|
||||
def test_skips_when_outputs_missing():
|
||||
conversation_id = "conv-456"
|
||||
variable = StringVariable(
|
||||
id="var-2",
|
||||
name="name",
|
||||
value="updated",
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "name"],
|
||||
)
|
||||
|
||||
variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable})
|
||||
|
||||
updater = Mock()
|
||||
layer = ConversationVariablePersistenceLayer(updater)
|
||||
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
|
||||
|
||||
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER)
|
||||
layer.on_event(event)
|
||||
|
||||
updater.update.assert_not_called()
|
||||
updater.flush.assert_not_called()
|
||||
|
||||
|
||||
def test_skips_non_assigner_nodes():
|
||||
updater = Mock()
|
||||
layer = ConversationVariablePersistenceLayer(updater)
|
||||
layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool()), Mock(spec=CommandChannel))
|
||||
|
||||
event = _build_node_run_succeeded_event(node_type=NodeType.LLM)
|
||||
layer.on_event(event)
|
||||
|
||||
updater.update.assert_not_called()
|
||||
updater.flush.assert_not_called()
|
||||
|
||||
|
||||
def test_skips_non_conversation_variables():
|
||||
conversation_id = "conv-789"
|
||||
non_conversation_variable = StringVariable(
|
||||
id="var-3",
|
||||
name="name",
|
||||
value="updated",
|
||||
selector=["environment", "name"],
|
||||
)
|
||||
process_data = common_helpers.set_updated_variables(
|
||||
{}, [common_helpers.variable_to_processed_data(non_conversation_variable.selector, non_conversation_variable)]
|
||||
)
|
||||
|
||||
variable_pool = MockReadOnlyVariablePool()
|
||||
|
||||
updater = Mock()
|
||||
layer = ConversationVariablePersistenceLayer(updater)
|
||||
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
|
||||
|
||||
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data)
|
||||
layer.on_event(event)
|
||||
|
||||
updater.update.assert_not_called()
|
||||
updater.flush.assert_not_called()
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from time import time
|
||||
from unittest.mock import Mock
|
||||
|
||||
@ -15,6 +16,7 @@ from core.app.layers.pause_state_persist_layer import (
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError
|
||||
from core.workflow.graph_events.graph import (
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPausedEvent,
|
||||
@ -66,8 +68,10 @@ class MockReadOnlyVariablePool:
|
||||
def __init__(self, variables: dict[tuple[str, str], object] | None = None):
|
||||
self._variables = variables or {}
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
value = self._variables.get((node_id, variable_key))
|
||||
def get(self, selector: Sequence[str]) -> Segment | None:
|
||||
if len(selector) < 2:
|
||||
return None
|
||||
value = self._variables.get((selector[0], selector[1]))
|
||||
if value is None:
|
||||
return None
|
||||
mock_segment = Mock(spec=Segment)
|
||||
@ -209,8 +213,9 @@ class TestPauseStatePersistenceLayer:
|
||||
|
||||
assert layer._session_maker is session_factory
|
||||
assert layer._state_owner_user_id == state_owner_user_id
|
||||
assert not hasattr(layer, "graph_runtime_state")
|
||||
assert not hasattr(layer, "command_channel")
|
||||
with pytest.raises(GraphEngineLayerNotInitializedError):
|
||||
_ = layer.graph_runtime_state
|
||||
assert layer.command_channel is None
|
||||
|
||||
def test_initialize_sets_dependencies(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
@ -295,7 +300,7 @@ class TestPauseStatePersistenceLayer:
|
||||
mock_factory.assert_not_called()
|
||||
mock_repo.create_workflow_pause.assert_not_called()
|
||||
|
||||
def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self):
|
||||
def test_on_event_raises_when_graph_runtime_state_is_uninitialized(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(
|
||||
session_factory=session_factory,
|
||||
@ -305,7 +310,7 @@ class TestPauseStatePersistenceLayer:
|
||||
|
||||
event = TestDataFactory.create_graph_run_paused_event()
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
with pytest.raises(GraphEngineLayerNotInitializedError):
|
||||
layer.on_event(event)
|
||||
|
||||
def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
"""Primarily used for testing merged cell scenarios"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from types import SimpleNamespace
|
||||
|
||||
from docx import Document
|
||||
from docx.oxml import OxmlElement
|
||||
from docx.oxml.ns import qn
|
||||
|
||||
import core.rag.extractor.word_extractor as we
|
||||
from core.rag.extractor.word_extractor import WordExtractor
|
||||
@ -165,3 +169,110 @@ def test_extract_images_from_docx_uses_internal_files_url():
|
||||
dify_config.FILES_URL = original_files_url
|
||||
if original_internal_files_url is not None:
|
||||
dify_config.INTERNAL_FILES_URL = original_internal_files_url
|
||||
|
||||
|
||||
def test_extract_hyperlinks(monkeypatch):
|
||||
# Mock db and storage to avoid issues during image extraction (even if no images are present)
|
||||
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None))
|
||||
db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda o: None, commit=lambda: None))
|
||||
monkeypatch.setattr(we, "db", db_stub)
|
||||
monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False)
|
||||
monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False)
|
||||
|
||||
doc = Document()
|
||||
p = doc.add_paragraph("Visit ")
|
||||
|
||||
# Adding a hyperlink manually
|
||||
r_id = "rId99"
|
||||
hyperlink = OxmlElement("w:hyperlink")
|
||||
hyperlink.set(qn("r:id"), r_id)
|
||||
|
||||
new_run = OxmlElement("w:r")
|
||||
t = OxmlElement("w:t")
|
||||
t.text = "Dify"
|
||||
new_run.append(t)
|
||||
hyperlink.append(new_run)
|
||||
p._p.append(hyperlink)
|
||||
|
||||
# Add relationship to the part
|
||||
doc.part.rels.add_relationship(
|
||||
"http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink",
|
||||
"https://dify.ai",
|
||||
r_id,
|
||||
is_external=True,
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp:
|
||||
doc.save(tmp.name)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
extractor = WordExtractor(tmp_path, "tenant_id", "user_id")
|
||||
docs = extractor.extract()
|
||||
# Verify modern hyperlink extraction
|
||||
assert "Visit[Dify](https://dify.ai)" in docs[0].page_content
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
|
||||
|
||||
def test_extract_legacy_hyperlinks(monkeypatch):
|
||||
# Mock db and storage
|
||||
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None))
|
||||
db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda o: None, commit=lambda: None))
|
||||
monkeypatch.setattr(we, "db", db_stub)
|
||||
monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False)
|
||||
monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False)
|
||||
|
||||
doc = Document()
|
||||
p = doc.add_paragraph()
|
||||
|
||||
# Construct a legacy HYPERLINK field:
|
||||
# 1. w:fldChar (begin)
|
||||
# 2. w:instrText (HYPERLINK "http://example.com")
|
||||
# 3. w:fldChar (separate)
|
||||
# 4. w:r (visible text "Example")
|
||||
# 5. w:fldChar (end)
|
||||
|
||||
run1 = OxmlElement("w:r")
|
||||
fldCharBegin = OxmlElement("w:fldChar")
|
||||
fldCharBegin.set(qn("w:fldCharType"), "begin")
|
||||
run1.append(fldCharBegin)
|
||||
p._p.append(run1)
|
||||
|
||||
run2 = OxmlElement("w:r")
|
||||
instrText = OxmlElement("w:instrText")
|
||||
instrText.text = ' HYPERLINK "http://example.com" '
|
||||
run2.append(instrText)
|
||||
p._p.append(run2)
|
||||
|
||||
run3 = OxmlElement("w:r")
|
||||
fldCharSep = OxmlElement("w:fldChar")
|
||||
fldCharSep.set(qn("w:fldCharType"), "separate")
|
||||
run3.append(fldCharSep)
|
||||
p._p.append(run3)
|
||||
|
||||
run4 = OxmlElement("w:r")
|
||||
t4 = OxmlElement("w:t")
|
||||
t4.text = "Example"
|
||||
run4.append(t4)
|
||||
p._p.append(run4)
|
||||
|
||||
run5 = OxmlElement("w:r")
|
||||
fldCharEnd = OxmlElement("w:fldChar")
|
||||
fldCharEnd.set(qn("w:fldCharType"), "end")
|
||||
run5.append(fldCharEnd)
|
||||
p._p.append(run5)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp:
|
||||
doc.save(tmp.name)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
extractor = WordExtractor(tmp_path, "tenant_id", "user_id")
|
||||
docs = extractor.extract()
|
||||
# Verify legacy hyperlink extraction
|
||||
assert "[Example](http://example.com)" in docs[0].page_content
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
|
||||
@ -0,0 +1,113 @@
|
||||
import threading
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class TestRetrievalService:
|
||||
@pytest.fixture
|
||||
def mock_dataset(self) -> Dataset:
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = str(uuid4())
|
||||
dataset.tenant_id = str(uuid4())
|
||||
dataset.name = "test_dataset"
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.provider = "dify"
|
||||
return dataset
|
||||
|
||||
def test_multiple_retrieve_reranking_with_app_context(self, mock_dataset):
|
||||
"""
|
||||
Repro test for current bug:
|
||||
reranking runs after `with flask_app.app_context():` exits.
|
||||
`_multiple_retrieve_thread` catches exceptions and stores them into `thread_exceptions`,
|
||||
so we must assert from that list (not from an outer try/except).
|
||||
"""
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
flask_app = Flask(__name__)
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
# second dataset to ensure dataset_count > 1 reranking branch
|
||||
secondary_dataset = Mock(spec=Dataset)
|
||||
secondary_dataset.id = str(uuid4())
|
||||
secondary_dataset.provider = "dify"
|
||||
secondary_dataset.indexing_technique = "high_quality"
|
||||
|
||||
# retriever returns 1 doc into internal list (all_documents_item)
|
||||
document = Document(
|
||||
page_content="Context aware doc",
|
||||
metadata={
|
||||
"doc_id": "doc1",
|
||||
"score": 0.95,
|
||||
"document_id": str(uuid4()),
|
||||
"dataset_id": mock_dataset.id,
|
||||
},
|
||||
provider="dify",
|
||||
)
|
||||
|
||||
def fake_retriever(
|
||||
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
|
||||
):
|
||||
all_documents.append(document)
|
||||
|
||||
called = {"init": 0, "invoke": 0}
|
||||
|
||||
class ContextRequiredPostProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
called["init"] += 1
|
||||
# will raise RuntimeError if no Flask app context exists
|
||||
_ = current_app.name
|
||||
|
||||
def invoke(self, *args, **kwargs):
|
||||
called["invoke"] += 1
|
||||
_ = current_app.name
|
||||
return kwargs.get("documents") or args[1]
|
||||
|
||||
# output list from _multiple_retrieve_thread
|
||||
all_documents: list[Document] = []
|
||||
|
||||
# IMPORTANT: _multiple_retrieve_thread swallows exceptions and appends them here
|
||||
thread_exceptions: list[Exception] = []
|
||||
|
||||
def target():
|
||||
with patch.object(dataset_retrieval, "_retriever", side_effect=fake_retriever):
|
||||
with patch(
|
||||
"core.rag.retrieval.dataset_retrieval.DataPostProcessor",
|
||||
ContextRequiredPostProcessor,
|
||||
):
|
||||
dataset_retrieval._multiple_retrieve_thread(
|
||||
flask_app=flask_app,
|
||||
available_datasets=[mock_dataset, secondary_dataset],
|
||||
metadata_condition=None,
|
||||
metadata_filter_document_ids=None,
|
||||
all_documents=all_documents,
|
||||
tenant_id=tenant_id,
|
||||
reranking_enable=True,
|
||||
reranking_mode="reranking_model",
|
||||
reranking_model={
|
||||
"reranking_provider_name": "cohere",
|
||||
"reranking_model_name": "rerank-v2",
|
||||
},
|
||||
weights=None,
|
||||
top_k=3,
|
||||
score_threshold=0.0,
|
||||
query="test query",
|
||||
attachment_id=None,
|
||||
dataset_count=2, # force reranking branch
|
||||
thread_exceptions=thread_exceptions, # ✅ key
|
||||
)
|
||||
|
||||
t = threading.Thread(target=target)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
# Ensure reranking branch was actually executed
|
||||
assert called["init"] >= 1, "DataPostProcessor was never constructed; reranking branch may not have run."
|
||||
|
||||
# Current buggy code should record an exception (not raise it)
|
||||
assert not thread_exceptions, thread_exceptions
|
||||
@ -3,8 +3,15 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.variables import IntegerVariable, StringVariable
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, GraphEngineCommand
|
||||
from core.workflow.graph_engine.entities.commands import (
|
||||
AbortCommand,
|
||||
CommandType,
|
||||
GraphEngineCommand,
|
||||
UpdateVariablesCommand,
|
||||
VariableUpdate,
|
||||
)
|
||||
|
||||
|
||||
class TestRedisChannel:
|
||||
@ -148,6 +155,43 @@ class TestRedisChannel:
|
||||
assert commands[0].command_type == CommandType.ABORT
|
||||
assert isinstance(commands[1], AbortCommand)
|
||||
|
||||
def test_fetch_commands_with_update_variables_command(self):
|
||||
"""Test fetching update variables command from Redis."""
|
||||
mock_redis = MagicMock()
|
||||
pending_pipe = MagicMock()
|
||||
fetch_pipe = MagicMock()
|
||||
pending_context = MagicMock()
|
||||
fetch_context = MagicMock()
|
||||
pending_context.__enter__.return_value = pending_pipe
|
||||
pending_context.__exit__.return_value = None
|
||||
fetch_context.__enter__.return_value = fetch_pipe
|
||||
fetch_context.__exit__.return_value = None
|
||||
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
|
||||
|
||||
update_command = UpdateVariablesCommand(
|
||||
updates=[
|
||||
VariableUpdate(
|
||||
value=StringVariable(name="foo", value="bar", selector=["node1", "foo"]),
|
||||
),
|
||||
VariableUpdate(
|
||||
value=IntegerVariable(name="baz", value=123, selector=["node2", "baz"]),
|
||||
),
|
||||
]
|
||||
)
|
||||
command_json = json.dumps(update_command.model_dump())
|
||||
|
||||
pending_pipe.execute.return_value = [b"1", 1]
|
||||
fetch_pipe.execute.return_value = [[command_json.encode()], 1]
|
||||
|
||||
channel = RedisChannel(mock_redis, "test:key")
|
||||
commands = channel.fetch_commands()
|
||||
|
||||
assert len(commands) == 1
|
||||
assert isinstance(commands[0], UpdateVariablesCommand)
|
||||
assert isinstance(commands[0].updates[0].value, StringVariable)
|
||||
assert list(commands[0].updates[0].value.selector) == ["node1", "foo"]
|
||||
assert commands[0].updates[0].value.value == "bar"
|
||||
|
||||
def test_fetch_commands_skips_invalid_json(self):
|
||||
"""Test that invalid JSON commands are skipped."""
|
||||
mock_redis = MagicMock()
|
||||
|
||||
@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_engine.layers.base import (
|
||||
GraphEngineLayer,
|
||||
GraphEngineLayerNotInitializedError,
|
||||
)
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
|
||||
from ..test_table_runner import WorkflowRunner
|
||||
|
||||
|
||||
class LayerForTest(GraphEngineLayer):
|
||||
def on_graph_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
pass
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def test_layer_runtime_state_raises_when_uninitialized() -> None:
|
||||
layer = LayerForTest()
|
||||
|
||||
with pytest.raises(GraphEngineLayerNotInitializedError):
|
||||
_ = layer.graph_runtime_state
|
||||
|
||||
|
||||
def test_layer_runtime_state_available_after_engine_layer() -> None:
|
||||
runner = WorkflowRunner()
|
||||
fixture_data = runner.load_fixture("simple_passthrough_workflow")
|
||||
graph, graph_runtime_state = runner.create_graph_from_fixture(
|
||||
fixture_data,
|
||||
inputs={"query": "test layer state"},
|
||||
)
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
layer = LayerForTest()
|
||||
engine.layer(layer)
|
||||
|
||||
outputs = layer.graph_runtime_state.outputs
|
||||
ready_queue_size = layer.graph_runtime_state.ready_queue_size
|
||||
|
||||
assert outputs == {}
|
||||
assert isinstance(ready_queue_size, int)
|
||||
assert ready_queue_size >= 0
|
||||
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from unittest import mock
|
||||
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
@ -36,6 +37,7 @@ def test_dispatcher_should_consume_remains_events_after_pause():
|
||||
event_queue=event_queue,
|
||||
event_handler=event_handler,
|
||||
execution_coordinator=execution_coordinator,
|
||||
stop_event=threading.Event(),
|
||||
)
|
||||
dispatcher._dispatcher_loop()
|
||||
assert event_queue.empty()
|
||||
@ -96,6 +98,7 @@ def _run_dispatcher_for_event(event) -> int:
|
||||
event_queue=event_queue,
|
||||
event_handler=event_handler,
|
||||
execution_coordinator=coordinator,
|
||||
stop_event=threading.Event(),
|
||||
)
|
||||
|
||||
dispatcher._dispatcher_loop()
|
||||
@ -181,6 +184,7 @@ def test_dispatcher_drain_event_queue():
|
||||
event_queue=event_queue,
|
||||
event_handler=event_handler,
|
||||
execution_coordinator=coordinator,
|
||||
stop_event=threading.Event(),
|
||||
)
|
||||
|
||||
dispatcher._dispatcher_loop()
|
||||
|
||||
@ -4,12 +4,19 @@ import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import IntegerVariable, StringVariable
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand
|
||||
from core.workflow.graph_engine.entities.commands import (
|
||||
AbortCommand,
|
||||
CommandType,
|
||||
PauseCommand,
|
||||
UpdateVariablesCommand,
|
||||
VariableUpdate,
|
||||
)
|
||||
from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
@ -180,3 +187,67 @@ def test_pause_command():
|
||||
|
||||
graph_execution = engine.graph_runtime_state.graph_execution
|
||||
assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")]
|
||||
|
||||
|
||||
def test_update_variables_command_updates_pool():
|
||||
"""Test that GraphEngine updates variable pool via update variables command."""
|
||||
|
||||
shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
shared_runtime_state.variable_pool.add(("node1", "foo"), "old value")
|
||||
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start"
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=shared_runtime_state,
|
||||
)
|
||||
mock_graph.nodes["start"] = start_node
|
||||
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
command_channel = InMemoryChannel()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=shared_runtime_state,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
update_command = UpdateVariablesCommand(
|
||||
updates=[
|
||||
VariableUpdate(
|
||||
value=StringVariable(name="foo", value="new value", selector=["node1", "foo"]),
|
||||
),
|
||||
VariableUpdate(
|
||||
value=IntegerVariable(name="bar", value=123, selector=["node2", "bar"]),
|
||||
),
|
||||
]
|
||||
)
|
||||
command_channel.send_command(update_command)
|
||||
|
||||
list(engine.run())
|
||||
|
||||
updated_existing = shared_runtime_state.variable_pool.get(["node1", "foo"])
|
||||
added_new = shared_runtime_state.variable_pool.get(["node2", "bar"])
|
||||
|
||||
assert updated_existing is not None
|
||||
assert updated_existing.value == "new value"
|
||||
assert added_new is not None
|
||||
assert added_new.value == 123
|
||||
|
||||
@ -0,0 +1,539 @@
|
||||
"""
|
||||
Unit tests for stop_event functionality in GraphEngine.
|
||||
|
||||
Tests the unified stop_event management by GraphEngine and its propagation
|
||||
to WorkerPool, Worker, Dispatcher, and Nodes.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
)
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
class TestStopEventPropagation:
|
||||
"""Test suite for stop_event propagation through GraphEngine components."""
|
||||
|
||||
def test_graph_engine_creates_stop_event(self):
|
||||
"""Test that GraphEngine creates a stop_event on initialization."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Verify stop_event was created
|
||||
assert engine._stop_event is not None
|
||||
assert isinstance(engine._stop_event, threading.Event)
|
||||
|
||||
# Verify it was set in graph_runtime_state
|
||||
assert runtime_state.stop_event is not None
|
||||
assert runtime_state.stop_event is engine._stop_event
|
||||
|
||||
def test_stop_event_cleared_on_start(self):
|
||||
"""Test that stop_event is cleared when execution starts."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start" # Set proper id
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
mock_graph.nodes["start"] = start_node
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Set the stop_event before running
|
||||
engine._stop_event.set()
|
||||
assert engine._stop_event.is_set()
|
||||
|
||||
# Run the engine (should clear the stop_event)
|
||||
events = list(engine.run())
|
||||
|
||||
# After running, stop_event should be set again (by _stop_execution)
|
||||
# But during start it was cleared
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
|
||||
|
||||
def test_stop_event_set_on_stop(self):
|
||||
"""Test that stop_event is set when execution stops."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start" # Set proper id
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
mock_graph.nodes["start"] = start_node
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Initially not set
|
||||
assert not engine._stop_event.is_set()
|
||||
|
||||
# Run the engine
|
||||
list(engine.run())
|
||||
|
||||
# After execution completes, stop_event should be set
|
||||
assert engine._stop_event.is_set()
|
||||
|
||||
def test_stop_event_passed_to_worker_pool(self):
|
||||
"""Test that stop_event is passed to WorkerPool."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Verify WorkerPool has the stop_event
|
||||
assert engine._worker_pool._stop_event is not None
|
||||
assert engine._worker_pool._stop_event is engine._stop_event
|
||||
|
||||
def test_stop_event_passed_to_dispatcher(self):
|
||||
"""Test that stop_event is passed to Dispatcher."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Verify Dispatcher has the stop_event
|
||||
assert engine._dispatcher._stop_event is not None
|
||||
assert engine._dispatcher._stop_event is engine._stop_event
|
||||
|
||||
|
||||
class TestNodeStopCheck:
|
||||
"""Test suite for Node._should_stop() functionality."""
|
||||
|
||||
def test_node_should_stop_checks_runtime_state(self):
|
||||
"""Test that Node._should_stop() checks GraphRuntimeState.stop_event."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
answer_node = AnswerNode(
|
||||
id="answer",
|
||||
config={"id": "answer", "data": {"title": "answer", "answer": "{{#start.result#}}"}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
# Initially stop_event is not set
|
||||
assert not answer_node._should_stop()
|
||||
|
||||
# Set the stop_event
|
||||
runtime_state.stop_event.set()
|
||||
|
||||
# Now _should_stop should return True
|
||||
assert answer_node._should_stop()
|
||||
|
||||
def test_node_run_checks_stop_event_between_yields(self):
|
||||
"""Test that Node.run() checks stop_event between yielding events."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
# Create a simple node
|
||||
answer_node = AnswerNode(
|
||||
id="answer",
|
||||
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
# Set stop_event BEFORE running the node
|
||||
runtime_state.stop_event.set()
|
||||
|
||||
# Run the node - should yield start event then detect stop
|
||||
# The node should check stop_event before processing
|
||||
assert answer_node._should_stop(), "stop_event should be set"
|
||||
|
||||
# Run and collect events
|
||||
events = list(answer_node.run())
|
||||
|
||||
# Since stop_event is set at the start, we should get:
|
||||
# 1. NodeRunStartedEvent (always yielded first)
|
||||
# 2. Either NodeRunFailedEvent (if detected early) or NodeRunSucceededEvent (if too fast)
|
||||
assert len(events) >= 2
|
||||
assert isinstance(events[0], NodeRunStartedEvent)
|
||||
|
||||
# Note: AnswerNode is very simple and might complete before stop check
|
||||
# The important thing is that _should_stop() returns True when stop_event is set
|
||||
assert answer_node._should_stop()
|
||||
|
||||
|
||||
class TestStopEventIntegration:
|
||||
"""Integration tests for stop_event in workflow execution."""
|
||||
|
||||
def test_simple_workflow_respects_stop_event(self):
|
||||
"""Test that a simple workflow respects stop_event."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start"
|
||||
|
||||
# Create start and answer nodes
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
answer_node = AnswerNode(
|
||||
id="answer",
|
||||
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
mock_graph.nodes["start"] = start_node
|
||||
mock_graph.nodes["answer"] = answer_node
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Set stop_event before running
|
||||
runtime_state.stop_event.set()
|
||||
|
||||
# Run the engine
|
||||
events = list(engine.run())
|
||||
|
||||
# Should get started event but not succeeded (due to stop)
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
# The workflow should still complete (start node runs quickly)
|
||||
# but answer node might be cancelled depending on timing
|
||||
|
||||
def test_stop_event_with_concurrent_nodes(self):
|
||||
"""Test stop_event behavior with multiple concurrent nodes."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
# Create multiple nodes
|
||||
for i in range(3):
|
||||
answer_node = AnswerNode(
|
||||
id=f"answer_{i}",
|
||||
config={"id": f"answer_{i}", "data": {"title": f"answer_{i}", "answer": f"test{i}"}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
mock_graph.nodes[f"answer_{i}"] = answer_node
|
||||
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# All nodes should share the same stop_event
|
||||
for node in mock_graph.nodes.values():
|
||||
assert node.graph_runtime_state.stop_event is runtime_state.stop_event
|
||||
assert node.graph_runtime_state.stop_event is engine._stop_event
|
||||
|
||||
|
||||
class TestStopEventTimeoutBehavior:
|
||||
"""Test stop_event behavior with join timeouts."""
|
||||
|
||||
@patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread")
|
||||
def test_dispatcher_uses_shorter_timeout(self, mock_thread_cls: MagicMock):
|
||||
"""Test that Dispatcher uses 2s timeout instead of 10s."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
dispatcher = engine._dispatcher
|
||||
dispatcher.start() # This will create and start the mocked thread
|
||||
|
||||
mock_thread_instance = mock_thread_cls.return_value
|
||||
mock_thread_instance.is_alive.return_value = True
|
||||
|
||||
dispatcher.stop()
|
||||
|
||||
mock_thread_instance.join.assert_called_once_with(timeout=2.0)
|
||||
|
||||
@patch("core.workflow.graph_engine.worker_management.worker_pool.Worker")
|
||||
def test_worker_pool_uses_shorter_timeout(self, mock_worker_cls: MagicMock):
|
||||
"""Test that WorkerPool uses 2s timeout instead of 10s."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
worker_pool = engine._worker_pool
|
||||
worker_pool.start(initial_count=1) # Start with one worker
|
||||
|
||||
mock_worker_instance = mock_worker_cls.return_value
|
||||
mock_worker_instance.is_alive.return_value = True
|
||||
|
||||
worker_pool.stop()
|
||||
|
||||
mock_worker_instance.join.assert_called_once_with(timeout=2.0)
|
||||
|
||||
|
||||
class TestStopEventResumeBehavior:
|
||||
"""Test stop_event behavior during workflow resume."""
|
||||
|
||||
def test_stop_event_cleared_on_resume(self):
|
||||
"""Test that stop_event is cleared when resuming a paused workflow."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start" # Set proper id
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
mock_graph.nodes["start"] = start_node
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Simulate a previous execution that set stop_event
|
||||
engine._stop_event.set()
|
||||
assert engine._stop_event.is_set()
|
||||
|
||||
# Run the engine (should clear stop_event in _start_execution)
|
||||
events = list(engine.run())
|
||||
|
||||
# Execution should complete successfully
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
|
||||
|
||||
|
||||
class TestWorkerStopBehavior:
|
||||
"""Test Worker behavior with shared stop_event."""
|
||||
|
||||
def test_worker_uses_shared_stop_event(self):
|
||||
"""Test that Worker uses shared stop_event from GraphEngine."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Get the worker pool and check workers
|
||||
worker_pool = engine._worker_pool
|
||||
|
||||
# Start the worker pool to create workers
|
||||
worker_pool.start()
|
||||
|
||||
# Check that at least one worker was created
|
||||
assert len(worker_pool._workers) > 0
|
||||
|
||||
# Verify workers use the shared stop_event
|
||||
for worker in worker_pool._workers:
|
||||
assert worker._stop_event is engine._stop_event
|
||||
|
||||
# Clean up
|
||||
worker_pool.stop()
|
||||
|
||||
def test_worker_stop_is_noop(self):
|
||||
"""Test that Worker.stop() is now a no-op."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
# Create a mock worker
|
||||
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
|
||||
from core.workflow.graph_engine.worker import Worker
|
||||
|
||||
ready_queue = InMemoryReadyQueue()
|
||||
event_queue = MagicMock()
|
||||
|
||||
# Create a proper mock graph with real dict
|
||||
mock_graph = Mock(spec=Graph)
|
||||
mock_graph.nodes = {} # Use real dict
|
||||
|
||||
stop_event = threading.Event()
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=ready_queue,
|
||||
event_queue=event_queue,
|
||||
graph=mock_graph,
|
||||
layers=[],
|
||||
stop_event=stop_event,
|
||||
)
|
||||
|
||||
# Calling stop() should do nothing (no-op)
|
||||
# and should NOT set the stop_event
|
||||
worker.stop()
|
||||
assert not stop_event.is_set()
|
||||
@ -78,7 +78,7 @@ class TestFileSaverImpl:
|
||||
file_binary=_PNG_DATA,
|
||||
mimetype=mime_type,
|
||||
)
|
||||
mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png")
|
||||
mocked_sign_file.assert_called_once_with(tool_file_id=mock_tool_file.id, extension=".png", for_external=True)
|
||||
|
||||
def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch):
|
||||
_TEST_URL = "https://example.com/image.png"
|
||||
|
||||
@ -5,8 +5,8 @@ from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.template_transform.template_renderer import TemplateRenderError
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
@ -127,7 +127,9 @@ class TestTemplateTransformNode:
|
||||
"""Test version class method."""
|
||||
assert TemplateTransformNode.version() == "1"
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_simple_template(
|
||||
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
):
|
||||
@ -145,7 +147,7 @@ class TestTemplateTransformNode:
|
||||
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
|
||||
|
||||
# Setup mock executor
|
||||
mock_execute.return_value = {"result": "Hello Alice, you are 30 years old!"}
|
||||
mock_execute.return_value = "Hello Alice, you are 30 years old!"
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
@ -162,7 +164,9 @@ class TestTemplateTransformNode:
|
||||
assert result.inputs["name"] == "Alice"
|
||||
assert result.inputs["age"] == 30
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with None variable values."""
|
||||
node_data = {
|
||||
@ -172,7 +176,7 @@ class TestTemplateTransformNode:
|
||||
}
|
||||
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = None
|
||||
mock_execute.return_value = {"result": "Value: "}
|
||||
mock_execute.return_value = "Value: "
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
@ -187,13 +191,15 @@ class TestTemplateTransformNode:
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.inputs["value"] is None
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_code_execution_error(
|
||||
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
):
|
||||
"""Test _run when code execution fails."""
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
|
||||
mock_execute.side_effect = CodeExecutionError("Template syntax error")
|
||||
mock_execute.side_effect = TemplateRenderError("Template syntax error")
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
@ -208,14 +214,16 @@ class TestTemplateTransformNode:
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "Template syntax error" in result.error
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH", 10)
|
||||
def test_run_output_length_exceeds_limit(
|
||||
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
):
|
||||
"""Test _run when output exceeds maximum length."""
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
|
||||
mock_execute.return_value = {"result": "This is a very long output that exceeds the limit"}
|
||||
mock_execute.return_value = "This is a very long output that exceeds the limit"
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
@ -230,7 +238,9 @@ class TestTemplateTransformNode:
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "Output length exceeds" in result.error
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_complex_jinja2_template(
|
||||
self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
):
|
||||
@ -257,7 +267,7 @@ class TestTemplateTransformNode:
|
||||
("sys", "show_total"): mock_show_total,
|
||||
}
|
||||
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
|
||||
mock_execute.return_value = {"result": "apple, banana, orange (Total: 3)"}
|
||||
mock_execute.return_value = "apple, banana, orange (Total: 3)"
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
@ -292,7 +302,9 @@ class TestTemplateTransformNode:
|
||||
assert mapping["node_123.var1"] == ["sys", "input1"]
|
||||
assert mapping["node_123.var2"] == ["sys", "input2"]
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with no variables (static template)."""
|
||||
node_data = {
|
||||
@ -301,7 +313,7 @@ class TestTemplateTransformNode:
|
||||
"template": "This is a static message.",
|
||||
}
|
||||
|
||||
mock_execute.return_value = {"result": "This is a static message."}
|
||||
mock_execute.return_value = "This is a static message."
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
@ -317,7 +329,9 @@ class TestTemplateTransformNode:
|
||||
assert result.outputs["output"] == "This is a static message."
|
||||
assert result.inputs == {}
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with numeric variable values."""
|
||||
node_data = {
|
||||
@ -339,7 +353,7 @@ class TestTemplateTransformNode:
|
||||
("sys", "quantity"): mock_quantity,
|
||||
}
|
||||
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
|
||||
mock_execute.return_value = {"result": "Total: $31.5"}
|
||||
mock_execute.return_value = "Total: $31.5"
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
@ -354,7 +368,9 @@ class TestTemplateTransformNode:
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["output"] == "Total: $31.5"
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with dictionary variable values."""
|
||||
node_data = {
|
||||
@ -367,7 +383,7 @@ class TestTemplateTransformNode:
|
||||
mock_user.to_object.return_value = {"name": "John Doe", "email": "john@example.com"}
|
||||
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_user
|
||||
mock_execute.return_value = {"result": "Name: John Doe, Email: john@example.com"}
|
||||
mock_execute.return_value = "Name: John Doe, Email: john@example.com"
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
@ -383,7 +399,9 @@ class TestTemplateTransformNode:
|
||||
assert "John Doe" in result.outputs["output"]
|
||||
assert "john@example.com" in result.outputs["output"]
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with list variable values."""
|
||||
node_data = {
|
||||
@ -396,7 +414,7 @@ class TestTemplateTransformNode:
|
||||
mock_tags.to_object.return_value = ["python", "ai", "workflow"]
|
||||
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_tags
|
||||
mock_execute.return_value = {"result": "Tags: #python #ai #workflow "}
|
||||
mock_execute.return_value = "Tags: #python #ai #workflow "
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest import mock
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import ArrayStringVariable, StringVariable
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events.node import NodeRunSucceededEvent
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode
|
||||
from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
@ -86,9 +86,6 @@ def test_overwrite_string_variable():
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
@ -104,20 +101,14 @@ def test_overwrite_string_variable():
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
list(node.run())
|
||||
expected_var = StringVariable(
|
||||
id=conversation_variable.id,
|
||||
name=conversation_variable.name,
|
||||
description=conversation_variable.description,
|
||||
selector=conversation_variable.selector,
|
||||
value_type=conversation_variable.value_type,
|
||||
value=input_variable.value,
|
||||
)
|
||||
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
|
||||
mock_conv_var_updater.flush.assert_called_once()
|
||||
events = list(node.run())
|
||||
succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent))
|
||||
updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data)
|
||||
assert updated_variables is not None
|
||||
assert updated_variables[0].name == conversation_variable.name
|
||||
assert updated_variables[0].new_value == input_variable.value
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
@ -191,9 +182,6 @@ def test_append_variable_to_array():
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
@ -209,22 +197,14 @@ def test_append_variable_to_array():
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
list(node.run())
|
||||
expected_value = list(conversation_variable.value)
|
||||
expected_value.append(input_variable.value)
|
||||
expected_var = ArrayStringVariable(
|
||||
id=conversation_variable.id,
|
||||
name=conversation_variable.name,
|
||||
description=conversation_variable.description,
|
||||
selector=conversation_variable.selector,
|
||||
value_type=conversation_variable.value_type,
|
||||
value=expected_value,
|
||||
)
|
||||
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
|
||||
mock_conv_var_updater.flush.assert_called_once()
|
||||
events = list(node.run())
|
||||
succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent))
|
||||
updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data)
|
||||
assert updated_variables is not None
|
||||
assert updated_variables[0].name == conversation_variable.name
|
||||
assert updated_variables[0].new_value == ["the first value", "the second value"]
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
@ -287,9 +267,6 @@ def test_clear_array():
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
@ -305,20 +282,14 @@ def test_clear_array():
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
list(node.run())
|
||||
expected_var = ArrayStringVariable(
|
||||
id=conversation_variable.id,
|
||||
name=conversation_variable.name,
|
||||
description=conversation_variable.description,
|
||||
selector=conversation_variable.selector,
|
||||
value_type=conversation_variable.value_type,
|
||||
value=[],
|
||||
)
|
||||
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
|
||||
mock_conv_var_updater.flush.assert_called_once()
|
||||
events = list(node.run())
|
||||
succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent))
|
||||
updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data)
|
||||
assert updated_variables is not None
|
||||
assert updated_variables[0].name == conversation_variable.name
|
||||
assert updated_variables[0].new_value == []
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
|
||||
@ -390,3 +390,42 @@ def test_remove_last_from_empty_array():
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
assert got.to_object() == []
|
||||
|
||||
|
||||
def test_node_factory_creates_variable_assigner_node():
|
||||
graph_config = {
|
||||
"edges": [],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []},
|
||||
"id": "assigner",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(conversation_id="conversation_id"),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
node = node_factory.create_node(graph_config["nodes"][0])
|
||||
|
||||
assert isinstance(node, VariableAssignerNode)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from libs.helper import extract_tenant_id
|
||||
from libs.helper import escape_like_pattern, extract_tenant_id
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
||||
@ -63,3 +63,51 @@ class TestExtractTenantId:
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
|
||||
extract_tenant_id(dict_user)
|
||||
|
||||
|
||||
class TestEscapeLikePattern:
|
||||
"""Test cases for the escape_like_pattern utility function."""
|
||||
|
||||
def test_escape_percent_character(self):
|
||||
"""Test escaping percent character."""
|
||||
result = escape_like_pattern("50% discount")
|
||||
assert result == "50\\% discount"
|
||||
|
||||
def test_escape_underscore_character(self):
|
||||
"""Test escaping underscore character."""
|
||||
result = escape_like_pattern("test_data")
|
||||
assert result == "test\\_data"
|
||||
|
||||
def test_escape_backslash_character(self):
|
||||
"""Test escaping backslash character."""
|
||||
result = escape_like_pattern("path\\to\\file")
|
||||
assert result == "path\\\\to\\\\file"
|
||||
|
||||
def test_escape_combined_special_characters(self):
|
||||
"""Test escaping multiple special characters together."""
|
||||
result = escape_like_pattern("file_50%\\path")
|
||||
assert result == "file\\_50\\%\\\\path"
|
||||
|
||||
def test_escape_empty_string(self):
|
||||
"""Test escaping empty string returns empty string."""
|
||||
result = escape_like_pattern("")
|
||||
assert result == ""
|
||||
|
||||
def test_escape_none_handling(self):
|
||||
"""Test escaping None returns None (falsy check handles it)."""
|
||||
# The function checks `if not pattern`, so None is falsy and returns as-is
|
||||
result = escape_like_pattern(None)
|
||||
assert result is None
|
||||
|
||||
def test_escape_normal_string_no_change(self):
|
||||
"""Test that normal strings without special characters are unchanged."""
|
||||
result = escape_like_pattern("normal text")
|
||||
assert result == "normal text"
|
||||
|
||||
def test_escape_order_matters(self):
|
||||
"""Test that backslash is escaped first to prevent double escaping."""
|
||||
# If we escape % first, then escape \, we might get wrong results
|
||||
# This test ensures the order is correct: \ first, then % and _
|
||||
result = escape_like_pattern("test\\%_value")
|
||||
# Should be: test\\\%\_value
|
||||
assert result == "test\\\\\\%\\_value"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user