mirror of
https://github.com/langgenius/dify.git
synced 2026-01-14 06:07:33 +08:00
Merge branch 'main' into refactor/query-params-nuqs
This commit is contained in:
commit
465135838e
@ -3,6 +3,7 @@
|
||||
"feature-dev@claude-plugins-official": true,
|
||||
"context7@claude-plugins-official": true,
|
||||
"typescript-lsp@claude-plugins-official": true,
|
||||
"pyright-lsp@claude-plugins-official": true
|
||||
"pyright-lsp@claude-plugins-official": true,
|
||||
"ralph-wiggum@claude-plugins-official": true
|
||||
}
|
||||
}
|
||||
|
||||
73
.claude/skills/frontend-code-review/SKILL.md
Normal file
73
.claude/skills/frontend-code-review/SKILL.md
Normal file
@ -0,0 +1,73 @@
|
||||
---
|
||||
name: frontend-code-review
|
||||
description: "Trigger when the user requests a review of frontend files (e.g., `.tsx`, `.ts`, `.js`). Support both pending-change reviews and focused file reviews while applying the checklist rules."
|
||||
---
|
||||
|
||||
# Frontend Code Review
|
||||
|
||||
## Intent
|
||||
Use this skill whenever the user asks to review frontend code (especially `.tsx`, `.ts`, or `.js` files). Support two review modes:
|
||||
|
||||
1. **Pending-change review** – inspect staged/working-tree files slated for commit and flag checklist violations before submission.
|
||||
2. **File-targeted review** – review the specific file(s) the user names and report the relevant checklist findings.
|
||||
|
||||
Stick to the checklist below for every applicable file and mode.
|
||||
|
||||
## Checklist
|
||||
See [references/code-quality.md](references/code-quality.md), [references/performance.md](references/performance.md), [references/business-logic.md](references/business-logic.md) for the living checklist split by category—treat it as the canonical set of rules to follow.
|
||||
|
||||
Flag each rule violation with urgency metadata so future reviewers can prioritize fixes.
|
||||
|
||||
## Review Process
|
||||
1. Open the relevant component/module. Gather lines that relate to class names, React Flow hooks, prop memoization, and styling.
|
||||
2. For each rule in the review point, note where the code deviates and capture a representative snippet.
|
||||
3. Compose the review section per the template below. Group violations first by **Urgent** flag, then by category order (Code Quality, Performance, Business Logic).
|
||||
|
||||
## Required output
|
||||
When invoked, the response must exactly follow one of the two templates:
|
||||
|
||||
### Template A (any findings)
|
||||
```
|
||||
# Code review
|
||||
Found <N> urgent issues need to be fixed:
|
||||
|
||||
## 1 <brief description of bug>
|
||||
FilePath: <path> line <line>
|
||||
<relevant code snippet or pointer>
|
||||
|
||||
|
||||
### Suggested fix
|
||||
<brief description of suggested fix>
|
||||
|
||||
---
|
||||
... (repeat for each urgent issue) ...
|
||||
|
||||
Found <M> suggestions for improvement:
|
||||
|
||||
## 1 <brief description of suggestion>
|
||||
FilePath: <path> line <line>
|
||||
<relevant code snippet or pointer>
|
||||
|
||||
|
||||
### Suggested fix
|
||||
<brief description of suggested fix>
|
||||
|
||||
---
|
||||
|
||||
... (repeat for each suggestion) ...
|
||||
```
|
||||
|
||||
If there are no urgent issues, omit that section. If there are no suggestions, omit that section.
|
||||
|
||||
If the issue number is more than 10, summarize as "10+ urgent issues" or "10+ suggestions" and just output the first 10 issues.
|
||||
|
||||
Don't compress the blank lines between sections; keep them as-is for readability.
|
||||
|
||||
If you use Template A (i.e., there are issues to fix) and at least one issue requires code changes, append a brief follow-up question after the structured output asking whether the user wants you to apply the suggested fix(es). For example: "Would you like me to use the Suggested fix section to address these issues?"
|
||||
|
||||
### Template B (no issues)
|
||||
```
|
||||
## Code review
|
||||
No issues found.
|
||||
```
|
||||
|
||||
@ -0,0 +1,15 @@
|
||||
# Rule Catalog — Business Logic
|
||||
|
||||
## Can't use workflowStore in Node components
|
||||
|
||||
IsUrgent: True
|
||||
|
||||
### Description
|
||||
|
||||
File path pattern of node components: `web/app/components/workflow/nodes/[nodeName]/node.tsx`
|
||||
|
||||
Node components are also used when creating a RAG Pipe from a template, but in that context there is no workflowStore Provider, which results in a blank screen. [This Issue](https://github.com/langgenius/dify/issues/29168) was caused by exactly this reason.
|
||||
|
||||
### Suggested Fix
|
||||
|
||||
Use `import { useNodes } from 'reactflow'` instead of `import useNodes from '@/app/components/workflow/store/workflow/use-nodes'`.
|
||||
@ -0,0 +1,44 @@
|
||||
# Rule Catalog — Code Quality
|
||||
|
||||
## Conditional class names use utility function
|
||||
|
||||
IsUrgent: True
|
||||
Category: Code Quality
|
||||
|
||||
### Description
|
||||
|
||||
Ensure conditional CSS is handled via the shared `classNames` instead of custom ternaries, string concatenation, or template strings. Centralizing class logic keeps components consistent and easier to maintain.
|
||||
|
||||
### Suggested Fix
|
||||
|
||||
```ts
|
||||
import { cn } from '@/utils/classnames'
|
||||
const classNames = cn(isActive ? 'text-primary-600' : 'text-gray-500')
|
||||
```
|
||||
|
||||
## Tailwind-first styling
|
||||
|
||||
IsUrgent: True
|
||||
Category: Code Quality
|
||||
|
||||
### Description
|
||||
|
||||
Favor Tailwind CSS utility classes instead of adding new `.module.css` files unless a Tailwind combination cannot achieve the required styling. Keeping styles in Tailwind improves consistency and reduces maintenance overhead.
|
||||
|
||||
Update this file when adding, editing, or removing Code Quality rules so the catalog remains accurate.
|
||||
|
||||
## Classname ordering for easy overrides
|
||||
|
||||
### Description
|
||||
|
||||
When writing components, always place the incoming `className` prop after the component’s own class values so that downstream consumers can override or extend the styling. This keeps your component’s defaults but still lets external callers change or remove specific styles.
|
||||
|
||||
Example:
|
||||
|
||||
```tsx
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
const Button = ({ className }) => {
|
||||
return <div className={cn('bg-primary-600', className)}></div>
|
||||
}
|
||||
```
|
||||
@ -0,0 +1,45 @@
|
||||
# Rule Catalog — Performance
|
||||
|
||||
## React Flow data usage
|
||||
|
||||
IsUrgent: True
|
||||
Category: Performance
|
||||
|
||||
### Description
|
||||
|
||||
When rendering React Flow, prefer `useNodes`/`useEdges` for UI consumption and rely on `useStoreApi` inside callbacks that mutate or read node/edge state. Avoid manually pulling Flow data outside of these hooks.
|
||||
|
||||
## Complex prop memoization
|
||||
|
||||
IsUrgent: True
|
||||
Category: Performance
|
||||
|
||||
### Description
|
||||
|
||||
Wrap complex prop values (objects, arrays, maps) in `useMemo` prior to passing them into child components to guarantee stable references and prevent unnecessary renders.
|
||||
|
||||
Update this file when adding, editing, or removing Performance rules so the catalog remains accurate.
|
||||
|
||||
Wrong:
|
||||
|
||||
```tsx
|
||||
<HeavyComp
|
||||
config={{
|
||||
provider: ...,
|
||||
detail: ...
|
||||
}}
|
||||
/>
|
||||
```
|
||||
|
||||
Right:
|
||||
|
||||
```tsx
|
||||
const config = useMemo(() => ({
|
||||
provider: ...,
|
||||
detail: ...
|
||||
}), [provider, detail]);
|
||||
|
||||
<HeavyComp
|
||||
config={config}
|
||||
/>
|
||||
```
|
||||
@ -28,17 +28,14 @@ import userEvent from '@testing-library/user-event'
|
||||
|
||||
// i18n (automatically mocked)
|
||||
// WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup
|
||||
// No explicit mock needed - it returns translation keys as-is
|
||||
// The global mock provides: useTranslation, Trans, useMixedTranslation, useGetLanguage
|
||||
// No explicit mock needed for most tests
|
||||
//
|
||||
// Override only if custom translations are required:
|
||||
// vi.mock('react-i18next', () => ({
|
||||
// useTranslation: () => ({
|
||||
// t: (key: string) => {
|
||||
// const customTranslations: Record<string, string> = {
|
||||
// 'my.custom.key': 'Custom Translation',
|
||||
// }
|
||||
// return customTranslations[key] || key
|
||||
// },
|
||||
// }),
|
||||
// import { createReactI18nextMock } from '@/test/i18n-mock'
|
||||
// vi.mock('react-i18next', () => createReactI18nextMock({
|
||||
// 'my.custom.key': 'Custom Translation',
|
||||
// 'button.save': 'Save',
|
||||
// }))
|
||||
|
||||
// Router (if component uses useRouter, usePathname, useSearchParams)
|
||||
|
||||
@ -52,23 +52,29 @@ Modules are not mocked automatically. Use `vi.mock` in test files, or add global
|
||||
### 1. i18n (Auto-loaded via Global Mock)
|
||||
|
||||
A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
|
||||
**No explicit mock needed** for most tests - it returns translation keys as-is.
|
||||
|
||||
For tests requiring custom translations, override the mock:
|
||||
The global mock provides:
|
||||
|
||||
- `useTranslation` - returns translation keys with namespace prefix
|
||||
- `Trans` component - renders i18nKey and components
|
||||
- `useMixedTranslation` (from `@/app/components/plugins/marketplace/hooks`)
|
||||
- `useGetLanguage` (from `@/context/i18n`) - returns `'en-US'`
|
||||
|
||||
**Default behavior**: Most tests should use the global mock (no local override needed).
|
||||
|
||||
**For custom translations**: Use the helper function from `@/test/i18n-mock`:
|
||||
|
||||
```typescript
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => {
|
||||
const translations: Record<string, string> = {
|
||||
'my.custom.key': 'Custom translation',
|
||||
}
|
||||
return translations[key] || key
|
||||
},
|
||||
}),
|
||||
import { createReactI18nextMock } from '@/test/i18n-mock'
|
||||
|
||||
vi.mock('react-i18next', () => createReactI18nextMock({
|
||||
'my.custom.key': 'Custom translation',
|
||||
'button.save': 'Save',
|
||||
}))
|
||||
```
|
||||
|
||||
**Avoid**: Manually defining `useTranslation` mocks that just return the key - the global mock already does this.
|
||||
|
||||
### 2. Next.js Router
|
||||
|
||||
```typescript
|
||||
|
||||
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@ -20,4 +20,4 @@
|
||||
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
|
||||
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
|
||||
- [x] I've updated the documentation accordingly.
|
||||
- [x] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
|
||||
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods
|
||||
|
||||
10
.github/workflows/style.yml
vendored
10
.github/workflows/style.yml
vendored
@ -110,6 +110,16 @@ jobs:
|
||||
working-directory: ./web
|
||||
run: pnpm run type-check:tsgo
|
||||
|
||||
- name: Web dead code check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run knip
|
||||
|
||||
- name: Web build check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run build
|
||||
|
||||
superlinter:
|
||||
name: SuperLinter
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
@ -5,6 +5,7 @@ on:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'web/i18n/en-US/*.json'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@ -18,7 +19,8 @@ jobs:
|
||||
run:
|
||||
working-directory: web
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
# Keep use old checkout action version for https://github.com/peter-evans/create-pull-request/issues/4272
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@ -26,21 +28,28 @@ jobs:
|
||||
- name: Check for file changes in i18n/en-US
|
||||
id: check_files
|
||||
run: |
|
||||
git fetch origin "${{ github.event.before }}" || true
|
||||
git fetch origin "${{ github.sha }}" || true
|
||||
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
|
||||
echo "Changed files: $changed_files"
|
||||
if [ -n "$changed_files" ]; then
|
||||
# Skip check for manual trigger, translate all files
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
||||
file_args=""
|
||||
for file in $changed_files; do
|
||||
filename=$(basename "$file" .json)
|
||||
file_args="$file_args --file $filename"
|
||||
done
|
||||
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
|
||||
echo "File arguments: $file_args"
|
||||
echo "FILE_ARGS=" >> $GITHUB_ENV
|
||||
echo "Manual trigger: translating all files"
|
||||
else
|
||||
echo "FILES_CHANGED=false" >> $GITHUB_ENV
|
||||
git fetch origin "${{ github.event.before }}" || true
|
||||
git fetch origin "${{ github.sha }}" || true
|
||||
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
|
||||
echo "Changed files: $changed_files"
|
||||
if [ -n "$changed_files" ]; then
|
||||
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
||||
file_args=""
|
||||
for file in $changed_files; do
|
||||
filename=$(basename "$file" .json)
|
||||
file_args="$file_args --file $filename"
|
||||
done
|
||||
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
|
||||
echo "File arguments: $file_args"
|
||||
else
|
||||
echo "FILES_CHANGED=false" >> $GITHUB_ENV
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Install pnpm
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -235,3 +235,4 @@ scripts/stress-test/reports/
|
||||
|
||||
# settings
|
||||
*.local.json
|
||||
*.local.md
|
||||
|
||||
5
Makefile
5
Makefile
@ -60,9 +60,10 @@ check:
|
||||
@echo "✅ Code check complete"
|
||||
|
||||
lint:
|
||||
@echo "🔧 Running ruff format, check with fixes, and import linter..."
|
||||
@echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..."
|
||||
@uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
|
||||
@uv run --directory api --dev lint-imports
|
||||
@uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
type-check:
|
||||
@ -122,7 +123,7 @@ help:
|
||||
@echo "Backend Code Quality:"
|
||||
@echo " make format - Format code with ruff"
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format and fix code with ruff"
|
||||
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
|
||||
@echo " make type-check - Run type checking with basedpyright"
|
||||
@echo " make test - Run backend unit tests"
|
||||
@echo ""
|
||||
|
||||
@ -101,6 +101,15 @@ S3_ACCESS_KEY=your-access-key
|
||||
S3_SECRET_KEY=your-secret-key
|
||||
S3_REGION=your-region
|
||||
|
||||
# Workflow run and Conversation archive storage (S3-compatible)
|
||||
ARCHIVE_STORAGE_ENABLED=false
|
||||
ARCHIVE_STORAGE_ENDPOINT=
|
||||
ARCHIVE_STORAGE_ARCHIVE_BUCKET=
|
||||
ARCHIVE_STORAGE_EXPORT_BUCKET=
|
||||
ARCHIVE_STORAGE_ACCESS_KEY=
|
||||
ARCHIVE_STORAGE_SECRET_KEY=
|
||||
ARCHIVE_STORAGE_REGION=auto
|
||||
|
||||
# Azure Blob Storage configuration
|
||||
AZURE_BLOB_ACCOUNT_NAME=your-account-name
|
||||
AZURE_BLOB_ACCOUNT_KEY=your-account-key
|
||||
@ -493,6 +502,8 @@ LOG_FILE_BACKUP_COUNT=5
|
||||
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
|
||||
# Log Timezone
|
||||
LOG_TZ=UTC
|
||||
# Log output format: text or json
|
||||
LOG_OUTPUT_FORMAT=text
|
||||
# Log format
|
||||
LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s
|
||||
|
||||
|
||||
@ -3,9 +3,11 @@ root_packages =
|
||||
core
|
||||
configs
|
||||
controllers
|
||||
extensions
|
||||
models
|
||||
tasks
|
||||
services
|
||||
include_external_packages = True
|
||||
|
||||
[importlinter:contract:workflow]
|
||||
name = Workflow
|
||||
@ -33,6 +35,29 @@ ignore_imports =
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
||||
|
||||
[importlinter:contract:workflow-infrastructure-dependencies]
|
||||
name = Workflow Infrastructure Dependencies
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.workflow
|
||||
forbidden_modules =
|
||||
extensions.ext_database
|
||||
extensions.ext_redis
|
||||
allow_indirect_imports = True
|
||||
ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
core.workflow.nodes.variable_assigner.common.impl -> extensions.ext_database
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
||||
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
type = layers
|
||||
|
||||
@ -2,9 +2,11 @@ import logging
|
||||
import time
|
||||
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from core.logging.context import init_request_context
|
||||
from dify_app import DifyApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -25,28 +27,35 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
# add before request hook
|
||||
@dify_app.before_request
|
||||
def before_request():
|
||||
# add an unique identifier to each request
|
||||
# Initialize logging context for this request
|
||||
init_request_context()
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
# add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
|
||||
# add after request hook for injecting trace headers from OpenTelemetry span context
|
||||
# Only adds headers when OTEL is enabled and has valid context
|
||||
@dify_app.after_request
|
||||
def add_trace_id_header(response):
|
||||
def add_trace_headers(response):
|
||||
try:
|
||||
span = get_current_span()
|
||||
ctx = span.get_span_context() if span else None
|
||||
if ctx and ctx.is_valid:
|
||||
trace_id_hex = format(ctx.trace_id, "032x")
|
||||
# Avoid duplicates if some middleware added it
|
||||
if "X-Trace-Id" not in response.headers:
|
||||
response.headers["X-Trace-Id"] = trace_id_hex
|
||||
|
||||
if not ctx or not ctx.is_valid:
|
||||
return response
|
||||
|
||||
# Inject trace headers from OTEL context
|
||||
if ctx.trace_id != INVALID_TRACE_ID and "X-Trace-Id" not in response.headers:
|
||||
response.headers["X-Trace-Id"] = format(ctx.trace_id, "032x")
|
||||
if ctx.span_id != INVALID_SPAN_ID and "X-Span-Id" not in response.headers:
|
||||
response.headers["X-Span-Id"] = format(ctx.span_id, "016x")
|
||||
|
||||
except Exception:
|
||||
# Never break the response due to tracing header injection
|
||||
logger.warning("Failed to add trace ID to response header", exc_info=True)
|
||||
logger.warning("Failed to add trace headers to response", exc_info=True)
|
||||
return response
|
||||
|
||||
# Capture the decorator's return value to avoid pyright reportUnusedFunction
|
||||
_ = before_request
|
||||
_ = add_trace_id_header
|
||||
_ = add_trace_headers
|
||||
|
||||
return dify_app
|
||||
|
||||
|
||||
211
api/commands.py
211
api/commands.py
@ -1184,6 +1184,217 @@ def remove_orphaned_files_on_storage(force: bool):
|
||||
click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow"))
|
||||
|
||||
|
||||
@click.command("file-usage", help="Query file usages and show where files are referenced.")
|
||||
@click.option("--file-id", type=str, default=None, help="Filter by file UUID.")
|
||||
@click.option("--key", type=str, default=None, help="Filter by storage key.")
|
||||
@click.option("--src", type=str, default=None, help="Filter by table.column pattern (e.g., 'documents.%' or '%.icon').")
|
||||
@click.option("--limit", type=int, default=100, help="Limit number of results (default: 100).")
|
||||
@click.option("--offset", type=int, default=0, help="Offset for pagination (default: 0).")
|
||||
@click.option("--json", "output_json", is_flag=True, help="Output results in JSON format.")
|
||||
def file_usage(
|
||||
file_id: str | None,
|
||||
key: str | None,
|
||||
src: str | None,
|
||||
limit: int,
|
||||
offset: int,
|
||||
output_json: bool,
|
||||
):
|
||||
"""
|
||||
Query file usages and show where files are referenced in the database.
|
||||
|
||||
This command reuses the same reference checking logic as clear-orphaned-file-records
|
||||
and displays detailed information about where each file is referenced.
|
||||
"""
|
||||
# define tables and columns to process
|
||||
files_tables = [
|
||||
{"table": "upload_files", "id_column": "id", "key_column": "key"},
|
||||
{"table": "tool_files", "id_column": "id", "key_column": "file_key"},
|
||||
]
|
||||
ids_tables = [
|
||||
{"type": "uuid", "table": "message_files", "column": "upload_file_id", "pk_column": "id"},
|
||||
{"type": "text", "table": "documents", "column": "data_source_info", "pk_column": "id"},
|
||||
{"type": "text", "table": "document_segments", "column": "content", "pk_column": "id"},
|
||||
{"type": "text", "table": "messages", "column": "answer", "pk_column": "id"},
|
||||
{"type": "text", "table": "workflow_node_executions", "column": "inputs", "pk_column": "id"},
|
||||
{"type": "text", "table": "workflow_node_executions", "column": "process_data", "pk_column": "id"},
|
||||
{"type": "text", "table": "workflow_node_executions", "column": "outputs", "pk_column": "id"},
|
||||
{"type": "text", "table": "conversations", "column": "introduction", "pk_column": "id"},
|
||||
{"type": "text", "table": "conversations", "column": "system_instruction", "pk_column": "id"},
|
||||
{"type": "text", "table": "accounts", "column": "avatar", "pk_column": "id"},
|
||||
{"type": "text", "table": "apps", "column": "icon", "pk_column": "id"},
|
||||
{"type": "text", "table": "sites", "column": "icon", "pk_column": "id"},
|
||||
{"type": "json", "table": "messages", "column": "inputs", "pk_column": "id"},
|
||||
{"type": "json", "table": "messages", "column": "message", "pk_column": "id"},
|
||||
]
|
||||
|
||||
# Stream file usages with pagination to avoid holding all results in memory
|
||||
paginated_usages = []
|
||||
total_count = 0
|
||||
|
||||
# First, build a mapping of file_id -> storage_key from the base tables
|
||||
file_key_map = {}
|
||||
for files_table in files_tables:
|
||||
query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}"
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
file_key_map[str(row[0])] = f"{files_table['table']}:{row[1]}"
|
||||
|
||||
# If filtering by key or file_id, verify it exists
|
||||
if file_id and file_id not in file_key_map:
|
||||
if output_json:
|
||||
click.echo(json.dumps({"error": f"File ID {file_id} not found in base tables"}))
|
||||
else:
|
||||
click.echo(click.style(f"File ID {file_id} not found in base tables.", fg="red"))
|
||||
return
|
||||
|
||||
if key:
|
||||
valid_prefixes = {f"upload_files:{key}", f"tool_files:{key}"}
|
||||
matching_file_ids = [fid for fid, fkey in file_key_map.items() if fkey in valid_prefixes]
|
||||
if not matching_file_ids:
|
||||
if output_json:
|
||||
click.echo(json.dumps({"error": f"Key {key} not found in base tables"}))
|
||||
else:
|
||||
click.echo(click.style(f"Key {key} not found in base tables.", fg="red"))
|
||||
return
|
||||
|
||||
guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
|
||||
# For each reference table/column, find matching file IDs and record the references
|
||||
for ids_table in ids_tables:
|
||||
src_filter = f"{ids_table['table']}.{ids_table['column']}"
|
||||
|
||||
# Skip if src filter doesn't match (use fnmatch for wildcard patterns)
|
||||
if src:
|
||||
if "%" in src or "_" in src:
|
||||
import fnmatch
|
||||
|
||||
# Convert SQL LIKE wildcards to fnmatch wildcards (% -> *, _ -> ?)
|
||||
pattern = src.replace("%", "*").replace("_", "?")
|
||||
if not fnmatch.fnmatch(src_filter, pattern):
|
||||
continue
|
||||
else:
|
||||
if src_filter != src:
|
||||
continue
|
||||
|
||||
if ids_table["type"] == "uuid":
|
||||
# Direct UUID match
|
||||
query = (
|
||||
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
|
||||
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
record_id = str(row[0])
|
||||
ref_file_id = str(row[1])
|
||||
if ref_file_id not in file_key_map:
|
||||
continue
|
||||
storage_key = file_key_map[ref_file_id]
|
||||
|
||||
# Apply filters
|
||||
if file_id and ref_file_id != file_id:
|
||||
continue
|
||||
if key and not storage_key.endswith(key):
|
||||
continue
|
||||
|
||||
# Only collect items within the requested page range
|
||||
if offset <= total_count < offset + limit:
|
||||
paginated_usages.append(
|
||||
{
|
||||
"src": f"{ids_table['table']}.{ids_table['column']}",
|
||||
"record_id": record_id,
|
||||
"file_id": ref_file_id,
|
||||
"key": storage_key,
|
||||
}
|
||||
)
|
||||
total_count += 1
|
||||
|
||||
elif ids_table["type"] in ("text", "json"):
|
||||
# Extract UUIDs from text/json content
|
||||
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
|
||||
query = (
|
||||
f"SELECT {ids_table['pk_column']}, {column_cast} "
|
||||
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
record_id = str(row[0])
|
||||
content = str(row[1])
|
||||
|
||||
# Find all UUIDs in the content
|
||||
import re
|
||||
|
||||
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
|
||||
matches = uuid_pattern.findall(content)
|
||||
|
||||
for ref_file_id in matches:
|
||||
if ref_file_id not in file_key_map:
|
||||
continue
|
||||
storage_key = file_key_map[ref_file_id]
|
||||
|
||||
# Apply filters
|
||||
if file_id and ref_file_id != file_id:
|
||||
continue
|
||||
if key and not storage_key.endswith(key):
|
||||
continue
|
||||
|
||||
# Only collect items within the requested page range
|
||||
if offset <= total_count < offset + limit:
|
||||
paginated_usages.append(
|
||||
{
|
||||
"src": f"{ids_table['table']}.{ids_table['column']}",
|
||||
"record_id": record_id,
|
||||
"file_id": ref_file_id,
|
||||
"key": storage_key,
|
||||
}
|
||||
)
|
||||
total_count += 1
|
||||
|
||||
# Output results
|
||||
if output_json:
|
||||
result = {
|
||||
"total": total_count,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"usages": paginated_usages,
|
||||
}
|
||||
click.echo(json.dumps(result, indent=2))
|
||||
else:
|
||||
click.echo(
|
||||
click.style(f"Found {total_count} file usages (showing {len(paginated_usages)} results)", fg="white")
|
||||
)
|
||||
click.echo("")
|
||||
|
||||
if not paginated_usages:
|
||||
click.echo(click.style("No file usages found matching the specified criteria.", fg="yellow"))
|
||||
return
|
||||
|
||||
# Print table header
|
||||
click.echo(
|
||||
click.style(
|
||||
f"{'Src (Table.Column)':<50} {'Record ID':<40} {'File ID':<40} {'Storage Key':<60}",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
click.echo(click.style("-" * 190, fg="white"))
|
||||
|
||||
# Print each usage
|
||||
for usage in paginated_usages:
|
||||
click.echo(f"{usage['src']:<50} {usage['record_id']:<40} {usage['file_id']:<40} {usage['key']:<60}")
|
||||
|
||||
# Show pagination info
|
||||
if offset + limit < total_count:
|
||||
click.echo("")
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Showing {offset + 1}-{offset + len(paginated_usages)} of {total_count} results", fg="white"
|
||||
)
|
||||
)
|
||||
click.echo(click.style(f"Use --offset {offset + limit} to see next page", fg="white"))
|
||||
|
||||
|
||||
@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from configs.extra.archive_config import ArchiveStorageConfig
|
||||
from configs.extra.notion_config import NotionConfig
|
||||
from configs.extra.sentry_config import SentryConfig
|
||||
|
||||
|
||||
class ExtraServiceConfig(
|
||||
# place the configs in alphabet order
|
||||
ArchiveStorageConfig,
|
||||
NotionConfig,
|
||||
SentryConfig,
|
||||
):
|
||||
|
||||
43
api/configs/extra/archive_config.py
Normal file
43
api/configs/extra/archive_config.py
Normal file
@ -0,0 +1,43 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ArchiveStorageConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for workflow run logs archiving storage.
|
||||
"""
|
||||
|
||||
ARCHIVE_STORAGE_ENABLED: bool = Field(
|
||||
description="Enable workflow run logs archiving to S3-compatible storage",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_ENDPOINT: str | None = Field(
|
||||
description="URL of the S3-compatible storage endpoint (e.g., 'https://storage.example.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_ARCHIVE_BUCKET: str | None = Field(
|
||||
description="Name of the bucket to store archived workflow logs",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_EXPORT_BUCKET: str | None = Field(
|
||||
description="Name of the bucket to store exported workflow runs",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_ACCESS_KEY: str | None = Field(
|
||||
description="Access key ID for authenticating with storage",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_SECRET_KEY: str | None = Field(
|
||||
description="Secret access key for authenticating with storage",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_REGION: str = Field(
|
||||
description="Region for storage (use 'auto' if the provider supports it)",
|
||||
default="auto",
|
||||
)
|
||||
@ -587,6 +587,11 @@ class LoggingConfig(BaseSettings):
|
||||
default="INFO",
|
||||
)
|
||||
|
||||
LOG_OUTPUT_FORMAT: Literal["text", "json"] = Field(
|
||||
description="Log output format: 'text' for human-readable, 'json' for structured JSON logs.",
|
||||
default="text",
|
||||
)
|
||||
|
||||
LOG_FILE: str | None = Field(
|
||||
description="File path for log output.",
|
||||
default=None,
|
||||
|
||||
@ -16,7 +16,6 @@ class MilvusConfig(BaseSettings):
|
||||
description="Authentication token for Milvus, if token-based authentication is enabled",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_USER: str | None = Field(
|
||||
description="Username for authenticating with Milvus, if username/password authentication is enabled",
|
||||
default=None,
|
||||
|
||||
@ -1,62 +1,59 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from libs.helper import AppIconUrlField
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
parameters__system_parameters = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
"workflow_file_upload_limit": fields.Integer,
|
||||
}
|
||||
from pydantic import BaseModel, ConfigDict, computed_field
|
||||
|
||||
from core.file import helpers as file_helpers
|
||||
from models.model import IconType
|
||||
|
||||
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||
JSONObject: TypeAlias = dict[str, Any]
|
||||
|
||||
|
||||
def build_system_parameters_model(api_or_ns: Api | Namespace):
|
||||
"""Build the system parameters model for the API or Namespace."""
|
||||
return api_or_ns.model("SystemParameters", parameters__system_parameters)
|
||||
class SystemParameters(BaseModel):
|
||||
image_file_size_limit: int
|
||||
video_file_size_limit: int
|
||||
audio_file_size_limit: int
|
||||
file_size_limit: int
|
||||
workflow_file_upload_limit: int
|
||||
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(parameters__system_parameters),
|
||||
}
|
||||
class Parameters(BaseModel):
|
||||
opening_statement: str | None = None
|
||||
suggested_questions: list[str]
|
||||
suggested_questions_after_answer: JSONObject
|
||||
speech_to_text: JSONObject
|
||||
text_to_speech: JSONObject
|
||||
retriever_resource: JSONObject
|
||||
annotation_reply: JSONObject
|
||||
more_like_this: JSONObject
|
||||
user_input_form: list[JSONObject]
|
||||
sensitive_word_avoidance: JSONObject
|
||||
file_upload: JSONObject
|
||||
system_parameters: SystemParameters
|
||||
|
||||
|
||||
def build_parameters_model(api_or_ns: Api | Namespace):
|
||||
"""Build the parameters model for the API or Namespace."""
|
||||
copied_fields = parameters_fields.copy()
|
||||
copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns))
|
||||
return api_or_ns.model("Parameters", copied_fields)
|
||||
class Site(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
title: str
|
||||
chat_color_theme: str | None = None
|
||||
chat_color_theme_inverted: bool
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
description: str | None = None
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
default_language: str
|
||||
show_workflow_steps: bool
|
||||
use_icon_as_answer_icon: bool
|
||||
|
||||
site_fields = {
|
||||
"title": fields.String,
|
||||
"chat_color_theme": fields.String,
|
||||
"chat_color_theme_inverted": fields.Boolean,
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"description": fields.String,
|
||||
"copyright": fields.String,
|
||||
"privacy_policy": fields.String,
|
||||
"custom_disclaimer": fields.String,
|
||||
"default_language": fields.String,
|
||||
"show_workflow_steps": fields.Boolean,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
}
|
||||
|
||||
|
||||
def build_site_model(api_or_ns: Api | Namespace):
|
||||
"""Build the site model for the API or Namespace."""
|
||||
return api_or_ns.model("Site", site_fields)
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
if self.icon and self.icon_type == IconType.IMAGE:
|
||||
return file_helpers.get_signed_file_url(self.icon)
|
||||
return None
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import re
|
||||
import uuid
|
||||
from typing import Literal
|
||||
|
||||
@ -73,6 +74,48 @@ class AppListQuery(BaseModel):
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
|
||||
# XSS prevention: patterns that could lead to XSS attacks
|
||||
# Includes: script tags, iframe tags, javascript: protocol, SVG with onload, etc.
|
||||
_XSS_PATTERNS = [
|
||||
r"<script[^>]*>.*?</script>", # Script tags
|
||||
r"<iframe\b[^>]*?(?:/>|>.*?</iframe>)", # Iframe tags (including self-closing)
|
||||
r"javascript:", # JavaScript protocol
|
||||
r"<svg[^>]*?\s+onload\s*=[^>]*>", # SVG with onload handler (attribute-aware, flexible whitespace)
|
||||
r"<.*?on\s*\w+\s*=", # Event handlers like onclick, onerror, etc.
|
||||
r"<object\b[^>]*(?:\s*/>|>.*?</object\s*>)", # Object tags (opening tag)
|
||||
r"<embed[^>]*>", # Embed tags (self-closing)
|
||||
r"<link[^>]*>", # Link tags with javascript
|
||||
]
|
||||
|
||||
|
||||
def _validate_xss_safe(value: str | None, field_name: str = "Field") -> str | None:
|
||||
"""
|
||||
Validate that a string value doesn't contain potential XSS payloads.
|
||||
|
||||
Args:
|
||||
value: The string value to validate
|
||||
field_name: Name of the field for error messages
|
||||
|
||||
Returns:
|
||||
The original value if safe
|
||||
|
||||
Raises:
|
||||
ValueError: If the value contains XSS patterns
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
value_lower = value.lower()
|
||||
for pattern in _XSS_PATTERNS:
|
||||
if re.search(pattern, value_lower, re.DOTALL | re.IGNORECASE):
|
||||
raise ValueError(
|
||||
f"{field_name} contains invalid characters or patterns. "
|
||||
"HTML tags, JavaScript, and other potentially dangerous content are not allowed."
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
@ -81,6 +124,11 @@ class CreateAppPayload(BaseModel):
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("name", "description", mode="before")
|
||||
@classmethod
|
||||
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||
return _validate_xss_safe(value, info.field_name)
|
||||
|
||||
|
||||
class UpdateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
@ -91,6 +139,11 @@ class UpdateAppPayload(BaseModel):
|
||||
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
|
||||
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
|
||||
|
||||
@field_validator("name", "description", mode="before")
|
||||
@classmethod
|
||||
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||
return _validate_xss_safe(value, info.field_name)
|
||||
|
||||
|
||||
class CopyAppPayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Name for the copied app")
|
||||
@ -99,6 +152,11 @@ class CopyAppPayload(BaseModel):
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("name", "description", mode="before")
|
||||
@classmethod
|
||||
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||
return _validate_xss_safe(value, info.field_name)
|
||||
|
||||
|
||||
class AppExportQuery(BaseModel):
|
||||
include_secret: bool = Field(default=False, description="Include secrets in export")
|
||||
|
||||
@ -13,7 +13,6 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import MessageTextField
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.helper import TimestampField
|
||||
@ -177,6 +176,12 @@ annotation_hit_history_model = console_ns.model(
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class MessageTextField(fields.Raw):
|
||||
def format(self, value):
|
||||
return value[0]["text"] if value else ""
|
||||
|
||||
|
||||
# Simple message detail model
|
||||
simple_message_detail_model = console_ns.model(
|
||||
"SimpleMessageDetail",
|
||||
|
||||
@ -751,12 +751,12 @@ class DocumentApi(DocumentResource):
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
"position": document.position,
|
||||
"data_source_type": document.data_source_type,
|
||||
"data_source_info": data_source_info,
|
||||
"data_source_info": document.data_source_info_dict,
|
||||
"data_source_detail_dict": document.data_source_detail_dict,
|
||||
"dataset_process_rule_id": document.dataset_process_rule_id,
|
||||
"dataset_process_rule": dataset_process_rules,
|
||||
"document_process_rule": document_process_rules,
|
||||
@ -784,12 +784,12 @@ class DocumentApi(DocumentResource):
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
"position": document.position,
|
||||
"data_source_type": document.data_source_type,
|
||||
"data_source_info": data_source_info,
|
||||
"data_source_info": document.data_source_info_dict,
|
||||
"data_source_detail_dict": document.data_source_detail_dict,
|
||||
"dataset_process_rule_id": document.dataset_process_rule_id,
|
||||
"dataset_process_rule": dataset_process_rules,
|
||||
"document_process_rule": document_process_rules,
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal_with
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@ -11,7 +10,11 @@ from controllers.console.explore.error import NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from fields.conversation_fields import (
|
||||
ConversationInfiniteScrollPagination,
|
||||
ResultResponse,
|
||||
SimpleConversation,
|
||||
)
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
@ -49,7 +52,6 @@ register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayl
|
||||
endpoint="installed_app_conversations",
|
||||
)
|
||||
class ConversationListApi(InstalledAppResource):
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
|
||||
def get(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
@ -73,7 +75,7 @@ class ConversationListApi(InstalledAppResource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
with Session(db.engine) as session:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
pagination = WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
@ -82,6 +84,13 @@ class ConversationListApi(InstalledAppResource):
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
pinned=args.pinned,
|
||||
)
|
||||
adapter = TypeAdapter(SimpleConversation)
|
||||
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
|
||||
return ConversationInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=conversations,
|
||||
).model_dump(mode="json")
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
@ -105,7 +114,7 @@ class ConversationApi(InstalledAppResource):
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
@ -113,7 +122,6 @@ class ConversationApi(InstalledAppResource):
|
||||
endpoint="installed_app_conversation_rename",
|
||||
)
|
||||
class ConversationRenameApi(InstalledAppResource):
|
||||
@marshal_with(simple_conversation_fields)
|
||||
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
|
||||
def post(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
@ -128,9 +136,14 @@ class ConversationRenameApi(InstalledAppResource):
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
return ConversationService.rename(
|
||||
conversation = ConversationService.rename(
|
||||
app_model, conversation_id, current_user, payload.name, payload.auto_generate
|
||||
)
|
||||
return (
|
||||
TypeAdapter(SimpleConversation)
|
||||
.validate_python(conversation, from_attributes=True)
|
||||
.model_dump(mode="json")
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@ -155,7 +168,7 @@ class ConversationPinApi(InstalledAppResource):
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
@ -174,4 +187,4 @@ class ConversationUnPinApi(InstalledAppResource):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
@ -2,8 +2,7 @@ import logging
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
@ -23,7 +22,8 @@ from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
@ -66,7 +66,6 @@ register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, Mor
|
||||
endpoint="installed_app_messages",
|
||||
)
|
||||
class MessageListApi(InstalledAppResource):
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
|
||||
def get(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@ -78,13 +77,20 @@ class MessageListApi(InstalledAppResource):
|
||||
args = MessageListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
pagination = MessageService.pagination_by_first_id(
|
||||
app_model,
|
||||
current_user,
|
||||
str(args.conversation_id),
|
||||
str(args.first_id) if args.first_id else None,
|
||||
args.limit,
|
||||
)
|
||||
adapter = TypeAdapter(MessageListItem)
|
||||
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
|
||||
return MessageInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=items,
|
||||
).model_dump(mode="json")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except FirstMessageNotExistsError:
|
||||
@ -116,7 +122,7 @@ class MessageFeedbackApi(InstalledAppResource):
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
@ -201,4 +207,4 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
return SuggestedQuestionsResponse(data=questions).model_dump(mode="json")
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from flask_restx import marshal_with
|
||||
|
||||
from controllers.common import fields
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
@ -13,7 +11,6 @@ from services.app_service import AppService
|
||||
class AppParameterApi(InstalledAppResource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, installed_app: InstalledApp):
|
||||
"""Retrieve app parameters."""
|
||||
app_model = installed_app.app
|
||||
@ -37,7 +34,8 @@ class AppParameterApi(InstalledAppResource):
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField, UUIDStrOrEmpty
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
@ -26,28 +26,8 @@ class SavedMessageCreatePayload(BaseModel):
|
||||
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||
|
||||
|
||||
feedback_fields = {"rating": fields.String}
|
||||
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"inputs": fields.Raw,
|
||||
"query": fields.String,
|
||||
"answer": fields.String,
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
|
||||
class SavedMessageListApi(InstalledAppResource):
|
||||
saved_message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_fields)),
|
||||
}
|
||||
|
||||
@marshal_with(saved_message_infinite_scroll_pagination_fields)
|
||||
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
|
||||
def get(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@ -57,12 +37,19 @@ class SavedMessageListApi(InstalledAppResource):
|
||||
|
||||
args = SavedMessageListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
return SavedMessageService.pagination_by_last_id(
|
||||
pagination = SavedMessageService.pagination_by_last_id(
|
||||
app_model,
|
||||
current_user,
|
||||
str(args.last_id) if args.last_id else None,
|
||||
args.limit,
|
||||
)
|
||||
adapter = TypeAdapter(SavedMessageItem)
|
||||
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
|
||||
return SavedMessageInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=items,
|
||||
).model_dump(mode="json")
|
||||
|
||||
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
|
||||
def post(self, installed_app):
|
||||
@ -78,7 +65,7 @@ class SavedMessageListApi(InstalledAppResource):
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
@ -96,4 +83,4 @@ class SavedMessageApi(InstalledAppResource):
|
||||
|
||||
SavedMessageService.delete(app_model, current_user, message_id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
|
||||
@ -4,12 +4,11 @@ from typing import Any
|
||||
|
||||
from flask import make_response, redirect, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
@ -44,6 +43,12 @@ class TriggerSubscriptionUpdateRequest(BaseModel):
|
||||
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription")
|
||||
properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_at_least_one_field(self):
|
||||
if all(v is None for v in (self.name, self.credentials, self.parameters, self.properties)):
|
||||
raise ValueError("At least one of name, credentials, parameters, or properties must be provided")
|
||||
return self
|
||||
|
||||
|
||||
class TriggerSubscriptionVerifyRequest(BaseModel):
|
||||
"""Request payload for verifying subscription credentials."""
|
||||
@ -333,7 +338,7 @@ class TriggerSubscriptionUpdateApi(Resource):
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
args = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
|
||||
request = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
|
||||
|
||||
subscription = TriggerProviderService.get_subscription_by_id(
|
||||
tenant_id=user.current_tenant_id,
|
||||
@ -345,50 +350,32 @@ class TriggerSubscriptionUpdateApi(Resource):
|
||||
provider_id = TriggerProviderID(subscription.provider_id)
|
||||
|
||||
try:
|
||||
# rename only
|
||||
if (
|
||||
args.name is not None
|
||||
and args.credentials is None
|
||||
and args.parameters is None
|
||||
and args.properties is None
|
||||
):
|
||||
# For rename only, just update the name
|
||||
rename = request.name is not None and not any((request.credentials, request.parameters, request.properties))
|
||||
# When credential type is UNAUTHORIZED, it indicates the subscription was manually created
|
||||
# For Manually created subscription, they dont have credentials, parameters
|
||||
# They only have name and properties(which is input by user)
|
||||
manually_created = subscription.credential_type == CredentialType.UNAUTHORIZED
|
||||
if rename or manually_created:
|
||||
TriggerProviderService.update_trigger_subscription(
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
name=args.name,
|
||||
name=request.name,
|
||||
properties=request.properties,
|
||||
)
|
||||
return 200
|
||||
|
||||
# rebuild for create automatically by the provider
|
||||
match subscription.credential_type:
|
||||
case CredentialType.UNAUTHORIZED:
|
||||
TriggerProviderService.update_trigger_subscription(
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
name=args.name,
|
||||
properties=args.properties,
|
||||
)
|
||||
return 200
|
||||
case CredentialType.API_KEY | CredentialType.OAUTH2:
|
||||
if args.credentials:
|
||||
new_credentials: dict[str, Any] = {
|
||||
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
|
||||
for key, value in args.credentials.items()
|
||||
}
|
||||
else:
|
||||
new_credentials = subscription.credentials
|
||||
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=user.current_tenant_id,
|
||||
name=args.name,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription_id,
|
||||
credentials=new_credentials,
|
||||
parameters=args.parameters or subscription.parameters,
|
||||
)
|
||||
return 200
|
||||
case _:
|
||||
raise BadRequest("Invalid credential type")
|
||||
# For the rest cases(API_KEY, OAUTH2)
|
||||
# we need to call third party provider(e.g. GitHub) to rebuild the subscription
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=user.current_tenant_id,
|
||||
name=request.name,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription_id,
|
||||
credentials=request.credentials or subscription.credentials,
|
||||
parameters=request.parameters or subscription.parameters,
|
||||
)
|
||||
return 200
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from flask_restx.api import HTTPStatus
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -92,7 +92,7 @@ annotation_list_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_annotation_list_model(api_or_ns: Api | Namespace):
|
||||
def build_annotation_list_model(api_or_ns: Namespace):
|
||||
"""Build the annotation list model for the API or Namespace."""
|
||||
copied_annotation_list_fields = annotation_list_fields.copy()
|
||||
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.common.fields import build_parameters_model
|
||||
from controllers.common.fields import Parameters
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
@ -23,7 +23,6 @@ class AppParameterApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_parameters_model(service_api_ns))
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app parameters.
|
||||
|
||||
@ -45,7 +44,8 @@ class AppParameterApi(Resource):
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@service_api_ns.route("/meta")
|
||||
|
||||
@ -3,8 +3,7 @@ from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from flask_restx._http import HTTPStatus
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
@ -16,9 +15,9 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
build_conversation_delete_model,
|
||||
build_conversation_infinite_scroll_pagination_model,
|
||||
build_simple_conversation_model,
|
||||
ConversationDelete,
|
||||
ConversationInfiniteScrollPagination,
|
||||
SimpleConversation,
|
||||
)
|
||||
from fields.conversation_variable_fields import (
|
||||
build_conversation_variable_infinite_scroll_pagination_model,
|
||||
@ -105,7 +104,6 @@ class ConversationApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@service_api_ns.marshal_with(build_conversation_infinite_scroll_pagination_model(service_api_ns))
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
"""List all conversations for the current user.
|
||||
|
||||
@ -120,7 +118,7 @@ class ConversationApi(Resource):
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
return ConversationService.pagination_by_last_id(
|
||||
pagination = ConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
@ -129,6 +127,13 @@ class ConversationApi(Resource):
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
sort_by=query_args.sort_by,
|
||||
)
|
||||
adapter = TypeAdapter(SimpleConversation)
|
||||
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
|
||||
return ConversationInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=conversations,
|
||||
).model_dump(mode="json")
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
@ -146,7 +151,6 @@ class ConversationDetailApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT)
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Delete a specific conversation."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
@ -159,7 +163,7 @@ class ConversationDetailApi(Resource):
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
return {"result": "success"}, 204
|
||||
return ConversationDelete(result="success").model_dump(mode="json"), 204
|
||||
|
||||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>/name")
|
||||
@ -176,7 +180,6 @@ class ConversationRenameApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@service_api_ns.marshal_with(build_simple_conversation_model(service_api_ns))
|
||||
def post(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Rename a conversation or auto-generate a name."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
@ -188,7 +191,14 @@ class ConversationRenameApi(Resource):
|
||||
payload = ConversationRenamePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate)
|
||||
conversation = ConversationService.rename(
|
||||
app_model, conversation_id, end_user, payload.name, payload.auto_generate
|
||||
)
|
||||
return (
|
||||
TypeAdapter(SimpleConversation)
|
||||
.validate_python(conversation, from_attributes=True)
|
||||
.model_dump(mode="json")
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
@ -14,10 +13,8 @@ from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import build_message_file_model
|
||||
from fields.message_fields import build_agent_thought_model, build_feedback_model
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
@ -48,49 +45,6 @@ class FeedbackListQuery(BaseModel):
|
||||
register_schema_models(service_api_ns, MessageListQuery, MessageFeedbackPayload, FeedbackListQuery)
|
||||
|
||||
|
||||
def build_message_model(api_or_ns: Namespace):
|
||||
"""Build the message model for the API or Namespace."""
|
||||
# First build the nested models
|
||||
feedback_model = build_feedback_model(api_or_ns)
|
||||
agent_thought_model = build_agent_thought_model(api_or_ns)
|
||||
message_file_model = build_message_file_model(api_or_ns)
|
||||
|
||||
# Then build the message fields with nested models
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"message_files": fields.List(fields.Nested(message_file_model)),
|
||||
"feedback": fields.Nested(feedback_model, attribute="user_feedback", allow_null=True),
|
||||
"retriever_resources": fields.Raw(
|
||||
attribute=lambda obj: json.loads(obj.message_metadata).get("retriever_resources", [])
|
||||
if obj.message_metadata
|
||||
else []
|
||||
),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
}
|
||||
return api_or_ns.model("Message", message_fields)
|
||||
|
||||
|
||||
def build_message_infinite_scroll_pagination_model(api_or_ns: Namespace):
|
||||
"""Build the message infinite scroll pagination model for the API or Namespace."""
|
||||
# Build the nested message model first
|
||||
message_model = build_message_model(api_or_ns)
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_model)),
|
||||
}
|
||||
return api_or_ns.model("MessageInfiniteScrollPagination", message_infinite_scroll_pagination_fields)
|
||||
|
||||
|
||||
@service_api_ns.route("/messages")
|
||||
class MessageListApi(Resource):
|
||||
@service_api_ns.expect(service_api_ns.models[MessageListQuery.__name__])
|
||||
@ -104,7 +58,6 @@ class MessageListApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@service_api_ns.marshal_with(build_message_infinite_scroll_pagination_model(service_api_ns))
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
"""List messages in a conversation.
|
||||
|
||||
@ -119,9 +72,16 @@ class MessageListApi(Resource):
|
||||
first_id = str(query_args.first_id) if query_args.first_id else None
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
pagination = MessageService.pagination_by_first_id(
|
||||
app_model, end_user, conversation_id, first_id, query_args.limit
|
||||
)
|
||||
adapter = TypeAdapter(MessageListItem)
|
||||
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
|
||||
return MessageInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=items,
|
||||
).model_dump(mode="json")
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except FirstMessageNotExistsError:
|
||||
@ -162,7 +122,7 @@ class MessageFeedbackApi(Resource):
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@service_api_ns.route("/app/feedbacks")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.fields import build_site_model
|
||||
from controllers.common.fields import Site as SiteResponse
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_database import db
|
||||
@ -23,7 +23,6 @@ class AppSiteApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_site_model(service_api_ns))
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app site info.
|
||||
|
||||
@ -38,4 +37,4 @@ class AppSiteApi(Resource):
|
||||
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
return site
|
||||
return SiteResponse.model_validate(site).model_dump(mode="json")
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Any, Literal
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
@ -78,7 +78,7 @@ workflow_run_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_workflow_run_model(api_or_ns: Api | Namespace):
|
||||
def build_workflow_run_model(api_or_ns: Namespace):
|
||||
"""Build the workflow run model for the API or Namespace."""
|
||||
return api_or_ns.model("WorkflowRun", workflow_run_fields)
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
@ -50,7 +50,6 @@ class AppParameterApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
@ -69,7 +68,8 @@ class AppParameterApi(WebApiResource):
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@web_ns.route("/meta")
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from flask_restx import fields, marshal_with, reqparse
|
||||
from flask_restx import reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@ -8,7 +9,11 @@ from controllers.web.error import NotChatAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from fields.conversation_fields import (
|
||||
ConversationInfiniteScrollPagination,
|
||||
ResultResponse,
|
||||
SimpleConversation,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
@ -54,7 +59,6 @@ class ConversationListApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -82,7 +86,7 @@ class ConversationListApi(WebApiResource):
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
pagination = WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
@ -92,16 +96,19 @@ class ConversationListApi(WebApiResource):
|
||||
pinned=pinned,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
adapter = TypeAdapter(SimpleConversation)
|
||||
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
|
||||
return ConversationInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=conversations,
|
||||
).model_dump(mode="json")
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
@web_ns.route("/conversations/<uuid:c_id>")
|
||||
class ConversationApi(WebApiResource):
|
||||
delete_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Delete Conversation")
|
||||
@web_ns.doc(description="Delete a specific conversation.")
|
||||
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
|
||||
@ -115,7 +122,6 @@ class ConversationApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(delete_response_fields)
|
||||
def delete(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -126,7 +132,7 @@ class ConversationApi(WebApiResource):
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
return {"result": "success"}, 204
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
|
||||
|
||||
@web_ns.route("/conversations/<uuid:c_id>/name")
|
||||
@ -155,7 +161,6 @@ class ConversationRenameApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -171,17 +176,20 @@ class ConversationRenameApi(WebApiResource):
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
|
||||
conversation = ConversationService.rename(
|
||||
app_model, conversation_id, end_user, args["name"], args["auto_generate"]
|
||||
)
|
||||
return (
|
||||
TypeAdapter(SimpleConversation)
|
||||
.validate_python(conversation, from_attributes=True)
|
||||
.model_dump(mode="json")
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
@web_ns.route("/conversations/<uuid:c_id>/pin")
|
||||
class ConversationPinApi(WebApiResource):
|
||||
pin_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Pin Conversation")
|
||||
@web_ns.doc(description="Pin a specific conversation to keep it at the top of the list.")
|
||||
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
|
||||
@ -195,7 +203,6 @@ class ConversationPinApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(pin_response_fields)
|
||||
def patch(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -208,15 +215,11 @@ class ConversationPinApi(WebApiResource):
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@web_ns.route("/conversations/<uuid:c_id>/unpin")
|
||||
class ConversationUnPinApi(WebApiResource):
|
||||
unpin_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Unpin Conversation")
|
||||
@web_ns.doc(description="Unpin a specific conversation to remove it from the top of the list.")
|
||||
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
|
||||
@ -230,7 +233,6 @@ class ConversationUnPinApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(unpin_response_fields)
|
||||
def patch(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -239,4 +241,4 @@ class ConversationUnPinApi(WebApiResource):
|
||||
conversation_id = str(c_id)
|
||||
WebConversationService.unpin(app_model, conversation_id, end_user)
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
@ -2,8 +2,7 @@ import logging
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
@ -22,11 +21,10 @@ from controllers.web.wraps import WebApiResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from fields.message_fields import agent_thought_fields, feedback_fields, retriever_resource_fields
|
||||
from fields.raws import FilesContainedField
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
@ -70,29 +68,6 @@ register_schema_models(web_ns, MessageListQuery, MessageFeedbackPayload, Message
|
||||
|
||||
@web_ns.route("/messages")
|
||||
class MessageListApi(WebApiResource):
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
}
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_fields)),
|
||||
}
|
||||
|
||||
@web_ns.doc("Get Message List")
|
||||
@web_ns.doc(description="Retrieve paginated list of messages from a conversation in a chat application.")
|
||||
@web_ns.doc(
|
||||
@ -121,7 +96,6 @@ class MessageListApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -131,9 +105,16 @@ class MessageListApi(WebApiResource):
|
||||
query = MessageListQuery.model_validate(raw_args)
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
pagination = MessageService.pagination_by_first_id(
|
||||
app_model, end_user, query.conversation_id, query.first_id, query.limit
|
||||
)
|
||||
adapter = TypeAdapter(WebMessageListItem)
|
||||
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
|
||||
return WebMessageInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=items,
|
||||
).model_dump(mode="json")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except FirstMessageNotExistsError:
|
||||
@ -142,10 +123,6 @@ class MessageListApi(WebApiResource):
|
||||
|
||||
@web_ns.route("/messages/<uuid:message_id>/feedbacks")
|
||||
class MessageFeedbackApi(WebApiResource):
|
||||
feedback_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Create Message Feedback")
|
||||
@web_ns.doc(description="Submit feedback (like/dislike) for a specific message.")
|
||||
@web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}})
|
||||
@ -170,7 +147,6 @@ class MessageFeedbackApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(feedback_response_fields)
|
||||
def post(self, app_model, end_user, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
@ -187,7 +163,7 @@ class MessageFeedbackApi(WebApiResource):
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@web_ns.route("/messages/<uuid:message_id>/more-like-this")
|
||||
@ -247,10 +223,6 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||
|
||||
@web_ns.route("/messages/<uuid:message_id>/suggested-questions")
|
||||
class MessageSuggestedQuestionApi(WebApiResource):
|
||||
suggested_questions_response_fields = {
|
||||
"data": fields.List(fields.String),
|
||||
}
|
||||
|
||||
@web_ns.doc("Get Suggested Questions")
|
||||
@web_ns.doc(description="Get suggested follow-up questions after a message (chat apps only).")
|
||||
@web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}})
|
||||
@ -264,7 +236,6 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(suggested_questions_response_fields)
|
||||
def get(self, app_model, end_user, message_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -277,7 +248,6 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP
|
||||
)
|
||||
# questions is a list of strings, not a list of Message objects
|
||||
# so we can directly return it
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
@ -296,4 +266,4 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
return SuggestedQuestionsResponse(data=questions).model_dump(mode="json")
|
||||
|
||||
@ -1,40 +1,20 @@
|
||||
from flask_restx import fields, marshal_with, reqparse
|
||||
from flask_restx import reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from pydantic import TypeAdapter
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotCompletionAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from libs.helper import uuid_value
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
feedback_fields = {"rating": fields.String}
|
||||
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"inputs": fields.Raw,
|
||||
"query": fields.String,
|
||||
"answer": fields.String,
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
|
||||
@web_ns.route("/saved-messages")
|
||||
class SavedMessageListApi(WebApiResource):
|
||||
saved_message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_fields)),
|
||||
}
|
||||
|
||||
post_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Get Saved Messages")
|
||||
@web_ns.doc(description="Retrieve paginated list of saved messages for a completion application.")
|
||||
@web_ns.doc(
|
||||
@ -58,7 +38,6 @@ class SavedMessageListApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(saved_message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
@ -70,7 +49,14 @@ class SavedMessageListApi(WebApiResource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
|
||||
pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
|
||||
adapter = TypeAdapter(SavedMessageItem)
|
||||
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
|
||||
return SavedMessageInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=items,
|
||||
).model_dump(mode="json")
|
||||
|
||||
@web_ns.doc("Save Message")
|
||||
@web_ns.doc(description="Save a specific message for later reference.")
|
||||
@ -89,7 +75,6 @@ class SavedMessageListApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(post_response_fields)
|
||||
def post(self, app_model, end_user):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
@ -102,15 +87,11 @@ class SavedMessageListApi(WebApiResource):
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@web_ns.route("/saved-messages/<uuid:message_id>")
|
||||
class SavedMessageApi(WebApiResource):
|
||||
delete_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Delete Saved Message")
|
||||
@web_ns.doc(description="Remove a message from saved messages.")
|
||||
@web_ns.doc(params={"message_id": {"description": "Message UUID to delete", "type": "string", "required": True}})
|
||||
@ -124,7 +105,6 @@ class SavedMessageApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(delete_response_fields)
|
||||
def delete(self, app_model, end_user, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
@ -133,4 +113,4 @@ class SavedMessageApi(WebApiResource):
|
||||
|
||||
SavedMessageService.delete(app_model, end_user, message_id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
|
||||
@ -22,6 +22,7 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.workflow.nodes.agent.exc import AgentMaxIterationError
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -165,6 +166,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
# Check if max iteration is reached and model still wants to call tools
|
||||
if iteration_step == max_iteration_steps and scratchpad.action:
|
||||
if scratchpad.action.action_name.lower() != "final answer":
|
||||
raise AgentMaxIterationError(app_config.agent.max_iteration)
|
||||
|
||||
# get llm usage
|
||||
if "usage" in usage_dict:
|
||||
if usage_dict["usage"] is not None:
|
||||
|
||||
@ -25,6 +25,7 @@ from core.model_runtime.entities.message_entities import ImagePromptMessageConte
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.workflow.nodes.agent.exc import AgentMaxIterationError
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -222,6 +223,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
final_answer += response + "\n"
|
||||
|
||||
# Check if max iteration is reached and model still wants to call tools
|
||||
if iteration_step == max_iteration_steps and tool_calls:
|
||||
raise AgentMaxIterationError(app_config.agent.max_iteration)
|
||||
|
||||
# call tools
|
||||
tool_responses = []
|
||||
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
|
||||
|
||||
@ -30,7 +30,6 @@ class SimpleModelProviderEntity(BaseModel):
|
||||
label: I18nObject
|
||||
icon_small: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
supported_model_types: list[ModelType]
|
||||
|
||||
def __init__(self, provider_entity: ProviderEntity):
|
||||
@ -44,7 +43,6 @@ class SimpleModelProviderEntity(BaseModel):
|
||||
label=provider_entity.label,
|
||||
icon_small=provider_entity.icon_small,
|
||||
icon_small_dark=provider_entity.icon_small_dark,
|
||||
icon_large=provider_entity.icon_large,
|
||||
supported_model_types=provider_entity.supported_model_types,
|
||||
)
|
||||
|
||||
@ -94,7 +92,6 @@ class DefaultModelProviderEntity(BaseModel):
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
supported_model_types: Sequence[ModelType] = []
|
||||
|
||||
|
||||
|
||||
@ -88,7 +88,41 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _inject_trace_headers(headers: dict | None) -> dict:
|
||||
"""
|
||||
Inject W3C traceparent header for distributed tracing.
|
||||
|
||||
When OTEL is enabled, HTTPXClientInstrumentor handles trace propagation automatically.
|
||||
When OTEL is disabled, we manually inject the traceparent header.
|
||||
"""
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
# Skip if already present (case-insensitive check)
|
||||
for key in headers:
|
||||
if key.lower() == "traceparent":
|
||||
return headers
|
||||
|
||||
# Skip if OTEL is enabled - HTTPXClientInstrumentor handles this automatically
|
||||
if dify_config.ENABLE_OTEL:
|
||||
return headers
|
||||
|
||||
# Generate and inject traceparent for non-OTEL scenarios
|
||||
try:
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
traceparent = generate_traceparent_header()
|
||||
if traceparent:
|
||||
headers["traceparent"] = traceparent
|
||||
except Exception:
|
||||
# Silently ignore errors to avoid breaking requests
|
||||
logger.debug("Failed to generate traceparent header", exc_info=True)
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
# Convert requests-style allow_redirects to httpx-style follow_redirects
|
||||
if "allow_redirects" in kwargs:
|
||||
allow_redirects = kwargs.pop("allow_redirects")
|
||||
if "follow_redirects" not in kwargs:
|
||||
@ -106,18 +140,21 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
|
||||
client = _get_ssrf_client(verify_option)
|
||||
|
||||
# Inject traceparent header for distributed tracing (when OTEL is not enabled)
|
||||
headers = kwargs.get("headers") or {}
|
||||
headers = _inject_trace_headers(headers)
|
||||
kwargs["headers"] = headers
|
||||
|
||||
# Preserve user-provided Host header
|
||||
# When using a forward proxy, httpx may override the Host header based on the URL.
|
||||
# We extract and preserve any explicitly set Host header to support virtual hosting.
|
||||
headers = kwargs.get("headers", {})
|
||||
user_provided_host = _get_user_provided_host_header(headers)
|
||||
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
# Build the request manually to preserve the Host header
|
||||
# httpx may override the Host header when using a proxy, so we use
|
||||
# the request API to explicitly set headers before sending
|
||||
# Preserve the user-provided Host header
|
||||
# httpx may override the Host header when using a proxy
|
||||
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
|
||||
if user_provided_host is not None:
|
||||
headers["host"] = user_provided_host
|
||||
|
||||
@ -103,3 +103,60 @@ def parse_traceparent_header(traceparent: str) -> str | None:
|
||||
if len(parts) == 4 and len(parts[1]) == 32:
|
||||
return parts[1]
|
||||
return None
|
||||
|
||||
|
||||
def get_span_id_from_otel_context() -> str | None:
|
||||
"""
|
||||
Retrieve the current span ID from the active OpenTelemetry trace context.
|
||||
|
||||
Returns:
|
||||
A 16-character hex string representing the span ID, or None if not available.
|
||||
"""
|
||||
try:
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID
|
||||
|
||||
span = get_current_span()
|
||||
if not span:
|
||||
return None
|
||||
|
||||
span_context = span.get_span_context()
|
||||
if not span_context or span_context.span_id == INVALID_SPAN_ID:
|
||||
return None
|
||||
|
||||
return f"{span_context.span_id:016x}"
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def generate_traceparent_header() -> str | None:
|
||||
"""
|
||||
Generate a W3C traceparent header from the current context.
|
||||
|
||||
Uses OpenTelemetry context if available, otherwise uses the
|
||||
ContextVar-based trace_id from the logging context.
|
||||
|
||||
Format: {version}-{trace_id}-{span_id}-{flags}
|
||||
Example: 00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01
|
||||
|
||||
Returns:
|
||||
A valid traceparent header string, or None if generation fails.
|
||||
"""
|
||||
import uuid
|
||||
|
||||
# Try OTEL context first
|
||||
trace_id = get_trace_id_from_otel_context()
|
||||
span_id = get_span_id_from_otel_context()
|
||||
|
||||
if trace_id and span_id:
|
||||
return f"00-{trace_id}-{span_id}-01"
|
||||
|
||||
# Fallback: use ContextVar-based trace_id or generate new one
|
||||
from core.logging.context import get_trace_id as get_logging_trace_id
|
||||
|
||||
trace_id = get_logging_trace_id() or uuid.uuid4().hex
|
||||
|
||||
# Generate a new span_id (16 hex chars)
|
||||
span_id = uuid.uuid4().hex[:16]
|
||||
|
||||
return f"00-{trace_id}-{span_id}-01"
|
||||
|
||||
20
api/core/logging/__init__.py
Normal file
20
api/core/logging/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
"""Structured logging components for Dify."""
|
||||
|
||||
from core.logging.context import (
|
||||
clear_request_context,
|
||||
get_request_id,
|
||||
get_trace_id,
|
||||
init_request_context,
|
||||
)
|
||||
from core.logging.filters import IdentityContextFilter, TraceContextFilter
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
__all__ = [
|
||||
"IdentityContextFilter",
|
||||
"StructuredJSONFormatter",
|
||||
"TraceContextFilter",
|
||||
"clear_request_context",
|
||||
"get_request_id",
|
||||
"get_trace_id",
|
||||
"init_request_context",
|
||||
]
|
||||
35
api/core/logging/context.py
Normal file
35
api/core/logging/context.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""Request context for logging - framework agnostic.
|
||||
|
||||
This module provides request-scoped context variables for logging,
|
||||
using Python's contextvars for thread-safe and async-safe storage.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
|
||||
_request_id: ContextVar[str] = ContextVar("log_request_id", default="")
|
||||
_trace_id: ContextVar[str] = ContextVar("log_trace_id", default="")
|
||||
|
||||
|
||||
def get_request_id() -> str:
|
||||
"""Get current request ID (10 hex chars)."""
|
||||
return _request_id.get()
|
||||
|
||||
|
||||
def get_trace_id() -> str:
|
||||
"""Get fallback trace ID when OTEL is unavailable (32 hex chars)."""
|
||||
return _trace_id.get()
|
||||
|
||||
|
||||
def init_request_context() -> None:
|
||||
"""Initialize request context. Call at start of each request."""
|
||||
req_id = uuid.uuid4().hex[:10]
|
||||
trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, req_id).hex
|
||||
_request_id.set(req_id)
|
||||
_trace_id.set(trace_id)
|
||||
|
||||
|
||||
def clear_request_context() -> None:
|
||||
"""Clear request context. Call at end of request (optional)."""
|
||||
_request_id.set("")
|
||||
_trace_id.set("")
|
||||
94
api/core/logging/filters.py
Normal file
94
api/core/logging/filters.py
Normal file
@ -0,0 +1,94 @@
|
||||
"""Logging filters for structured logging."""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
|
||||
import flask
|
||||
|
||||
from core.logging.context import get_request_id, get_trace_id
|
||||
|
||||
|
||||
class TraceContextFilter(logging.Filter):
|
||||
"""
|
||||
Filter that adds trace_id and span_id to log records.
|
||||
Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
# Get trace context from OpenTelemetry
|
||||
trace_id, span_id = self._get_otel_context()
|
||||
|
||||
# Set trace_id (fallback to ContextVar if no OTEL context)
|
||||
if trace_id:
|
||||
record.trace_id = trace_id
|
||||
else:
|
||||
record.trace_id = get_trace_id()
|
||||
|
||||
record.span_id = span_id or ""
|
||||
|
||||
# For backward compatibility, also set req_id
|
||||
record.req_id = get_request_id()
|
||||
|
||||
return True
|
||||
|
||||
def _get_otel_context(self) -> tuple[str, str]:
|
||||
"""Extract trace_id and span_id from OpenTelemetry context."""
|
||||
with contextlib.suppress(Exception):
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
span = get_current_span()
|
||||
if span and span.get_span_context():
|
||||
ctx = span.get_span_context()
|
||||
if ctx.is_valid and ctx.trace_id != INVALID_TRACE_ID:
|
||||
trace_id = f"{ctx.trace_id:032x}"
|
||||
span_id = f"{ctx.span_id:016x}" if ctx.span_id != INVALID_SPAN_ID else ""
|
||||
return trace_id, span_id
|
||||
return "", ""
|
||||
|
||||
|
||||
class IdentityContextFilter(logging.Filter):
|
||||
"""
|
||||
Filter that adds user identity context to log records.
|
||||
Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
identity = self._extract_identity()
|
||||
record.tenant_id = identity.get("tenant_id", "")
|
||||
record.user_id = identity.get("user_id", "")
|
||||
record.user_type = identity.get("user_type", "")
|
||||
return True
|
||||
|
||||
def _extract_identity(self) -> dict[str, str]:
|
||||
"""Extract identity from current_user if in request context."""
|
||||
try:
|
||||
if not flask.has_request_context():
|
||||
return {}
|
||||
from flask_login import current_user
|
||||
|
||||
# Check if user is authenticated using the proxy
|
||||
if not current_user.is_authenticated:
|
||||
return {}
|
||||
|
||||
# Access the underlying user object
|
||||
user = current_user
|
||||
|
||||
from models import Account
|
||||
from models.model import EndUser
|
||||
|
||||
identity: dict[str, str] = {}
|
||||
|
||||
if isinstance(user, Account):
|
||||
if user.current_tenant_id:
|
||||
identity["tenant_id"] = user.current_tenant_id
|
||||
identity["user_id"] = user.id
|
||||
identity["user_type"] = "account"
|
||||
elif isinstance(user, EndUser):
|
||||
identity["tenant_id"] = user.tenant_id
|
||||
identity["user_id"] = user.id
|
||||
identity["user_type"] = user.type or "end_user"
|
||||
|
||||
return identity
|
||||
except Exception:
|
||||
return {}
|
||||
107
api/core/logging/structured_formatter.py
Normal file
107
api/core/logging/structured_formatter.py
Normal file
@ -0,0 +1,107 @@
|
||||
"""Structured JSON log formatter for Dify."""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
class StructuredJSONFormatter(logging.Formatter):
|
||||
"""
|
||||
JSON log formatter following the specified schema:
|
||||
{
|
||||
"ts": "ISO 8601 UTC",
|
||||
"severity": "INFO|ERROR|WARN|DEBUG",
|
||||
"service": "service name",
|
||||
"caller": "file:line",
|
||||
"trace_id": "hex 32",
|
||||
"span_id": "hex 16",
|
||||
"identity": { "tenant_id", "user_id", "user_type" },
|
||||
"message": "log message",
|
||||
"attributes": { ... },
|
||||
"stack_trace": "..."
|
||||
}
|
||||
"""
|
||||
|
||||
SEVERITY_MAP: dict[int, str] = {
|
||||
logging.DEBUG: "DEBUG",
|
||||
logging.INFO: "INFO",
|
||||
logging.WARNING: "WARN",
|
||||
logging.ERROR: "ERROR",
|
||||
logging.CRITICAL: "ERROR",
|
||||
}
|
||||
|
||||
def __init__(self, service_name: str | None = None):
|
||||
super().__init__()
|
||||
self._service_name = service_name or dify_config.APPLICATION_NAME
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
log_dict = self._build_log_dict(record)
|
||||
try:
|
||||
return orjson.dumps(log_dict).decode("utf-8")
|
||||
except TypeError:
|
||||
# Fallback: convert non-serializable objects to string
|
||||
import json
|
||||
|
||||
return json.dumps(log_dict, default=str, ensure_ascii=False)
|
||||
|
||||
def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
|
||||
# Core fields
|
||||
log_dict: dict[str, Any] = {
|
||||
"ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
|
||||
"severity": self.SEVERITY_MAP.get(record.levelno, "INFO"),
|
||||
"service": self._service_name,
|
||||
"caller": f"{record.filename}:{record.lineno}",
|
||||
"message": record.getMessage(),
|
||||
}
|
||||
|
||||
# Trace context (from TraceContextFilter)
|
||||
trace_id = getattr(record, "trace_id", "")
|
||||
span_id = getattr(record, "span_id", "")
|
||||
|
||||
if trace_id:
|
||||
log_dict["trace_id"] = trace_id
|
||||
if span_id:
|
||||
log_dict["span_id"] = span_id
|
||||
|
||||
# Identity context (from IdentityContextFilter)
|
||||
identity = self._extract_identity(record)
|
||||
if identity:
|
||||
log_dict["identity"] = identity
|
||||
|
||||
# Dynamic attributes
|
||||
attributes = getattr(record, "attributes", None)
|
||||
if attributes:
|
||||
log_dict["attributes"] = attributes
|
||||
|
||||
# Stack trace for errors with exceptions
|
||||
if record.exc_info and record.levelno >= logging.ERROR:
|
||||
log_dict["stack_trace"] = self._format_exception(record.exc_info)
|
||||
|
||||
return log_dict
|
||||
|
||||
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
|
||||
tenant_id = getattr(record, "tenant_id", None)
|
||||
user_id = getattr(record, "user_id", None)
|
||||
user_type = getattr(record, "user_type", None)
|
||||
|
||||
if not any([tenant_id, user_id, user_type]):
|
||||
return None
|
||||
|
||||
identity: dict[str, str] = {}
|
||||
if tenant_id:
|
||||
identity["tenant_id"] = tenant_id
|
||||
if user_id:
|
||||
identity["user_id"] = user_id
|
||||
if user_type:
|
||||
identity["user_type"] = user_type
|
||||
return identity
|
||||
|
||||
def _format_exception(self, exc_info: tuple[Any, ...]) -> str:
|
||||
if exc_info and exc_info[0] is not None:
|
||||
return "".join(traceback.format_exception(*exc_info))
|
||||
return ""
|
||||
@ -100,7 +100,6 @@ class SimpleProviderEntity(BaseModel):
|
||||
label: I18nObject
|
||||
icon_small: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
supported_model_types: Sequence[ModelType]
|
||||
models: list[AIModelEntity] = []
|
||||
|
||||
@ -123,7 +122,6 @@ class ProviderEntity(BaseModel):
|
||||
label: I18nObject
|
||||
description: I18nObject | None = None
|
||||
icon_small: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
background: str | None = None
|
||||
help: ProviderHelpEntity | None = None
|
||||
@ -157,7 +155,6 @@ class ProviderEntity(BaseModel):
|
||||
provider=self.provider,
|
||||
label=self.label,
|
||||
icon_small=self.icon_small,
|
||||
icon_large=self.icon_large,
|
||||
supported_model_types=self.supported_model_types,
|
||||
models=self.models,
|
||||
)
|
||||
|
||||
@ -285,7 +285,7 @@ class ModelProviderFactory:
|
||||
"""
|
||||
Get provider icon
|
||||
:param provider: provider name
|
||||
:param icon_type: icon type (icon_small or icon_large)
|
||||
:param icon_type: icon type (icon_small or icon_small_dark)
|
||||
:param lang: language (zh_Hans or en_US)
|
||||
:return: provider icon
|
||||
"""
|
||||
@ -309,13 +309,7 @@ class ModelProviderFactory:
|
||||
else:
|
||||
file_name = provider_schema.icon_small_dark.en_US
|
||||
else:
|
||||
if not provider_schema.icon_large:
|
||||
raise ValueError(f"Provider {provider} does not have large icon.")
|
||||
|
||||
if lang.lower() == "zh_hans":
|
||||
file_name = provider_schema.icon_large.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_large.en_US
|
||||
raise ValueError(f"Unsupported icon type: {icon_type}.")
|
||||
|
||||
if not file_name:
|
||||
raise ValueError(f"Provider {provider} does not have icon.")
|
||||
|
||||
@ -103,6 +103,9 @@ class BasePluginClient:
|
||||
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
|
||||
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
|
||||
|
||||
# Inject traceparent header for distributed tracing
|
||||
self._inject_trace_headers(prepared_headers)
|
||||
|
||||
prepared_data: bytes | dict[str, Any] | str | None = (
|
||||
data if isinstance(data, (bytes, str, dict)) or data is None else None
|
||||
)
|
||||
@ -114,6 +117,31 @@ class BasePluginClient:
|
||||
|
||||
return str(url), prepared_headers, prepared_data, params, files
|
||||
|
||||
def _inject_trace_headers(self, headers: dict[str, str]) -> None:
|
||||
"""
|
||||
Inject W3C traceparent header for distributed tracing.
|
||||
|
||||
This ensures trace context is propagated to plugin daemon even if
|
||||
HTTPXClientInstrumentor doesn't cover module-level httpx functions.
|
||||
"""
|
||||
if not dify_config.ENABLE_OTEL:
|
||||
return
|
||||
|
||||
import contextlib
|
||||
|
||||
# Skip if already present (case-insensitive check)
|
||||
for key in headers:
|
||||
if key.lower() == "traceparent":
|
||||
return
|
||||
|
||||
# Inject traceparent - works as fallback when OTEL instrumentation doesn't cover this call
|
||||
with contextlib.suppress(Exception):
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
traceparent = generate_traceparent_header()
|
||||
if traceparent:
|
||||
headers["traceparent"] = traceparent
|
||||
|
||||
def _stream_request(
|
||||
self,
|
||||
method: str,
|
||||
|
||||
@ -331,7 +331,6 @@ class ProviderManager:
|
||||
provider=provider_schema.provider,
|
||||
label=provider_schema.label,
|
||||
icon_small=provider_schema.icon_small,
|
||||
icon_large=provider_schema.icon_large,
|
||||
supported_model_types=provider_schema.supported_model_types,
|
||||
),
|
||||
)
|
||||
|
||||
@ -112,7 +112,7 @@ class ExtractProcessor:
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == ".pdf":
|
||||
extractor = PdfExtractor(file_path)
|
||||
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension in {".md", ".markdown", ".mdx"}:
|
||||
extractor = (
|
||||
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
@ -148,7 +148,7 @@ class ExtractProcessor:
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == ".pdf":
|
||||
extractor = PdfExtractor(file_path)
|
||||
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension in {".md", ".markdown", ".mdx"}:
|
||||
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension in {".htm", ".html"}:
|
||||
|
||||
@ -1,25 +1,57 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pypdfium2
|
||||
import pypdfium2.raw as pdfium_c
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.blob.blob import Blob
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PdfExtractor(BaseExtractor):
|
||||
"""Load pdf files.
|
||||
|
||||
"""
|
||||
PdfExtractor is used to extract text and images from PDF files.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
file_path: Path to the PDF file.
|
||||
tenant_id: Workspace ID.
|
||||
user_id: ID of the user performing the extraction.
|
||||
file_cache_key: Optional cache key for the extracted text.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, file_cache_key: str | None = None):
|
||||
"""Initialize with file path."""
|
||||
# Magic bytes for image format detection: (magic_bytes, extension, mime_type)
|
||||
IMAGE_FORMATS = [
|
||||
(b"\xff\xd8\xff", "jpg", "image/jpeg"),
|
||||
(b"\x89PNG\r\n\x1a\n", "png", "image/png"),
|
||||
(b"\x00\x00\x00\x0c\x6a\x50\x20\x20\x0d\x0a\x87\x0a", "jp2", "image/jp2"),
|
||||
(b"GIF8", "gif", "image/gif"),
|
||||
(b"BM", "bmp", "image/bmp"),
|
||||
(b"II*\x00", "tiff", "image/tiff"),
|
||||
(b"MM\x00*", "tiff", "image/tiff"),
|
||||
(b"II+\x00", "tiff", "image/tiff"),
|
||||
(b"MM\x00+", "tiff", "image/tiff"),
|
||||
]
|
||||
MAX_MAGIC_LEN = max(len(m) for m, _, _ in IMAGE_FORMATS)
|
||||
|
||||
def __init__(self, file_path: str, tenant_id: str, user_id: str, file_cache_key: str | None = None):
|
||||
"""Initialize PdfExtractor."""
|
||||
self._file_path = file_path
|
||||
self._tenant_id = tenant_id
|
||||
self._user_id = user_id
|
||||
self._file_cache_key = file_cache_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
@ -50,7 +82,6 @@ class PdfExtractor(BaseExtractor):
|
||||
|
||||
def parse(self, blob: Blob) -> Iterator[Document]:
|
||||
"""Lazily parse the blob."""
|
||||
import pypdfium2 # type: ignore
|
||||
|
||||
with blob.as_bytes_io() as file_path:
|
||||
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
|
||||
@ -59,8 +90,87 @@ class PdfExtractor(BaseExtractor):
|
||||
text_page = page.get_textpage()
|
||||
content = text_page.get_text_range()
|
||||
text_page.close()
|
||||
|
||||
image_content = self._extract_images(page)
|
||||
if image_content:
|
||||
content += "\n" + image_content
|
||||
|
||||
page.close()
|
||||
metadata = {"source": blob.source, "page": page_number}
|
||||
yield Document(page_content=content, metadata=metadata)
|
||||
finally:
|
||||
pdf_reader.close()
|
||||
|
||||
def _extract_images(self, page) -> str:
|
||||
"""
|
||||
Extract images from a PDF page, save them to storage and database,
|
||||
and return markdown image links.
|
||||
|
||||
Args:
|
||||
page: pypdfium2 page object.
|
||||
|
||||
Returns:
|
||||
Markdown string containing links to the extracted images.
|
||||
"""
|
||||
image_content = []
|
||||
upload_files = []
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
|
||||
try:
|
||||
image_objects = page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,))
|
||||
for obj in image_objects:
|
||||
try:
|
||||
# Extract image bytes
|
||||
img_byte_arr = io.BytesIO()
|
||||
# Extract DCTDecode (JPEG) and JPXDecode (JPEG 2000) images directly
|
||||
# Fallback to png for other formats
|
||||
obj.extract(img_byte_arr, fb_format="png")
|
||||
img_bytes = img_byte_arr.getvalue()
|
||||
|
||||
if not img_bytes:
|
||||
continue
|
||||
|
||||
header = img_bytes[: self.MAX_MAGIC_LEN]
|
||||
image_ext = None
|
||||
mime_type = None
|
||||
for magic, ext, mime in self.IMAGE_FORMATS:
|
||||
if header.startswith(magic):
|
||||
image_ext = ext
|
||||
mime_type = mime
|
||||
break
|
||||
|
||||
if not image_ext or not mime_type:
|
||||
continue
|
||||
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = "image_files/" + self._tenant_id + "/" + file_uuid + "." + image_ext
|
||||
|
||||
storage.save(file_key, img_bytes)
|
||||
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=len(img_bytes),
|
||||
extension=image_ext,
|
||||
mime_type=mime_type,
|
||||
created_by=self._user_id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=naive_utc_now(),
|
||||
used=True,
|
||||
used_by=self._user_id,
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
upload_files.append(upload_file)
|
||||
image_content.append(f"")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to extract image from PDF: %s", e)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning("Failed to get objects from PDF page: %s", e)
|
||||
if upload_files:
|
||||
db.session.add_all(upload_files)
|
||||
db.session.commit()
|
||||
return "\n".join(image_content)
|
||||
|
||||
@ -515,6 +515,7 @@ class DatasetRetrieval:
|
||||
0
|
||||
].embedding_model_provider
|
||||
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
||||
dataset_count = len(available_datasets)
|
||||
with measure_time() as timer:
|
||||
cancel_event = threading.Event()
|
||||
thread_exceptions: list[Exception] = []
|
||||
@ -537,6 +538,7 @@ class DatasetRetrieval:
|
||||
"score_threshold": score_threshold,
|
||||
"query": query,
|
||||
"attachment_id": None,
|
||||
"dataset_count": dataset_count,
|
||||
"cancel_event": cancel_event,
|
||||
"thread_exceptions": thread_exceptions,
|
||||
},
|
||||
@ -562,6 +564,7 @@ class DatasetRetrieval:
|
||||
"score_threshold": score_threshold,
|
||||
"query": None,
|
||||
"attachment_id": attachment_id,
|
||||
"dataset_count": dataset_count,
|
||||
"cancel_event": cancel_event,
|
||||
"thread_exceptions": thread_exceptions,
|
||||
},
|
||||
@ -1422,6 +1425,7 @@ class DatasetRetrieval:
|
||||
score_threshold: float,
|
||||
query: str | None,
|
||||
attachment_id: str | None,
|
||||
dataset_count: int,
|
||||
cancel_event: threading.Event | None = None,
|
||||
thread_exceptions: list[Exception] | None = None,
|
||||
):
|
||||
@ -1470,37 +1474,38 @@ class DatasetRetrieval:
|
||||
if cancel_event and cancel_event.is_set():
|
||||
break
|
||||
|
||||
if reranking_enable:
|
||||
# do rerank for searched documents
|
||||
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
||||
if query:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.TEXT_QUERY,
|
||||
)
|
||||
if attachment_id:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.IMAGE_QUERY,
|
||||
query=attachment_id,
|
||||
)
|
||||
else:
|
||||
if index_type == IndexTechniqueType.ECONOMY:
|
||||
if not query:
|
||||
all_documents_item = []
|
||||
else:
|
||||
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
|
||||
elif index_type == IndexTechniqueType.HIGH_QUALITY:
|
||||
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
|
||||
# Skip second reranking when there is only one dataset
|
||||
if reranking_enable and dataset_count > 1:
|
||||
# do rerank for searched documents
|
||||
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
||||
if query:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.TEXT_QUERY,
|
||||
)
|
||||
if attachment_id:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.IMAGE_QUERY,
|
||||
query=attachment_id,
|
||||
)
|
||||
else:
|
||||
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
|
||||
if all_documents_item:
|
||||
all_documents.extend(all_documents_item)
|
||||
if index_type == IndexTechniqueType.ECONOMY:
|
||||
if not query:
|
||||
all_documents_item = []
|
||||
else:
|
||||
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
|
||||
elif index_type == IndexTechniqueType.HIGH_QUALITY:
|
||||
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
|
||||
else:
|
||||
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
|
||||
if all_documents_item:
|
||||
all_documents.extend(all_documents_item)
|
||||
except Exception as e:
|
||||
if cancel_event:
|
||||
cancel_event.set()
|
||||
|
||||
@ -119,3 +119,14 @@ class AgentVariableTypeError(AgentNodeError):
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentMaxIterationError(AgentNodeError):
|
||||
"""Exception raised when the agent exceeds the maximum iteration limit."""
|
||||
|
||||
def __init__(self, max_iteration: int):
|
||||
self.max_iteration = max_iteration
|
||||
super().__init__(
|
||||
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
|
||||
f"The agent was unable to complete the task within the allowed number of iterations."
|
||||
)
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
@ -13,6 +12,7 @@ from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
|
||||
from .exc import (
|
||||
CodeNodeError,
|
||||
@ -20,9 +20,41 @@ from .exc import (
|
||||
OutputValidationError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class CodeNode(Node[CodeNodeData]):
|
||||
node_type = NodeType.CODE
|
||||
_DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = (
|
||||
Python3CodeProvider,
|
||||
JavascriptCodeProvider,
|
||||
)
|
||||
_limits: CodeNodeLimits
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
code_executor: type[CodeExecutor] | None = None,
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
|
||||
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
|
||||
tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS
|
||||
)
|
||||
self._limits = code_limits
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
@ -35,11 +67,16 @@ class CodeNode(Node[CodeNodeData]):
|
||||
if filters:
|
||||
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
|
||||
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
|
||||
code_provider: type[CodeNodeProvider] = next(
|
||||
provider for provider in cls._DEFAULT_CODE_PROVIDERS if provider.is_accept_language(code_language)
|
||||
)
|
||||
|
||||
return code_provider.get_default_config()
|
||||
|
||||
@classmethod
|
||||
def default_code_providers(cls) -> tuple[type[CodeNodeProvider], ...]:
|
||||
return cls._DEFAULT_CODE_PROVIDERS
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
@ -60,7 +97,8 @@ class CodeNode(Node[CodeNodeData]):
|
||||
variables[variable_name] = variable.to_object() if variable else None
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
_ = self._select_code_provider(code_language)
|
||||
result = self._code_executor.execute_workflow_code_template(
|
||||
language=code_language,
|
||||
code=code,
|
||||
inputs=variables,
|
||||
@ -75,6 +113,12 @@ class CodeNode(Node[CodeNodeData]):
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
||||
|
||||
def _select_code_provider(self, code_language: CodeLanguage) -> type[CodeNodeProvider]:
|
||||
for provider in self._code_providers:
|
||||
if provider.is_accept_language(code_language):
|
||||
return provider
|
||||
raise CodeNodeError(f"Unsupported code language: {code_language}")
|
||||
|
||||
def _check_string(self, value: str | None, variable: str) -> str | None:
|
||||
"""
|
||||
Check string
|
||||
@ -85,10 +129,10 @@ class CodeNode(Node[CodeNodeData]):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
|
||||
if len(value) > self._limits.max_string_length:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{variable}` must be"
|
||||
f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters"
|
||||
f" less than {self._limits.max_string_length} characters"
|
||||
)
|
||||
|
||||
return value.replace("\x00", "")
|
||||
@ -109,20 +153,20 @@ class CodeNode(Node[CodeNodeData]):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
|
||||
if value > self._limits.max_number or value < self._limits.min_number:
|
||||
raise OutputValidationError(
|
||||
f"Output variable `{variable}` is out of range,"
|
||||
f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}."
|
||||
f" it must be between {self._limits.min_number} and {self._limits.max_number}."
|
||||
)
|
||||
|
||||
if isinstance(value, float):
|
||||
decimal_value = Decimal(str(value)).normalize()
|
||||
precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator]
|
||||
# raise error if precision is too high
|
||||
if precision > dify_config.CODE_MAX_PRECISION:
|
||||
if precision > self._limits.max_precision:
|
||||
raise OutputValidationError(
|
||||
f"Output variable `{variable}` has too high precision,"
|
||||
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
|
||||
f" it must be less than {self._limits.max_precision} digits."
|
||||
)
|
||||
|
||||
return value
|
||||
@ -137,8 +181,8 @@ class CodeNode(Node[CodeNodeData]):
|
||||
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
|
||||
# Note that `_transform_result` may produce lists containing `None` values,
|
||||
# which don't conform to the type requirements of `Array*Segment` classes.
|
||||
if depth > dify_config.CODE_MAX_DEPTH:
|
||||
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
if depth > self._limits.max_depth:
|
||||
raise DepthLimitError(f"Depth limit {self._limits.max_depth} reached, object too deep.")
|
||||
|
||||
transformed_result: dict[str, Any] = {}
|
||||
if output_schema is None:
|
||||
@ -272,10 +316,10 @@ class CodeNode(Node[CodeNodeData]):
|
||||
f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead."
|
||||
)
|
||||
else:
|
||||
if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
|
||||
if len(value) > self._limits.max_number_array_length:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
|
||||
f" less than {self._limits.max_number_array_length} elements."
|
||||
)
|
||||
|
||||
for i, inner_value in enumerate(value):
|
||||
@ -305,10 +349,10 @@ class CodeNode(Node[CodeNodeData]):
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
|
||||
if len(result[output_name]) > self._limits.max_string_array_length:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements."
|
||||
f" less than {self._limits.max_string_array_length} elements."
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
@ -326,10 +370,10 @@ class CodeNode(Node[CodeNodeData]):
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
|
||||
if len(result[output_name]) > self._limits.max_object_array_length:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements."
|
||||
f" less than {self._limits.max_object_array_length} elements."
|
||||
)
|
||||
|
||||
for i, value in enumerate(result[output_name]):
|
||||
|
||||
13
api/core/workflow/nodes/code/limits.py
Normal file
13
api/core/workflow/nodes/code/limits.py
Normal file
@ -0,0 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CodeNodeLimits:
|
||||
max_string_length: int
|
||||
max_number: int | float
|
||||
min_number: int | float
|
||||
max_precision: int
|
||||
max_depth: int
|
||||
max_number_array_length: int
|
||||
max_string_array_length: int
|
||||
max_object_array_length: int
|
||||
@ -1,10 +1,21 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutor
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
Jinja2TemplateRenderer,
|
||||
)
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
@ -27,9 +38,29 @@ class DifyNodeFactory(NodeFactory):
|
||||
self,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
code_executor: type[CodeExecutor] | None = None,
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits | None = None,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
) -> None:
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
|
||||
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
|
||||
tuple(code_providers) if code_providers else CodeNode.default_code_providers()
|
||||
)
|
||||
self._code_limits = code_limits or CodeNodeLimits(
|
||||
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
|
||||
max_number=dify_config.CODE_MAX_NUMBER,
|
||||
min_number=dify_config.CODE_MIN_NUMBER,
|
||||
max_precision=dify_config.CODE_MAX_PRECISION,
|
||||
max_depth=dify_config.CODE_MAX_DEPTH,
|
||||
max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
|
||||
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: dict[str, object]) -> Node:
|
||||
@ -72,6 +103,26 @@ class DifyNodeFactory(NodeFactory):
|
||||
raise ValueError(f"No latest version class found for node type: {node_type}")
|
||||
|
||||
# Create node instance
|
||||
if node_type == NodeType.CODE:
|
||||
return CodeNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
code_executor=self._code_executor,
|
||||
code_providers=self._code_providers,
|
||||
code_limits=self._code_limits,
|
||||
)
|
||||
|
||||
if node_type == NodeType.TEMPLATE_TRANSFORM:
|
||||
return TemplateTransformNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
template_renderer=self._template_renderer,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
|
||||
@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
|
||||
|
||||
class TemplateRenderError(ValueError):
|
||||
"""Raised when rendering a Jinja2 template fails."""
|
||||
|
||||
|
||||
class Jinja2TemplateRenderer(Protocol):
|
||||
"""Render Jinja2 templates for template transform nodes."""
|
||||
|
||||
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
|
||||
"""Render a Jinja2 template with provided variables."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
|
||||
"""Adapter that renders Jinja2 templates via CodeExecutor."""
|
||||
|
||||
_code_executor: type[CodeExecutor]
|
||||
|
||||
def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None:
|
||||
self._code_executor = code_executor or CodeExecutor
|
||||
|
||||
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
|
||||
try:
|
||||
result = self._code_executor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=template, inputs=variables
|
||||
)
|
||||
except CodeExecutionError as exc:
|
||||
raise TemplateRenderError(str(exc)) from exc
|
||||
|
||||
rendered = result.get("result")
|
||||
if not isinstance(rendered, str):
|
||||
raise TemplateRenderError("Template render result must be a string.")
|
||||
return rendered
|
||||
@ -1,18 +1,44 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
Jinja2TemplateRenderer,
|
||||
TemplateRenderError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
|
||||
|
||||
class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
_template_renderer: Jinja2TemplateRenderer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
@ -39,13 +65,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
variables[variable_name] = value.to_object() if value else None
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
|
||||
)
|
||||
except CodeExecutionError as e:
|
||||
rendered = self._template_renderer.render_template(self.node_data.template, variables)
|
||||
except TemplateRenderError as e:
|
||||
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
||||
|
||||
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
return NodeRunResult(
|
||||
inputs=variables,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
@ -53,7 +77,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -46,7 +46,11 @@ def _get_celery_ssl_options() -> dict[str, Any] | None:
|
||||
def init_app(app: DifyApp) -> Celery:
|
||||
class FlaskTask(Task):
|
||||
def __call__(self, *args: object, **kwargs: object) -> object:
|
||||
from core.logging.context import init_request_context
|
||||
|
||||
with app.app_context():
|
||||
# Initialize logging context for this task (similar to before_request in Flask)
|
||||
init_request_context()
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
broker_transport_options = {}
|
||||
|
||||
@ -11,6 +11,7 @@ def init_app(app: DifyApp):
|
||||
create_tenant,
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
file_usage,
|
||||
fix_app_site_missing,
|
||||
install_plugins,
|
||||
install_rag_pipeline_plugins,
|
||||
@ -47,6 +48,7 @@ def init_app(app: DifyApp):
|
||||
clear_free_plan_tenant_expired_logs,
|
||||
clear_orphaned_file_records,
|
||||
remove_orphaned_files_on_storage,
|
||||
file_usage,
|
||||
setup_system_tool_oauth_client,
|
||||
setup_system_trigger_oauth_client,
|
||||
cleanup_orphaned_draft_variables,
|
||||
|
||||
@ -53,3 +53,10 @@ def _setup_gevent_compatibility():
|
||||
def init_app(app: DifyApp):
|
||||
db.init_app(app)
|
||||
_setup_gevent_compatibility()
|
||||
|
||||
# Eagerly build the engine so pool_size/max_overflow/etc. come from config
|
||||
try:
|
||||
with app.app_context():
|
||||
_ = db.engine # triggers engine creation with the configured options
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize SQLAlchemy engine during app startup")
|
||||
|
||||
@ -1,18 +1,19 @@
|
||||
"""Logging extension for Dify Flask application."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
import flask
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.trace_id_helper import get_trace_id_from_otel_context
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
"""Initialize logging with support for text or JSON format."""
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
# File handler
|
||||
log_file = dify_config.LOG_FILE
|
||||
if log_file:
|
||||
log_dir = os.path.dirname(log_file)
|
||||
@ -25,27 +26,53 @@ def init_app(app: DifyApp):
|
||||
)
|
||||
)
|
||||
|
||||
# Always add StreamHandler to log to console
|
||||
# Console handler
|
||||
sh = logging.StreamHandler(sys.stdout)
|
||||
log_handlers.append(sh)
|
||||
|
||||
# Apply RequestIdFilter to all handlers
|
||||
for handler in log_handlers:
|
||||
handler.addFilter(RequestIdFilter())
|
||||
# Apply filters to all handlers
|
||||
from core.logging.filters import IdentityContextFilter, TraceContextFilter
|
||||
|
||||
for handler in log_handlers:
|
||||
handler.addFilter(TraceContextFilter())
|
||||
handler.addFilter(IdentityContextFilter())
|
||||
|
||||
# Configure formatter based on format type
|
||||
formatter = _create_formatter()
|
||||
for handler in log_handlers:
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
# Configure root logger
|
||||
logging.basicConfig(
|
||||
level=dify_config.LOG_LEVEL,
|
||||
format=dify_config.LOG_FORMAT,
|
||||
datefmt=dify_config.LOG_DATEFORMAT,
|
||||
handlers=log_handlers,
|
||||
force=True,
|
||||
)
|
||||
|
||||
# Apply RequestIdFormatter to all handlers
|
||||
apply_request_id_formatter()
|
||||
|
||||
# Disable propagation for noisy loggers to avoid duplicate logs
|
||||
logging.getLogger("sqlalchemy.engine").propagate = False
|
||||
|
||||
# Apply timezone if specified (only for text format)
|
||||
if dify_config.LOG_OUTPUT_FORMAT == "text":
|
||||
_apply_timezone(log_handlers)
|
||||
|
||||
|
||||
def _create_formatter() -> logging.Formatter:
|
||||
"""Create appropriate formatter based on configuration."""
|
||||
if dify_config.LOG_OUTPUT_FORMAT == "json":
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
return StructuredJSONFormatter()
|
||||
else:
|
||||
# Text format - use existing pattern with backward compatible formatter
|
||||
return _TextFormatter(
|
||||
fmt=dify_config.LOG_FORMAT,
|
||||
datefmt=dify_config.LOG_DATEFORMAT,
|
||||
)
|
||||
|
||||
|
||||
def _apply_timezone(handlers: list[logging.Handler]):
|
||||
"""Apply timezone conversion to text formatters."""
|
||||
log_tz = dify_config.LOG_TZ
|
||||
if log_tz:
|
||||
from datetime import datetime
|
||||
@ -57,34 +84,51 @@ def init_app(app: DifyApp):
|
||||
def time_converter(seconds):
|
||||
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
|
||||
|
||||
for handler in logging.root.handlers:
|
||||
for handler in handlers:
|
||||
if handler.formatter:
|
||||
handler.formatter.converter = time_converter
|
||||
handler.formatter.converter = time_converter # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def get_request_id():
|
||||
if getattr(flask.g, "request_id", None):
|
||||
return flask.g.request_id
|
||||
class _TextFormatter(logging.Formatter):
|
||||
"""Text formatter that ensures trace_id and req_id are always present."""
|
||||
|
||||
new_uuid = uuid.uuid4().hex[:10]
|
||||
flask.g.request_id = new_uuid
|
||||
|
||||
return new_uuid
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
if not hasattr(record, "req_id"):
|
||||
record.req_id = ""
|
||||
if not hasattr(record, "trace_id"):
|
||||
record.trace_id = ""
|
||||
if not hasattr(record, "span_id"):
|
||||
record.span_id = ""
|
||||
return super().format(record)
|
||||
|
||||
|
||||
def get_request_id() -> str:
|
||||
"""Get request ID for current request context.
|
||||
|
||||
Deprecated: Use core.logging.context.get_request_id() directly.
|
||||
"""
|
||||
from core.logging.context import get_request_id as _get_request_id
|
||||
|
||||
return _get_request_id()
|
||||
|
||||
|
||||
# Backward compatibility aliases
|
||||
class RequestIdFilter(logging.Filter):
|
||||
# This is a logging filter that makes the request ID available for use in
|
||||
# the logging format. Note that we're checking if we're in a request
|
||||
# context, as we may want to log things before Flask is fully loaded.
|
||||
def filter(self, record):
|
||||
trace_id = get_trace_id_from_otel_context() or ""
|
||||
record.req_id = get_request_id() if flask.has_request_context() else ""
|
||||
record.trace_id = trace_id
|
||||
"""Deprecated: Use TraceContextFilter from core.logging.filters instead."""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
from core.logging.context import get_request_id as _get_request_id
|
||||
from core.logging.context import get_trace_id as _get_trace_id
|
||||
|
||||
record.req_id = _get_request_id()
|
||||
record.trace_id = _get_trace_id()
|
||||
return True
|
||||
|
||||
|
||||
class RequestIdFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
"""Deprecated: Use _TextFormatter instead."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
if not hasattr(record, "req_id"):
|
||||
record.req_id = ""
|
||||
if not hasattr(record, "trace_id"):
|
||||
@ -93,6 +137,7 @@ class RequestIdFormatter(logging.Formatter):
|
||||
|
||||
|
||||
def apply_request_id_formatter():
|
||||
"""Deprecated: Formatter is now applied in init_app."""
|
||||
for handler in logging.root.handlers:
|
||||
if handler.formatter:
|
||||
handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT)
|
||||
|
||||
@ -22,6 +22,18 @@ from models.enums import WorkflowRunTriggeredFrom
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def to_serializable(obj):
|
||||
"""
|
||||
Convert non-JSON-serializable objects into JSON-compatible formats.
|
||||
|
||||
- Uses `to_dict()` if it's a callable method.
|
||||
- Falls back to string representation.
|
||||
"""
|
||||
if hasattr(obj, "to_dict") and callable(obj.to_dict):
|
||||
return obj.to_dict()
|
||||
return str(obj)
|
||||
|
||||
|
||||
class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
def __init__(
|
||||
self,
|
||||
@ -108,9 +120,24 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
),
|
||||
("type", domain_model.workflow_type.value),
|
||||
("version", domain_model.workflow_version),
|
||||
("graph", json.dumps(domain_model.graph, ensure_ascii=False) if domain_model.graph else "{}"),
|
||||
("inputs", json.dumps(domain_model.inputs, ensure_ascii=False) if domain_model.inputs else "{}"),
|
||||
("outputs", json.dumps(domain_model.outputs, ensure_ascii=False) if domain_model.outputs else "{}"),
|
||||
(
|
||||
"graph",
|
||||
json.dumps(domain_model.graph, ensure_ascii=False, default=to_serializable)
|
||||
if domain_model.graph
|
||||
else "{}",
|
||||
),
|
||||
(
|
||||
"inputs",
|
||||
json.dumps(domain_model.inputs, ensure_ascii=False, default=to_serializable)
|
||||
if domain_model.inputs
|
||||
else "{}",
|
||||
),
|
||||
(
|
||||
"outputs",
|
||||
json.dumps(domain_model.outputs, ensure_ascii=False, default=to_serializable)
|
||||
if domain_model.outputs
|
||||
else "{}",
|
||||
),
|
||||
("status", domain_model.status.value),
|
||||
("error_message", domain_model.error_message or ""),
|
||||
("total_tokens", str(domain_model.total_tokens)),
|
||||
|
||||
@ -19,26 +19,43 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExceptionLoggingHandler(logging.Handler):
|
||||
"""
|
||||
Handler that records exceptions to the current OpenTelemetry span.
|
||||
|
||||
Unlike creating a new span, this records exceptions on the existing span
|
||||
to maintain trace context consistency throughout the request lifecycle.
|
||||
"""
|
||||
|
||||
def emit(self, record: logging.LogRecord):
|
||||
with contextlib.suppress(Exception):
|
||||
if record.exc_info:
|
||||
tracer = get_tracer_provider().get_tracer("dify.exception.logging")
|
||||
with tracer.start_as_current_span(
|
||||
"log.exception",
|
||||
attributes={
|
||||
"log.level": record.levelname,
|
||||
"log.message": record.getMessage(),
|
||||
"log.logger": record.name,
|
||||
"log.file.path": record.pathname,
|
||||
"log.file.line": record.lineno,
|
||||
},
|
||||
) as span:
|
||||
span.set_status(StatusCode.ERROR)
|
||||
if record.exc_info[1]:
|
||||
span.record_exception(record.exc_info[1])
|
||||
span.set_attribute("exception.message", str(record.exc_info[1]))
|
||||
if record.exc_info[0]:
|
||||
span.set_attribute("exception.type", record.exc_info[0].__name__)
|
||||
if not record.exc_info:
|
||||
return
|
||||
|
||||
from opentelemetry.trace import get_current_span
|
||||
|
||||
span = get_current_span()
|
||||
if not span or not span.is_recording():
|
||||
return
|
||||
|
||||
# Record exception on the current span instead of creating a new one
|
||||
span.set_status(StatusCode.ERROR, record.getMessage())
|
||||
|
||||
# Add log context as span events/attributes
|
||||
span.add_event(
|
||||
"log.exception",
|
||||
attributes={
|
||||
"log.level": record.levelname,
|
||||
"log.message": record.getMessage(),
|
||||
"log.logger": record.name,
|
||||
"log.file.path": record.pathname,
|
||||
"log.file.line": record.lineno,
|
||||
},
|
||||
)
|
||||
|
||||
if record.exc_info[1]:
|
||||
span.record_exception(record.exc_info[1])
|
||||
if record.exc_info[0]:
|
||||
span.set_attribute("exception.type", record.exc_info[0].__name__)
|
||||
|
||||
|
||||
def instrument_exception_logging() -> None:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
@ -12,7 +12,7 @@ annotation_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_annotation_model(api_or_ns: Api | Namespace):
|
||||
def build_annotation_model(api_or_ns: Namespace):
|
||||
"""Build the annotation model for the API or Namespace."""
|
||||
return api_or_ns.model("Annotation", annotation_fields)
|
||||
|
||||
|
||||
@ -1,236 +1,338 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from fields.member_fields import simple_account_fields
|
||||
from libs.helper import TimestampField
|
||||
from datetime import datetime
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from .raws import FilesContainedField
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from core.file import File
|
||||
|
||||
JSONValue: TypeAlias = Any
|
||||
|
||||
|
||||
class MessageTextField(fields.Raw):
|
||||
def format(self, value):
|
||||
return value[0]["text"] if value else ""
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
|
||||
feedback_fields = {
|
||||
"rating": fields.String,
|
||||
"content": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account": fields.Nested(simple_account_fields, allow_null=True),
|
||||
}
|
||||
class MessageFile(ResponseModel):
|
||||
id: str
|
||||
filename: str
|
||||
type: str
|
||||
url: str | None = None
|
||||
mime_type: str | None = None
|
||||
size: int | None = None
|
||||
transfer_method: str
|
||||
belongs_to: str | None = None
|
||||
upload_file_id: str | None = None
|
||||
|
||||
annotation_fields = {
|
||||
"id": fields.String,
|
||||
"question": fields.String,
|
||||
"content": fields.String,
|
||||
"account": fields.Nested(simple_account_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
annotation_hit_history_fields = {
|
||||
"annotation_id": fields.String(attribute="id"),
|
||||
"annotation_create_account": fields.Nested(simple_account_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
message_file_fields = {
|
||||
"id": fields.String,
|
||||
"filename": fields.String,
|
||||
"type": fields.String,
|
||||
"url": fields.String,
|
||||
"mime_type": fields.String,
|
||||
"size": fields.Integer,
|
||||
"transfer_method": fields.String,
|
||||
"belongs_to": fields.String(default="user"),
|
||||
"upload_file_id": fields.String(default=None),
|
||||
}
|
||||
@field_validator("transfer_method", mode="before")
|
||||
@classmethod
|
||||
def _normalize_transfer_method(cls, value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
def build_message_file_model(api_or_ns: Api | Namespace):
|
||||
"""Build the message file fields for the API or Namespace."""
|
||||
return api_or_ns.model("MessageFile", message_file_fields)
|
||||
class SimpleConversation(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
inputs: dict[str, JSONValue]
|
||||
status: str
|
||||
introduction: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
|
||||
return format_files_contained(value)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
agent_thought_fields = {
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"thought": fields.String,
|
||||
"tool": fields.String,
|
||||
"tool_labels": fields.Raw,
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"files": fields.List(fields.String),
|
||||
}
|
||||
|
||||
message_detail_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": fields.Raw,
|
||||
"message_tokens": fields.Integer,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"answer_tokens": fields.Integer,
|
||||
"provider_response_latency": fields.Float,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"feedbacks": fields.List(fields.Nested(feedback_fields)),
|
||||
"workflow_run_id": fields.String,
|
||||
"annotation": fields.Nested(annotation_fields, allow_null=True),
|
||||
"annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
}
|
||||
|
||||
feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer}
|
||||
status_count_fields = {"success": fields.Integer, "failed": fields.Integer, "partial_success": fields.Integer}
|
||||
model_config_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"model": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"pre_prompt": fields.String,
|
||||
"agent_mode": fields.Raw,
|
||||
}
|
||||
|
||||
simple_model_config_fields = {
|
||||
"model": fields.Raw(attribute="model_dict"),
|
||||
"pre_prompt": fields.String,
|
||||
}
|
||||
|
||||
simple_message_detail_fields = {
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": MessageTextField,
|
||||
"answer": fields.String,
|
||||
}
|
||||
|
||||
conversation_fields = {
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_end_user_session_id": fields.String(),
|
||||
"from_account_id": fields.String,
|
||||
"from_account_name": fields.String,
|
||||
"read_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotation": fields.Nested(annotation_fields, allow_null=True),
|
||||
"model_config": fields.Nested(simple_model_config_fields),
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
"message": fields.Nested(simple_message_detail_fields, attribute="first_message"),
|
||||
}
|
||||
|
||||
conversation_pagination_fields = {
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(conversation_fields), attribute="items"),
|
||||
}
|
||||
|
||||
conversation_message_detail_fields = {
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"model_config": fields.Nested(model_config_fields),
|
||||
"message": fields.Nested(message_detail_fields, attribute="first_message"),
|
||||
}
|
||||
|
||||
conversation_with_summary_fields = {
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_end_user_session_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"from_account_name": fields.String,
|
||||
"name": fields.String,
|
||||
"summary": fields.String(attribute="summary_or_query"),
|
||||
"read_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotated": fields.Boolean,
|
||||
"model_config": fields.Nested(simple_model_config_fields),
|
||||
"message_count": fields.Integer,
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
"status_count": fields.Nested(status_count_fields),
|
||||
}
|
||||
|
||||
conversation_with_summary_pagination_fields = {
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(conversation_with_summary_fields), attribute="items"),
|
||||
}
|
||||
|
||||
conversation_detail_fields = {
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotated": fields.Boolean,
|
||||
"introduction": fields.String,
|
||||
"model_config": fields.Nested(model_config_fields),
|
||||
"message_count": fields.Integer,
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
}
|
||||
|
||||
simple_conversation_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"status": fields.String,
|
||||
"introduction": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
conversation_delete_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
conversation_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(simple_conversation_fields)),
|
||||
}
|
||||
class ConversationInfiniteScrollPagination(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[SimpleConversation]
|
||||
|
||||
|
||||
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
|
||||
"""Build the conversation infinite scroll pagination model for the API or Namespace."""
|
||||
simple_conversation_model = build_simple_conversation_model(api_or_ns)
|
||||
|
||||
copied_fields = conversation_infinite_scroll_pagination_fields.copy()
|
||||
copied_fields["data"] = fields.List(fields.Nested(simple_conversation_model))
|
||||
return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields)
|
||||
class ConversationDelete(ResponseModel):
|
||||
result: str
|
||||
|
||||
|
||||
def build_conversation_delete_model(api_or_ns: Api | Namespace):
|
||||
"""Build the conversation delete model for the API or Namespace."""
|
||||
return api_or_ns.model("ConversationDelete", conversation_delete_fields)
|
||||
class ResultResponse(ResponseModel):
|
||||
result: str
|
||||
|
||||
|
||||
def build_simple_conversation_model(api_or_ns: Api | Namespace):
|
||||
"""Build the simple conversation model for the API or Namespace."""
|
||||
return api_or_ns.model("SimpleConversation", simple_conversation_fields)
|
||||
class SimpleAccount(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
|
||||
|
||||
class Feedback(ResponseModel):
|
||||
rating: str
|
||||
content: str | None = None
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account: SimpleAccount | None = None
|
||||
|
||||
|
||||
class Annotation(ResponseModel):
|
||||
id: str
|
||||
question: str | None = None
|
||||
content: str
|
||||
account: SimpleAccount | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
class AnnotationHitHistory(ResponseModel):
|
||||
annotation_id: str
|
||||
annotation_create_account: SimpleAccount | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
class AgentThought(ResponseModel):
|
||||
id: str
|
||||
chain_id: str | None = None
|
||||
message_chain_id: str | None = Field(default=None, exclude=True, validation_alias="message_chain_id")
|
||||
message_id: str
|
||||
position: int
|
||||
thought: str | None = None
|
||||
tool: str | None = None
|
||||
tool_labels: JSONValue
|
||||
tool_input: str | None = None
|
||||
created_at: int | None = None
|
||||
observation: str | None = None
|
||||
files: list[str]
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _fallback_chain_id(self):
|
||||
if self.chain_id is None and self.message_chain_id:
|
||||
self.chain_id = self.message_chain_id
|
||||
return self
|
||||
|
||||
|
||||
class MessageDetail(ResponseModel):
|
||||
id: str
|
||||
conversation_id: str
|
||||
inputs: dict[str, JSONValue]
|
||||
query: str
|
||||
message: JSONValue
|
||||
message_tokens: int
|
||||
answer: str
|
||||
answer_tokens: int
|
||||
provider_response_latency: float
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
feedbacks: list[Feedback]
|
||||
workflow_run_id: str | None = None
|
||||
annotation: Annotation | None = None
|
||||
annotation_hit_history: AnnotationHitHistory | None = None
|
||||
created_at: int | None = None
|
||||
agent_thoughts: list[AgentThought]
|
||||
message_files: list[MessageFile]
|
||||
metadata: JSONValue
|
||||
status: str
|
||||
error: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
|
||||
return format_files_contained(value)
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
class FeedbackStat(ResponseModel):
|
||||
like: int
|
||||
dislike: int
|
||||
|
||||
|
||||
class StatusCount(ResponseModel):
|
||||
success: int
|
||||
failed: int
|
||||
partial_success: int
|
||||
|
||||
|
||||
class ModelConfig(ResponseModel):
|
||||
opening_statement: str | None = None
|
||||
suggested_questions: JSONValue | None = None
|
||||
model: JSONValue | None = None
|
||||
user_input_form: JSONValue | None = None
|
||||
pre_prompt: str | None = None
|
||||
agent_mode: JSONValue | None = None
|
||||
|
||||
|
||||
class SimpleModelConfig(ResponseModel):
|
||||
model: JSONValue | None = None
|
||||
pre_prompt: str | None = None
|
||||
|
||||
|
||||
class SimpleMessageDetail(ResponseModel):
|
||||
inputs: dict[str, JSONValue]
|
||||
query: str
|
||||
message: str
|
||||
answer: str
|
||||
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
|
||||
return format_files_contained(value)
|
||||
|
||||
|
||||
class Conversation(ResponseModel):
|
||||
id: str
|
||||
status: str
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_end_user_session_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
from_account_name: str | None = None
|
||||
read_at: int | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
annotation: Annotation | None = None
|
||||
model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config")
|
||||
user_feedback_stats: FeedbackStat | None = None
|
||||
admin_feedback_stats: FeedbackStat | None = None
|
||||
message: SimpleMessageDetail | None = None
|
||||
|
||||
|
||||
class ConversationPagination(ResponseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[Conversation]
|
||||
|
||||
|
||||
class ConversationMessageDetail(ResponseModel):
|
||||
id: str
|
||||
status: str
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
created_at: int | None = None
|
||||
model_config_: ModelConfig | None = Field(default=None, alias="model_config")
|
||||
message: MessageDetail | None = None
|
||||
|
||||
|
||||
class ConversationWithSummary(ResponseModel):
|
||||
id: str
|
||||
status: str
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_end_user_session_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
from_account_name: str | None = None
|
||||
name: str
|
||||
summary: str
|
||||
read_at: int | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
annotated: bool
|
||||
model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config")
|
||||
message_count: int
|
||||
user_feedback_stats: FeedbackStat | None = None
|
||||
admin_feedback_stats: FeedbackStat | None = None
|
||||
status_count: StatusCount | None = None
|
||||
|
||||
|
||||
class ConversationWithSummaryPagination(ResponseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[ConversationWithSummary]
|
||||
|
||||
|
||||
class ConversationDetail(ResponseModel):
|
||||
id: str
|
||||
status: str
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
annotated: bool
|
||||
introduction: str | None = None
|
||||
model_config_: ModelConfig | None = Field(default=None, alias="model_config")
|
||||
message_count: int
|
||||
user_feedback_stats: FeedbackStat | None = None
|
||||
admin_feedback_stats: FeedbackStat | None = None
|
||||
|
||||
|
||||
def to_timestamp(value: datetime | None) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def format_files_contained(value: JSONValue) -> JSONValue:
|
||||
if isinstance(value, File):
|
||||
return value.model_dump()
|
||||
if isinstance(value, dict):
|
||||
return {k: format_files_contained(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [format_files_contained(v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
def message_text(value: JSONValue) -> str:
|
||||
if isinstance(value, list) and value:
|
||||
first = value[0]
|
||||
if isinstance(first, dict):
|
||||
text = first.get("text")
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
return ""
|
||||
|
||||
|
||||
def extract_model_config(value: object | None) -> dict[str, JSONValue]:
|
||||
if value is None:
|
||||
return {}
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if hasattr(value, "to_dict"):
|
||||
return value.to_dict()
|
||||
return {}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
@ -29,12 +29,12 @@ conversation_variable_infinite_scroll_pagination_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_conversation_variable_model(api_or_ns: Api | Namespace):
|
||||
def build_conversation_variable_model(api_or_ns: Namespace):
|
||||
"""Build the conversation variable model for the API or Namespace."""
|
||||
return api_or_ns.model("ConversationVariable", conversation_variable_fields)
|
||||
|
||||
|
||||
def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
|
||||
def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Namespace):
|
||||
"""Build the conversation variable infinite scroll pagination model for the API or Namespace."""
|
||||
# Build the nested variable model first
|
||||
conversation_variable_model = build_conversation_variable_model(api_or_ns)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
simple_end_user_fields = {
|
||||
"id": fields.String,
|
||||
@ -8,5 +8,5 @@ simple_end_user_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_simple_end_user_model(api_or_ns: Api | Namespace):
|
||||
def build_simple_end_user_model(api_or_ns: Namespace):
|
||||
return api_or_ns.model("SimpleEndUser", simple_end_user_fields)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
@ -14,7 +14,7 @@ upload_config_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_upload_config_model(api_or_ns: Api | Namespace):
|
||||
def build_upload_config_model(api_or_ns: Namespace):
|
||||
"""Build the upload config model for the API or Namespace.
|
||||
|
||||
Args:
|
||||
@ -39,7 +39,7 @@ file_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_file_model(api_or_ns: Api | Namespace):
|
||||
def build_file_model(api_or_ns: Namespace):
|
||||
"""Build the file model for the API or Namespace.
|
||||
|
||||
Args:
|
||||
@ -57,7 +57,7 @@ remote_file_info_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_remote_file_info_model(api_or_ns: Api | Namespace):
|
||||
def build_remote_file_info_model(api_or_ns: Namespace):
|
||||
"""Build the remote file info model for the API or Namespace.
|
||||
|
||||
Args:
|
||||
@ -81,7 +81,7 @@ file_fields_with_signed_url = {
|
||||
}
|
||||
|
||||
|
||||
def build_file_with_signed_url_model(api_or_ns: Api | Namespace):
|
||||
def build_file_with_signed_url_model(api_or_ns: Namespace):
|
||||
"""Build the file with signed URL model for the API or Namespace.
|
||||
|
||||
Args:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from libs.helper import AvatarUrlField, TimestampField
|
||||
|
||||
@ -9,7 +9,7 @@ simple_account_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_simple_account_model(api_or_ns: Api | Namespace):
|
||||
def build_simple_account_model(api_or_ns: Namespace):
|
||||
return api_or_ns.model("SimpleAccount", simple_account_fields)
|
||||
|
||||
|
||||
|
||||
@ -1,77 +1,137 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField
|
||||
from datetime import datetime
|
||||
from typing import TypeAlias
|
||||
|
||||
from .raws import FilesContainedField
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
feedback_fields = {
|
||||
"rating": fields.String,
|
||||
}
|
||||
from core.file import File
|
||||
from fields.conversation_fields import AgentThought, JSONValue, MessageFile
|
||||
|
||||
JSONValueType: TypeAlias = JSONValue
|
||||
|
||||
|
||||
def build_feedback_model(api_or_ns: Api | Namespace):
|
||||
"""Build the feedback model for the API or Namespace."""
|
||||
return api_or_ns.model("Feedback", feedback_fields)
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||
|
||||
|
||||
agent_thought_fields = {
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"thought": fields.String,
|
||||
"tool": fields.String,
|
||||
"tool_labels": fields.Raw,
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"files": fields.List(fields.String),
|
||||
}
|
||||
class SimpleFeedback(ResponseModel):
|
||||
rating: str | None = None
|
||||
|
||||
|
||||
def build_agent_thought_model(api_or_ns: Api | Namespace):
|
||||
"""Build the agent thought model for the API or Namespace."""
|
||||
return api_or_ns.model("AgentThought", agent_thought_fields)
|
||||
class RetrieverResource(ResponseModel):
|
||||
id: str
|
||||
message_id: str
|
||||
position: int
|
||||
dataset_id: str | None = None
|
||||
dataset_name: str | None = None
|
||||
document_id: str | None = None
|
||||
document_name: str | None = None
|
||||
data_source_type: str | None = None
|
||||
segment_id: str | None = None
|
||||
score: float | None = None
|
||||
hit_count: int | None = None
|
||||
word_count: int | None = None
|
||||
segment_position: int | None = None
|
||||
index_node_hash: str | None = None
|
||||
content: str | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
retriever_resource_fields = {
|
||||
"id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"dataset_id": fields.String,
|
||||
"dataset_name": fields.String,
|
||||
"document_id": fields.String,
|
||||
"document_name": fields.String,
|
||||
"data_source_type": fields.String,
|
||||
"segment_id": fields.String,
|
||||
"score": fields.Float,
|
||||
"hit_count": fields.Integer,
|
||||
"word_count": fields.Integer,
|
||||
"segment_position": fields.Integer,
|
||||
"index_node_hash": fields.String,
|
||||
"content": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
class MessageListItem(ResponseModel):
|
||||
id: str
|
||||
conversation_id: str
|
||||
parent_message_id: str | None = None
|
||||
inputs: dict[str, JSONValueType]
|
||||
query: str
|
||||
answer: str = Field(validation_alias="re_sign_file_url_answer")
|
||||
feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback")
|
||||
retriever_resources: list[RetrieverResource]
|
||||
created_at: int | None = None
|
||||
agent_thoughts: list[AgentThought]
|
||||
message_files: list[MessageFile]
|
||||
status: str
|
||||
error: str | None = None
|
||||
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
}
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType:
|
||||
return format_files_contained(value)
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_fields)),
|
||||
}
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
class WebMessageListItem(MessageListItem):
|
||||
metadata: JSONValueType | None = Field(default=None, validation_alias="message_metadata_dict")
|
||||
|
||||
|
||||
class MessageInfiniteScrollPagination(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[MessageListItem]
|
||||
|
||||
|
||||
class WebMessageInfiniteScrollPagination(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[WebMessageListItem]
|
||||
|
||||
|
||||
class SavedMessageItem(ResponseModel):
|
||||
id: str
|
||||
inputs: dict[str, JSONValueType]
|
||||
query: str
|
||||
answer: str
|
||||
message_files: list[MessageFile]
|
||||
feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback")
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType:
|
||||
return format_files_contained(value)
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
class SavedMessageInfiniteScrollPagination(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[SavedMessageItem]
|
||||
|
||||
|
||||
class SuggestedQuestionsResponse(ResponseModel):
|
||||
data: list[str]
|
||||
|
||||
|
||||
def to_timestamp(value: datetime | None) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def format_files_contained(value: JSONValueType) -> JSONValueType:
|
||||
if isinstance(value, File):
|
||||
return value.model_dump()
|
||||
if isinstance(value, dict):
|
||||
return {k: format_files_contained(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [format_files_contained(v) for v in value]
|
||||
return value
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
dataset_tag_fields = {
|
||||
"id": fields.String,
|
||||
@ -8,5 +8,5 @@ dataset_tag_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_dataset_tag_fields(api_or_ns: Api | Namespace):
|
||||
def build_dataset_tag_fields(api_or_ns: Namespace):
|
||||
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
|
||||
from fields.member_fields import build_simple_account_model, simple_account_fields
|
||||
@ -17,7 +17,7 @@ workflow_app_log_partial_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace):
|
||||
def build_workflow_app_log_partial_model(api_or_ns: Namespace):
|
||||
"""Build the workflow app log partial model for the API or Namespace."""
|
||||
workflow_run_model = build_workflow_run_for_log_model(api_or_ns)
|
||||
simple_account_model = build_simple_account_model(api_or_ns)
|
||||
@ -43,7 +43,7 @@ workflow_app_log_pagination_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace):
|
||||
def build_workflow_app_log_pagination_model(api_or_ns: Namespace):
|
||||
"""Build the workflow app log pagination model for the API or Namespace."""
|
||||
# Build the nested partial model first
|
||||
workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
from fields.member_fields import simple_account_fields
|
||||
@ -19,7 +19,7 @@ workflow_run_for_log_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_workflow_run_for_log_model(api_or_ns: Api | Namespace):
|
||||
def build_workflow_run_for_log_model(api_or_ns: Namespace):
|
||||
return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields)
|
||||
|
||||
|
||||
|
||||
347
api/libs/archive_storage.py
Normal file
347
api/libs/archive_storage.py
Normal file
@ -0,0 +1,347 @@
|
||||
"""
|
||||
Archive Storage Client for S3-compatible storage.
|
||||
|
||||
This module provides a dedicated storage client for archiving or exporting logs
|
||||
to S3-compatible object storage.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import gzip
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
import boto3
|
||||
import orjson
|
||||
from botocore.client import Config
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ArchiveStorageError(Exception):
|
||||
"""Base exception for archive storage operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ArchiveStorageNotConfiguredError(ArchiveStorageError):
|
||||
"""Raised when archive storage is not properly configured."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ArchiveStorage:
|
||||
"""
|
||||
S3-compatible storage client for archiving or exporting.
|
||||
|
||||
This client provides methods for storing and retrieving archived data in JSONL+gzip format.
|
||||
"""
|
||||
|
||||
def __init__(self, bucket: str):
|
||||
if not dify_config.ARCHIVE_STORAGE_ENABLED:
|
||||
raise ArchiveStorageNotConfiguredError("Archive storage is not enabled")
|
||||
|
||||
if not bucket:
|
||||
raise ArchiveStorageNotConfiguredError("Archive storage bucket is not configured")
|
||||
if not all(
|
||||
[
|
||||
dify_config.ARCHIVE_STORAGE_ENDPOINT,
|
||||
bucket,
|
||||
dify_config.ARCHIVE_STORAGE_ACCESS_KEY,
|
||||
dify_config.ARCHIVE_STORAGE_SECRET_KEY,
|
||||
]
|
||||
):
|
||||
raise ArchiveStorageNotConfiguredError(
|
||||
"Archive storage configuration is incomplete. "
|
||||
"Required: ARCHIVE_STORAGE_ENDPOINT, ARCHIVE_STORAGE_ACCESS_KEY, "
|
||||
"ARCHIVE_STORAGE_SECRET_KEY, and a bucket name"
|
||||
)
|
||||
|
||||
self.bucket = bucket
|
||||
self.client = boto3.client(
|
||||
"s3",
|
||||
endpoint_url=dify_config.ARCHIVE_STORAGE_ENDPOINT,
|
||||
aws_access_key_id=dify_config.ARCHIVE_STORAGE_ACCESS_KEY,
|
||||
aws_secret_access_key=dify_config.ARCHIVE_STORAGE_SECRET_KEY,
|
||||
region_name=dify_config.ARCHIVE_STORAGE_REGION,
|
||||
config=Config(s3={"addressing_style": "path"}),
|
||||
)
|
||||
|
||||
# Verify bucket accessibility
|
||||
try:
|
||||
self.client.head_bucket(Bucket=self.bucket)
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code")
|
||||
if error_code == "404":
|
||||
raise ArchiveStorageNotConfiguredError(f"Archive bucket '{self.bucket}' does not exist")
|
||||
elif error_code == "403":
|
||||
raise ArchiveStorageNotConfiguredError(f"Access denied to archive bucket '{self.bucket}'")
|
||||
else:
|
||||
raise ArchiveStorageError(f"Failed to access archive bucket: {e}")
|
||||
|
||||
def put_object(self, key: str, data: bytes) -> str:
|
||||
"""
|
||||
Upload an object to the archive storage.
|
||||
|
||||
Args:
|
||||
key: Object key (path) within the bucket
|
||||
data: Binary data to upload
|
||||
|
||||
Returns:
|
||||
MD5 checksum of the uploaded data
|
||||
|
||||
Raises:
|
||||
ArchiveStorageError: If upload fails
|
||||
"""
|
||||
checksum = hashlib.md5(data).hexdigest()
|
||||
try:
|
||||
self.client.put_object(
|
||||
Bucket=self.bucket,
|
||||
Key=key,
|
||||
Body=data,
|
||||
ContentMD5=self._content_md5(data),
|
||||
)
|
||||
logger.debug("Uploaded object: %s (size=%d, checksum=%s)", key, len(data), checksum)
|
||||
return checksum
|
||||
except ClientError as e:
|
||||
raise ArchiveStorageError(f"Failed to upload object '{key}': {e}")
|
||||
|
||||
def get_object(self, key: str) -> bytes:
|
||||
"""
|
||||
Download an object from the archive storage.
|
||||
|
||||
Args:
|
||||
key: Object key (path) within the bucket
|
||||
|
||||
Returns:
|
||||
Binary data of the object
|
||||
|
||||
Raises:
|
||||
ArchiveStorageError: If download fails
|
||||
FileNotFoundError: If object does not exist
|
||||
"""
|
||||
try:
|
||||
response = self.client.get_object(Bucket=self.bucket, Key=key)
|
||||
return response["Body"].read()
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code")
|
||||
if error_code == "NoSuchKey":
|
||||
raise FileNotFoundError(f"Archive object not found: {key}")
|
||||
raise ArchiveStorageError(f"Failed to download object '{key}': {e}")
|
||||
|
||||
def get_object_stream(self, key: str) -> Generator[bytes, None, None]:
|
||||
"""
|
||||
Stream an object from the archive storage.
|
||||
|
||||
Args:
|
||||
key: Object key (path) within the bucket
|
||||
|
||||
Yields:
|
||||
Chunks of binary data
|
||||
|
||||
Raises:
|
||||
ArchiveStorageError: If download fails
|
||||
FileNotFoundError: If object does not exist
|
||||
"""
|
||||
try:
|
||||
response = self.client.get_object(Bucket=self.bucket, Key=key)
|
||||
yield from response["Body"].iter_chunks()
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code")
|
||||
if error_code == "NoSuchKey":
|
||||
raise FileNotFoundError(f"Archive object not found: {key}")
|
||||
raise ArchiveStorageError(f"Failed to stream object '{key}': {e}")
|
||||
|
||||
def object_exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if an object exists in the archive storage.
|
||||
|
||||
Args:
|
||||
key: Object key (path) within the bucket
|
||||
|
||||
Returns:
|
||||
True if object exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
self.client.head_object(Bucket=self.bucket, Key=key)
|
||||
return True
|
||||
except ClientError:
|
||||
return False
|
||||
|
||||
def delete_object(self, key: str) -> None:
|
||||
"""
|
||||
Delete an object from the archive storage.
|
||||
|
||||
Args:
|
||||
key: Object key (path) within the bucket
|
||||
|
||||
Raises:
|
||||
ArchiveStorageError: If deletion fails
|
||||
"""
|
||||
try:
|
||||
self.client.delete_object(Bucket=self.bucket, Key=key)
|
||||
logger.debug("Deleted object: %s", key)
|
||||
except ClientError as e:
|
||||
raise ArchiveStorageError(f"Failed to delete object '{key}': {e}")
|
||||
|
||||
def generate_presigned_url(self, key: str, expires_in: int = 3600) -> str:
|
||||
"""
|
||||
Generate a pre-signed URL for downloading an object.
|
||||
|
||||
Args:
|
||||
key: Object key (path) within the bucket
|
||||
expires_in: URL validity duration in seconds (default: 1 hour)
|
||||
|
||||
Returns:
|
||||
Pre-signed URL string.
|
||||
|
||||
Raises:
|
||||
ArchiveStorageError: If generation fails
|
||||
"""
|
||||
try:
|
||||
return self.client.generate_presigned_url(
|
||||
ClientMethod="get_object",
|
||||
Params={"Bucket": self.bucket, "Key": key},
|
||||
ExpiresIn=expires_in,
|
||||
)
|
||||
except ClientError as e:
|
||||
raise ArchiveStorageError(f"Failed to generate pre-signed URL for '{key}': {e}")
|
||||
|
||||
def list_objects(self, prefix: str) -> list[str]:
|
||||
"""
|
||||
List objects under a given prefix.
|
||||
|
||||
Args:
|
||||
prefix: Object key prefix to filter by
|
||||
|
||||
Returns:
|
||||
List of object keys matching the prefix
|
||||
"""
|
||||
keys = []
|
||||
paginator = self.client.get_paginator("list_objects_v2")
|
||||
|
||||
try:
|
||||
for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix):
|
||||
for obj in page.get("Contents", []):
|
||||
keys.append(obj["Key"])
|
||||
except ClientError as e:
|
||||
raise ArchiveStorageError(f"Failed to list objects with prefix '{prefix}': {e}")
|
||||
|
||||
return keys
|
||||
|
||||
@staticmethod
|
||||
def _content_md5(data: bytes) -> str:
|
||||
"""Calculate base64-encoded MD5 for Content-MD5 header."""
|
||||
return base64.b64encode(hashlib.md5(data).digest()).decode()
|
||||
|
||||
@staticmethod
|
||||
def serialize_to_jsonl_gz(records: list[dict[str, Any]]) -> bytes:
|
||||
"""
|
||||
Serialize records to gzipped JSONL format.
|
||||
|
||||
Args:
|
||||
records: List of dictionaries to serialize
|
||||
|
||||
Returns:
|
||||
Gzipped JSONL bytes
|
||||
"""
|
||||
lines = []
|
||||
for record in records:
|
||||
# Convert datetime objects to ISO format strings
|
||||
serialized = ArchiveStorage._serialize_record(record)
|
||||
lines.append(orjson.dumps(serialized))
|
||||
|
||||
jsonl_content = b"\n".join(lines)
|
||||
if jsonl_content:
|
||||
jsonl_content += b"\n"
|
||||
|
||||
return gzip.compress(jsonl_content)
|
||||
|
||||
@staticmethod
|
||||
def deserialize_from_jsonl_gz(data: bytes) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Deserialize gzipped JSONL data to records.
|
||||
|
||||
Args:
|
||||
data: Gzipped JSONL bytes
|
||||
|
||||
Returns:
|
||||
List of dictionaries
|
||||
"""
|
||||
jsonl_content = gzip.decompress(data)
|
||||
records = []
|
||||
|
||||
for line in jsonl_content.splitlines():
|
||||
if line:
|
||||
records.append(orjson.loads(line))
|
||||
|
||||
return records
|
||||
|
||||
@staticmethod
|
||||
def _serialize_record(record: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Serialize a single record, converting special types."""
|
||||
|
||||
def _serialize(item: Any) -> Any:
|
||||
if isinstance(item, datetime.datetime):
|
||||
return item.isoformat()
|
||||
if isinstance(item, dict):
|
||||
return {key: _serialize(value) for key, value in item.items()}
|
||||
if isinstance(item, list):
|
||||
return [_serialize(value) for value in item]
|
||||
return item
|
||||
|
||||
return cast(dict[str, Any], _serialize(record))
|
||||
|
||||
@staticmethod
|
||||
def compute_checksum(data: bytes) -> str:
|
||||
"""Compute MD5 checksum of data."""
|
||||
return hashlib.md5(data).hexdigest()
|
||||
|
||||
|
||||
# Singleton instance (lazy initialization)
|
||||
_archive_storage: ArchiveStorage | None = None
|
||||
_export_storage: ArchiveStorage | None = None
|
||||
|
||||
|
||||
def get_archive_storage() -> ArchiveStorage:
|
||||
"""
|
||||
Get the archive storage singleton instance.
|
||||
|
||||
Returns:
|
||||
ArchiveStorage instance
|
||||
|
||||
Raises:
|
||||
ArchiveStorageNotConfiguredError: If archive storage is not configured
|
||||
"""
|
||||
global _archive_storage
|
||||
if _archive_storage is None:
|
||||
archive_bucket = dify_config.ARCHIVE_STORAGE_ARCHIVE_BUCKET
|
||||
if not archive_bucket:
|
||||
raise ArchiveStorageNotConfiguredError(
|
||||
"Archive storage bucket is not configured. Required: ARCHIVE_STORAGE_ARCHIVE_BUCKET"
|
||||
)
|
||||
_archive_storage = ArchiveStorage(bucket=archive_bucket)
|
||||
return _archive_storage
|
||||
|
||||
|
||||
def get_export_storage() -> ArchiveStorage:
|
||||
"""
|
||||
Get the export storage singleton instance.
|
||||
|
||||
Returns:
|
||||
ArchiveStorage instance
|
||||
"""
|
||||
global _export_storage
|
||||
if _export_storage is None:
|
||||
export_bucket = dify_config.ARCHIVE_STORAGE_EXPORT_BUCKET
|
||||
if not export_bucket:
|
||||
raise ArchiveStorageNotConfiguredError(
|
||||
"Archive export bucket is not configured. Required: ARCHIVE_STORAGE_EXPORT_BUCKET"
|
||||
)
|
||||
_export_storage = ArchiveStorage(bucket=export_bucket)
|
||||
return _export_storage
|
||||
@ -1,5 +1,4 @@
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
@ -109,11 +108,8 @@ def register_external_error_handlers(api: Api):
|
||||
data.setdefault("code", "unknown")
|
||||
data.setdefault("status", status_code)
|
||||
|
||||
# Log stack
|
||||
exc_info: Any = sys.exc_info()
|
||||
if exc_info[1] is None:
|
||||
exc_info = (None, None, None)
|
||||
current_app.log_exception(exc_info)
|
||||
# Note: Exception logging is handled by Flask/Flask-RESTX framework automatically
|
||||
# Explicit log_exception call removed to avoid duplicate log entries
|
||||
|
||||
return data, status_code
|
||||
|
||||
|
||||
@ -11,9 +11,6 @@ from alembic import op
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '00bacef91f18'
|
||||
down_revision = '8ec536f3c800'
|
||||
@ -23,31 +20,17 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('description', sa.Text(), nullable=False))
|
||||
batch_op.drop_column('description_str')
|
||||
else:
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False))
|
||||
batch_op.drop_column('description_str')
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False))
|
||||
batch_op.drop_column('description_str')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('description_str', sa.TEXT(), autoincrement=False, nullable=False))
|
||||
batch_op.drop_column('description')
|
||||
else:
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False))
|
||||
batch_op.drop_column('description')
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False))
|
||||
batch_op.drop_column('description')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@ -7,14 +7,10 @@ Create Date: 2024-01-10 04:40:57.257824
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '114eed84c228'
|
||||
down_revision = 'c71211c8f604'
|
||||
@ -32,13 +28,7 @@ def upgrade():
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tool_id', postgresql.UUID(), autoincrement=False, nullable=False))
|
||||
else:
|
||||
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False))
|
||||
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@ -11,9 +11,6 @@ from alembic import op
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '161cadc1af8d'
|
||||
down_revision = '7e6a8693e07a'
|
||||
@ -23,16 +20,9 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
|
||||
# Step 1: Add column without NOT NULL constraint
|
||||
op.add_column('dataset_permissions', sa.Column('tenant_id', sa.UUID(), nullable=False))
|
||||
else:
|
||||
with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
|
||||
# Step 1: Add column without NOT NULL constraint
|
||||
op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False))
|
||||
with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
|
||||
# Step 1: Add column without NOT NULL constraint
|
||||
op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@ -9,11 +9,6 @@ from alembic import op
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '6af6a521a53e'
|
||||
down_revision = 'd57ba9ebb251'
|
||||
@ -23,58 +18,30 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
|
||||
batch_op.alter_column('document_id',
|
||||
existing_type=sa.UUID(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('data_source_type',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('segment_id',
|
||||
existing_type=sa.UUID(),
|
||||
nullable=True)
|
||||
else:
|
||||
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
|
||||
batch_op.alter_column('document_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('data_source_type',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('segment_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
|
||||
batch_op.alter_column('document_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('data_source_type',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('segment_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
|
||||
batch_op.alter_column('segment_id',
|
||||
existing_type=sa.UUID(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('data_source_type',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('document_id',
|
||||
existing_type=sa.UUID(),
|
||||
nullable=False)
|
||||
else:
|
||||
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
|
||||
batch_op.alter_column('segment_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('data_source_type',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('document_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
|
||||
batch_op.alter_column('segment_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('data_source_type',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('document_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@ -8,7 +8,6 @@ Create Date: 2024-11-01 04:34:23.816198
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'd3f6769a94a3'
|
||||
|
||||
@ -28,85 +28,45 @@ def upgrade():
|
||||
op.execute("UPDATE sites SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL")
|
||||
op.execute("UPDATE tool_api_providers SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL")
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=sa.TEXT(),
|
||||
nullable=False)
|
||||
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=models.types.LongText(),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=sa.TEXT(),
|
||||
nullable=False)
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=models.types.LongText(),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=sa.TEXT(),
|
||||
nullable=False)
|
||||
else:
|
||||
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=models.types.LongText(),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=models.types.LongText(),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=models.types.LongText(),
|
||||
nullable=False)
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=models.types.LongText(),
|
||||
nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=models.types.LongText(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=models.types.LongText(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
|
||||
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
else:
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=models.types.LongText(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=models.types.LongText(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
|
||||
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=models.types.LongText(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=models.types.LongText(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@ -49,57 +49,33 @@ def upgrade():
|
||||
op.execute("UPDATE workflows SET updated_at = created_at WHERE updated_at IS NULL")
|
||||
op.execute("UPDATE workflows SET graph = '' WHERE graph IS NULL")
|
||||
op.execute("UPDATE workflows SET features = '' WHERE features IS NULL")
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.alter_column('graph',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('features',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('updated_at',
|
||||
existing_type=postgresql.TIMESTAMP(),
|
||||
nullable=False)
|
||||
else:
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.alter_column('graph',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('features',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('updated_at',
|
||||
existing_type=sa.TIMESTAMP(),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.alter_column('graph',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('features',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('updated_at',
|
||||
existing_type=sa.TIMESTAMP(),
|
||||
nullable=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.alter_column('updated_at',
|
||||
existing_type=postgresql.TIMESTAMP(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('features',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('graph',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=True)
|
||||
else:
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.alter_column('updated_at',
|
||||
existing_type=sa.TIMESTAMP(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('features',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('graph',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=True)
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.alter_column('updated_at',
|
||||
existing_type=sa.TIMESTAMP(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('features',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('graph',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=True)
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
|
||||
@ -86,57 +86,30 @@ def upgrade():
|
||||
|
||||
def migrate_existing_provider_models_data():
|
||||
"""migrate provider_models table data to provider_model_credentials"""
|
||||
conn = op.get_bind()
|
||||
# Define table structure for data manipulation
|
||||
if _is_pg(conn):
|
||||
provider_models_table = table('provider_models',
|
||||
column('id', models.types.StringUUID()),
|
||||
column('tenant_id', models.types.StringUUID()),
|
||||
column('provider_name', sa.String()),
|
||||
column('model_name', sa.String()),
|
||||
column('model_type', sa.String()),
|
||||
column('encrypted_config', sa.Text()),
|
||||
column('created_at', sa.DateTime()),
|
||||
column('updated_at', sa.DateTime()),
|
||||
column('credential_id', models.types.StringUUID()),
|
||||
)
|
||||
else:
|
||||
provider_models_table = table('provider_models',
|
||||
column('id', models.types.StringUUID()),
|
||||
column('tenant_id', models.types.StringUUID()),
|
||||
column('provider_name', sa.String()),
|
||||
column('model_name', sa.String()),
|
||||
column('model_type', sa.String()),
|
||||
column('encrypted_config', models.types.LongText()),
|
||||
column('created_at', sa.DateTime()),
|
||||
column('updated_at', sa.DateTime()),
|
||||
column('credential_id', models.types.StringUUID()),
|
||||
)
|
||||
# Define table structure for data manipulatio
|
||||
provider_models_table = table('provider_models',
|
||||
column('id', models.types.StringUUID()),
|
||||
column('tenant_id', models.types.StringUUID()),
|
||||
column('provider_name', sa.String()),
|
||||
column('model_name', sa.String()),
|
||||
column('model_type', sa.String()),
|
||||
column('encrypted_config', models.types.LongText()),
|
||||
column('created_at', sa.DateTime()),
|
||||
column('updated_at', sa.DateTime()),
|
||||
column('credential_id', models.types.StringUUID()),
|
||||
)
|
||||
|
||||
if _is_pg(conn):
|
||||
provider_model_credentials_table = table('provider_model_credentials',
|
||||
column('id', models.types.StringUUID()),
|
||||
column('tenant_id', models.types.StringUUID()),
|
||||
column('provider_name', sa.String()),
|
||||
column('model_name', sa.String()),
|
||||
column('model_type', sa.String()),
|
||||
column('credential_name', sa.String()),
|
||||
column('encrypted_config', sa.Text()),
|
||||
column('created_at', sa.DateTime()),
|
||||
column('updated_at', sa.DateTime())
|
||||
)
|
||||
else:
|
||||
provider_model_credentials_table = table('provider_model_credentials',
|
||||
column('id', models.types.StringUUID()),
|
||||
column('tenant_id', models.types.StringUUID()),
|
||||
column('provider_name', sa.String()),
|
||||
column('model_name', sa.String()),
|
||||
column('model_type', sa.String()),
|
||||
column('credential_name', sa.String()),
|
||||
column('encrypted_config', models.types.LongText()),
|
||||
column('created_at', sa.DateTime()),
|
||||
column('updated_at', sa.DateTime())
|
||||
)
|
||||
provider_model_credentials_table = table('provider_model_credentials',
|
||||
column('id', models.types.StringUUID()),
|
||||
column('tenant_id', models.types.StringUUID()),
|
||||
column('provider_name', sa.String()),
|
||||
column('model_name', sa.String()),
|
||||
column('model_type', sa.String()),
|
||||
column('credential_name', sa.String()),
|
||||
column('encrypted_config', models.types.LongText()),
|
||||
column('created_at', sa.DateTime()),
|
||||
column('updated_at', sa.DateTime())
|
||||
)
|
||||
|
||||
|
||||
# Get database connection
|
||||
@ -183,14 +156,8 @@ def migrate_existing_provider_models_data():
|
||||
|
||||
def downgrade():
|
||||
# Re-add encrypted_config column to provider_models table
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('provider_models', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True))
|
||||
else:
|
||||
with op.batch_alter_table('provider_models', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('encrypted_config', models.types.LongText(), nullable=True))
|
||||
with op.batch_alter_table('provider_models', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('encrypted_config', models.types.LongText(), nullable=True))
|
||||
|
||||
if not context.is_offline_mode():
|
||||
# Migrate data back from provider_model_credentials to provider_models
|
||||
|
||||
@ -8,7 +8,6 @@ Create Date: 2025-08-20 17:47:17.015695
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
|
||||
@ -9,8 +9,6 @@ from alembic import op
|
||||
import models as models
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
@ -23,12 +21,7 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# Add encrypted_headers column to tool_mcp_providers table
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', sa.Text(), nullable=True))
|
||||
else:
|
||||
op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', models.types.LongText(), nullable=True))
|
||||
op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', models.types.LongText(), nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
|
||||
@ -44,6 +44,7 @@ def upgrade():
|
||||
sa.PrimaryKeyConstraint('id', name='datasource_oauth_config_pkey'),
|
||||
sa.UniqueConstraint('plugin_id', 'provider', name='datasource_oauth_config_datasource_id_provider_idx')
|
||||
)
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('datasource_oauth_tenant_params',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||
@ -70,6 +71,7 @@ def upgrade():
|
||||
sa.PrimaryKeyConstraint('id', name='datasource_oauth_tenant_config_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='datasource_oauth_tenant_config_unique')
|
||||
)
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('datasource_providers',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||
@ -104,6 +106,7 @@ def upgrade():
|
||||
sa.PrimaryKeyConstraint('id', name='datasource_provider_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', 'name', name='datasource_provider_unique_name')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
|
||||
batch_op.create_index('datasource_provider_auth_type_provider_idx', ['tenant_id', 'plugin_id', 'provider'], unique=False)
|
||||
|
||||
@ -133,6 +136,7 @@ def upgrade():
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='document_pipeline_execution_log_pkey')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op:
|
||||
batch_op.create_index('document_pipeline_execution_logs_document_id_idx', ['document_id'], unique=False)
|
||||
|
||||
@ -174,6 +178,7 @@ def upgrade():
|
||||
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='pipeline_built_in_template_pkey')
|
||||
)
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('pipeline_customized_templates',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||
@ -193,7 +198,6 @@ def upgrade():
|
||||
sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey')
|
||||
)
|
||||
else:
|
||||
# MySQL: Use compatible syntax
|
||||
op.create_table('pipeline_customized_templates',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
@ -211,6 +215,7 @@ def upgrade():
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
|
||||
batch_op.create_index('pipeline_customized_template_tenant_idx', ['tenant_id'], unique=False)
|
||||
|
||||
@ -236,6 +241,7 @@ def upgrade():
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='pipeline_recommended_plugin_pkey')
|
||||
)
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('pipelines',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||
@ -266,6 +272,7 @@ def upgrade():
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='pipeline_pkey')
|
||||
)
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('workflow_draft_variable_files',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||
@ -292,6 +299,7 @@ def upgrade():
|
||||
sa.Column('value_type', sa.String(20), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('workflow_draft_variable_files_pkey'))
|
||||
)
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('workflow_node_execution_offload',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||
@ -316,6 +324,7 @@ def upgrade():
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('workflow_node_execution_offload_pkey')),
|
||||
sa.UniqueConstraint('node_execution_id', 'type', name=op.f('workflow_node_execution_offload_node_execution_id_key'))
|
||||
)
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('keyword_number', sa.Integer(), server_default=sa.text('10'), nullable=True))
|
||||
@ -342,6 +351,7 @@ def upgrade():
|
||||
comment='Indicates whether the current value is the default for a conversation variable. Always `FALSE` for other types of variables.',)
|
||||
)
|
||||
batch_op.create_index('workflow_draft_variable_file_id_idx', ['file_id'], unique=False)
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('rag_pipeline_variables', sa.Text(), server_default='{}', nullable=False))
|
||||
|
||||
@ -9,8 +9,6 @@ from alembic import op
|
||||
import models as models
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
@ -33,15 +31,9 @@ def upgrade():
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('created_by', sa.UUID(), autoincrement=False, nullable=False))
|
||||
batch_op.add_column(sa.Column('updated_by', sa.UUID(), autoincrement=False, nullable=True))
|
||||
else:
|
||||
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), autoincrement=False, nullable=False))
|
||||
batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), autoincrement=False, nullable=True))
|
||||
|
||||
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), autoincrement=False, nullable=False))
|
||||
batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), autoincrement=False, nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@ -9,7 +9,6 @@ Create Date: 2025-10-22 16:11:31.805407
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
@ -105,6 +105,7 @@ def upgrade():
|
||||
sa.PrimaryKeyConstraint('id', name='trigger_oauth_tenant_client_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_trigger_oauth_tenant_client')
|
||||
)
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('trigger_subscriptions',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
@ -143,6 +144,7 @@ def upgrade():
|
||||
sa.PrimaryKeyConstraint('id', name='trigger_provider_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'provider_id', 'name', name='unique_trigger_provider')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('trigger_subscriptions', schema=None) as batch_op:
|
||||
batch_op.create_index('idx_trigger_providers_endpoint', ['endpoint_id'], unique=True)
|
||||
batch_op.create_index('idx_trigger_providers_tenant_endpoint', ['tenant_id', 'endpoint_id'], unique=False)
|
||||
@ -176,6 +178,7 @@ def upgrade():
|
||||
sa.PrimaryKeyConstraint('id', name='workflow_plugin_trigger_pkey'),
|
||||
sa.UniqueConstraint('app_id', 'node_id', name='uniq_app_node_subscription')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('workflow_plugin_triggers', schema=None) as batch_op:
|
||||
batch_op.create_index('workflow_plugin_trigger_tenant_subscription_idx', ['tenant_id', 'subscription_id', 'event_name'], unique=False)
|
||||
|
||||
@ -207,6 +210,7 @@ def upgrade():
|
||||
sa.PrimaryKeyConstraint('id', name='workflow_schedule_plan_pkey'),
|
||||
sa.UniqueConstraint('app_id', 'node_id', name='uniq_app_node')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('workflow_schedule_plans', schema=None) as batch_op:
|
||||
batch_op.create_index('workflow_schedule_plan_next_idx', ['next_run_at'], unique=False)
|
||||
|
||||
@ -264,6 +268,7 @@ def upgrade():
|
||||
sa.Column('finished_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='workflow_trigger_log_pkey')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('workflow_trigger_logs', schema=None) as batch_op:
|
||||
batch_op.create_index('workflow_trigger_log_created_at_idx', ['created_at'], unique=False)
|
||||
batch_op.create_index('workflow_trigger_log_status_idx', ['status'], unique=False)
|
||||
@ -299,6 +304,7 @@ def upgrade():
|
||||
sa.UniqueConstraint('app_id', 'node_id', name='uniq_node'),
|
||||
sa.UniqueConstraint('webhook_id', name='uniq_webhook_id')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('workflow_webhook_triggers', schema=None) as batch_op:
|
||||
batch_op.create_index('workflow_webhook_trigger_tenant_idx', ['tenant_id'], unique=False)
|
||||
|
||||
|
||||
@ -11,9 +11,6 @@ from alembic import op
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '23db93619b9d'
|
||||
down_revision = '8ae9bc661daa'
|
||||
@ -23,14 +20,8 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('message_files', sa.Text(), nullable=True))
|
||||
else:
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('message_files', models.types.LongText(), nullable=True))
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('message_files', models.types.LongText(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@ -62,14 +62,8 @@ def upgrade():
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('annotation_reply', sa.TEXT(), autoincrement=False, nullable=True))
|
||||
else:
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), autoincrement=False, nullable=True))
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), autoincrement=False, nullable=True))
|
||||
|
||||
with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op:
|
||||
batch_op.drop_index('app_annotation_settings_app_idx')
|
||||
|
||||
@ -11,9 +11,6 @@ from alembic import op
|
||||
import models as models
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '2a3aebbbf4bb'
|
||||
down_revision = 'c031d46af369'
|
||||
@ -23,14 +20,8 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('apps', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tracing', sa.Text(), nullable=True))
|
||||
else:
|
||||
with op.batch_alter_table('apps', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tracing', models.types.LongText(), nullable=True))
|
||||
with op.batch_alter_table('apps', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tracing', models.types.LongText(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@ -7,14 +7,10 @@ Create Date: 2023-09-22 15:41:01.243183
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '2e9819ca5b28'
|
||||
down_revision = 'ab23c11305d4'
|
||||
@ -24,35 +20,19 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tenant_id', postgresql.UUID(), nullable=True))
|
||||
batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
|
||||
batch_op.drop_column('dataset_id')
|
||||
else:
|
||||
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tenant_id', models.types.StringUUID(), nullable=True))
|
||||
batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
|
||||
batch_op.drop_column('dataset_id')
|
||||
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tenant_id', models.types.StringUUID(), nullable=True))
|
||||
batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
|
||||
batch_op.drop_column('dataset_id')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('dataset_id', postgresql.UUID(), autoincrement=False, nullable=True))
|
||||
batch_op.drop_index('api_token_tenant_idx')
|
||||
batch_op.drop_column('tenant_id')
|
||||
else:
|
||||
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('dataset_id', models.types.StringUUID(), autoincrement=False, nullable=True))
|
||||
batch_op.drop_index('api_token_tenant_idx')
|
||||
batch_op.drop_column('tenant_id')
|
||||
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('dataset_id', models.types.StringUUID(), autoincrement=False, nullable=True))
|
||||
batch_op.drop_index('api_token_tenant_idx')
|
||||
batch_op.drop_column('tenant_id')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@ -7,14 +7,10 @@ Create Date: 2024-03-07 08:30:29.133614
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '42e85ed5564d'
|
||||
down_revision = 'f9107f83abab'
|
||||
@ -24,59 +20,31 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.alter_column('app_model_config_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('model_provider',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
batch_op.alter_column('model_id',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
else:
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.alter_column('app_model_config_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('model_provider',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
batch_op.alter_column('model_id',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.alter_column('app_model_config_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('model_provider',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
batch_op.alter_column('model_id',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.alter_column('model_id',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=False)
|
||||
batch_op.alter_column('model_provider',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=False)
|
||||
batch_op.alter_column('app_model_config_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=False)
|
||||
else:
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.alter_column('model_id',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=False)
|
||||
batch_op.alter_column('model_provider',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=False)
|
||||
batch_op.alter_column('app_model_config_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.alter_column('model_id',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=False)
|
||||
batch_op.alter_column('model_provider',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=False)
|
||||
batch_op.alter_column('app_model_config_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@ -6,14 +6,10 @@ Create Date: 2024-01-12 03:42:27.362415
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '4829e54d2fee'
|
||||
down_revision = '114eed84c228'
|
||||
@ -23,39 +19,21 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
# PostgreSQL: Keep original syntax
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.alter_column('message_chain_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=True)
|
||||
else:
|
||||
# MySQL: Use compatible syntax
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.alter_column('message_chain_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.alter_column('message_chain_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
# PostgreSQL: Keep original syntax
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.alter_column('message_chain_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=False)
|
||||
else:
|
||||
# MySQL: Use compatible syntax
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.alter_column('message_chain_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.alter_column('message_chain_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@ -6,14 +6,10 @@ Create Date: 2024-03-14 04:54:56.679506
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '563cf8bf777b'
|
||||
down_revision = 'b5429b71023c'
|
||||
@ -23,35 +19,19 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('tool_files', schema=None) as batch_op:
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=True)
|
||||
else:
|
||||
with op.batch_alter_table('tool_files', schema=None) as batch_op:
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
with op.batch_alter_table('tool_files', schema=None) as batch_op:
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('tool_files', schema=None) as batch_op:
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=False)
|
||||
else:
|
||||
with op.batch_alter_table('tool_files', schema=None) as batch_op:
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
with op.batch_alter_table('tool_files', schema=None) as batch_op:
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user