mirror of
https://github.com/langgenius/dify.git
synced 2026-02-13 22:44:40 +08:00
Merge branch 'feat/agent-node-v2' into feat/support-agent-sandbox
This commit is contained in:
commit
1557f48740
@ -4,6 +4,6 @@
|
||||
"context7@claude-plugins-official": true,
|
||||
"typescript-lsp@claude-plugins-official": true,
|
||||
"pyright-lsp@claude-plugins-official": true,
|
||||
"ralph-wiggum@claude-plugins-official": true
|
||||
"ralph-loop@claude-plugins-official": true
|
||||
}
|
||||
}
|
||||
|
||||
355
.claude/skills/skill-creator/SKILL.md
Normal file
355
.claude/skills/skill-creator/SKILL.md
Normal file
@ -0,0 +1,355 @@
|
||||
---
|
||||
name: skill-creator
|
||||
description: Guide for creating effective skills. This skill should be used when users want to create a new skill (or update an existing skill) that extends Claude's capabilities with specialized knowledge, workflows, or tool integrations.
|
||||
---
|
||||
|
||||
# Skill Creator
|
||||
|
||||
This skill provides guidance for creating effective skills.
|
||||
|
||||
## About Skills
|
||||
|
||||
Skills are modular, self-contained packages that extend Claude's capabilities by providing
|
||||
specialized knowledge, workflows, and tools. Think of them as "onboarding guides" for specific
|
||||
domains or tasks—they transform Claude from a general-purpose agent into a specialized agent
|
||||
equipped with procedural knowledge that no model can fully possess.
|
||||
|
||||
### What Skills Provide
|
||||
|
||||
1. Specialized workflows - Multi-step procedures for specific domains
|
||||
2. Tool integrations - Instructions for working with specific file formats or APIs
|
||||
3. Domain expertise - Company-specific knowledge, schemas, business logic
|
||||
4. Bundled resources - Scripts, references, and assets for complex and repetitive tasks
|
||||
|
||||
## Core Principles
|
||||
|
||||
### Concise is Key
|
||||
|
||||
The context window is a public good. Skills share the context window with everything else Claude needs: system prompt, conversation history, other Skills' metadata, and the actual user request.
|
||||
|
||||
**Default assumption: Claude is already very smart.** Only add context Claude doesn't already have. Challenge each piece of information: "Does Claude really need this explanation?" and "Does this paragraph justify its token cost?"
|
||||
|
||||
Prefer concise examples over verbose explanations.
|
||||
|
||||
### Set Appropriate Degrees of Freedom
|
||||
|
||||
Match the level of specificity to the task's fragility and variability:
|
||||
|
||||
**High freedom (text-based instructions)**: Use when multiple approaches are valid, decisions depend on context, or heuristics guide the approach.
|
||||
|
||||
**Medium freedom (pseudocode or scripts with parameters)**: Use when a preferred pattern exists, some variation is acceptable, or configuration affects behavior.
|
||||
|
||||
**Low freedom (specific scripts, few parameters)**: Use when operations are fragile and error-prone, consistency is critical, or a specific sequence must be followed.
|
||||
|
||||
Think of Claude as exploring a path: a narrow bridge with cliffs needs specific guardrails (low freedom), while an open field allows many routes (high freedom).
|
||||
|
||||
### Anatomy of a Skill
|
||||
|
||||
Every skill consists of a required SKILL.md file and optional bundled resources:
|
||||
|
||||
```
|
||||
skill-name/
|
||||
├── SKILL.md (required)
|
||||
│ ├── YAML frontmatter metadata (required)
|
||||
│ │ ├── name: (required)
|
||||
│ │ └── description: (required)
|
||||
│ └── Markdown instructions (required)
|
||||
└── Bundled Resources (optional)
|
||||
├── scripts/ - Executable code (Python/Bash/etc.)
|
||||
├── references/ - Documentation intended to be loaded into context as needed
|
||||
└── assets/ - Files used in output (templates, icons, fonts, etc.)
|
||||
```
|
||||
|
||||
#### SKILL.md (required)
|
||||
|
||||
Every SKILL.md consists of:
|
||||
|
||||
- **Frontmatter** (YAML): Contains `name` and `description` fields. These are the only fields that Claude reads to determine when the skill gets used, thus it is very important to be clear and comprehensive in describing what the skill is, and when it should be used.
|
||||
- **Body** (Markdown): Instructions and guidance for using the skill. Only loaded AFTER the skill triggers (if at all).
|
||||
|
||||
#### Bundled Resources (optional)
|
||||
|
||||
##### Scripts (`scripts/`)
|
||||
|
||||
Executable code (Python/Bash/etc.) for tasks that require deterministic reliability or are repeatedly rewritten.
|
||||
|
||||
- **When to include**: When the same code is being rewritten repeatedly or deterministic reliability is needed
|
||||
- **Example**: `scripts/rotate_pdf.py` for PDF rotation tasks
|
||||
- **Benefits**: Token efficient, deterministic, may be executed without loading into context
|
||||
- **Note**: Scripts may still need to be read by Claude for patching or environment-specific adjustments
|
||||
|
||||
##### References (`references/`)
|
||||
|
||||
Documentation and reference material intended to be loaded as needed into context to inform Claude's process and thinking.
|
||||
|
||||
- **When to include**: For documentation that Claude should reference while working
|
||||
- **Examples**: `references/finance.md` for financial schemas, `references/mnda.md` for company NDA template, `references/policies.md` for company policies, `references/api_docs.md` for API specifications
|
||||
- **Use cases**: Database schemas, API documentation, domain knowledge, company policies, detailed workflow guides
|
||||
- **Benefits**: Keeps SKILL.md lean, loaded only when Claude determines it's needed
|
||||
- **Best practice**: If files are large (>10k words), include grep search patterns in SKILL.md
|
||||
- **Avoid duplication**: Information should live in either SKILL.md or references files, not both. Prefer references files for detailed information unless it's truly core to the skill—this keeps SKILL.md lean while making information discoverable without hogging the context window. Keep only essential procedural instructions and workflow guidance in SKILL.md; move detailed reference material, schemas, and examples to references files.
|
||||
|
||||
##### Assets (`assets/`)
|
||||
|
||||
Files not intended to be loaded into context, but rather used within the output Claude produces.
|
||||
|
||||
- **When to include**: When the skill needs files that will be used in the final output
|
||||
- **Examples**: `assets/logo.png` for brand assets, `assets/slides.pptx` for PowerPoint templates, `assets/frontend-template/` for HTML/React boilerplate, `assets/font.ttf` for typography
|
||||
- **Use cases**: Templates, images, icons, boilerplate code, fonts, sample documents that get copied or modified
|
||||
- **Benefits**: Separates output resources from documentation, enables Claude to use files without loading them into context
|
||||
|
||||
#### What to Not Include in a Skill
|
||||
|
||||
A skill should only contain essential files that directly support its functionality. Do NOT create extraneous documentation or auxiliary files, including:
|
||||
|
||||
- README.md
|
||||
- INSTALLATION_GUIDE.md
|
||||
- QUICK_REFERENCE.md
|
||||
- CHANGELOG.md
|
||||
- etc.
|
||||
|
||||
The skill should only contain the information needed for an AI agent to do the job at hand. It should not contain auxilary context about the process that went into creating it, setup and testing procedures, user-facing documentation, etc. Creating additional documentation files just adds clutter and confusion.
|
||||
|
||||
### Progressive Disclosure Design Principle
|
||||
|
||||
Skills use a three-level loading system to manage context efficiently:
|
||||
|
||||
1. **Metadata (name + description)** - Always in context (~100 words)
|
||||
2. **SKILL.md body** - When skill triggers (<5k words)
|
||||
3. **Bundled resources** - As needed by Claude (Unlimited because scripts can be executed without reading into context window)
|
||||
|
||||
#### Progressive Disclosure Patterns
|
||||
|
||||
Keep SKILL.md body to the essentials and under 500 lines to minimize context bloat. Split content into separate files when approaching this limit. When splitting out content into other files, it is very important to reference them from SKILL.md and describe clearly when to read them, to ensure the reader of the skill knows they exist and when to use them.
|
||||
|
||||
**Key principle:** When a skill supports multiple variations, frameworks, or options, keep only the core workflow and selection guidance in SKILL.md. Move variant-specific details (patterns, examples, configuration) into separate reference files.
|
||||
|
||||
**Pattern 1: High-level guide with references**
|
||||
|
||||
```markdown
|
||||
# PDF Processing
|
||||
|
||||
## Quick start
|
||||
|
||||
Extract text with pdfplumber:
|
||||
[code example]
|
||||
|
||||
## Advanced features
|
||||
|
||||
- **Form filling**: See [FORMS.md](FORMS.md) for complete guide
|
||||
- **API reference**: See [REFERENCE.md](REFERENCE.md) for all methods
|
||||
- **Examples**: See [EXAMPLES.md](EXAMPLES.md) for common patterns
|
||||
```
|
||||
|
||||
Claude loads FORMS.md, REFERENCE.md, or EXAMPLES.md only when needed.
|
||||
|
||||
**Pattern 2: Domain-specific organization**
|
||||
|
||||
For Skills with multiple domains, organize content by domain to avoid loading irrelevant context:
|
||||
|
||||
```
|
||||
bigquery-skill/
|
||||
├── SKILL.md (overview and navigation)
|
||||
└── reference/
|
||||
├── finance.md (revenue, billing metrics)
|
||||
├── sales.md (opportunities, pipeline)
|
||||
├── product.md (API usage, features)
|
||||
└── marketing.md (campaigns, attribution)
|
||||
```
|
||||
|
||||
When a user asks about sales metrics, Claude only reads sales.md.
|
||||
|
||||
Similarly, for skills supporting multiple frameworks or variants, organize by variant:
|
||||
|
||||
```
|
||||
cloud-deploy/
|
||||
├── SKILL.md (workflow + provider selection)
|
||||
└── references/
|
||||
├── aws.md (AWS deployment patterns)
|
||||
├── gcp.md (GCP deployment patterns)
|
||||
└── azure.md (Azure deployment patterns)
|
||||
```
|
||||
|
||||
When the user chooses AWS, Claude only reads aws.md.
|
||||
|
||||
**Pattern 3: Conditional details**
|
||||
|
||||
Show basic content, link to advanced content:
|
||||
|
||||
```markdown
|
||||
# DOCX Processing
|
||||
|
||||
## Creating documents
|
||||
|
||||
Use docx-js for new documents. See [DOCX-JS.md](DOCX-JS.md).
|
||||
|
||||
## Editing documents
|
||||
|
||||
For simple edits, modify the XML directly.
|
||||
|
||||
**For tracked changes**: See [REDLINING.md](REDLINING.md)
|
||||
**For OOXML details**: See [OOXML.md](OOXML.md)
|
||||
```
|
||||
|
||||
Claude reads REDLINING.md or OOXML.md only when the user needs those features.
|
||||
|
||||
**Important guidelines:**
|
||||
|
||||
- **Avoid deeply nested references** - Keep references one level deep from SKILL.md. All reference files should link directly from SKILL.md.
|
||||
- **Structure longer reference files** - For files longer than 100 lines, include a table of contents at the top so Claude can see the full scope when previewing.
|
||||
|
||||
## Skill Creation Process
|
||||
|
||||
Skill creation involves these steps:
|
||||
|
||||
1. Understand the skill with concrete examples
|
||||
2. Plan reusable skill contents (scripts, references, assets)
|
||||
3. Initialize the skill (run init_skill.py)
|
||||
4. Edit the skill (implement resources and write SKILL.md)
|
||||
5. Package the skill (run package_skill.py)
|
||||
6. Iterate based on real usage
|
||||
|
||||
Follow these steps in order, skipping only if there is a clear reason why they are not applicable.
|
||||
|
||||
### Step 1: Understanding the Skill with Concrete Examples
|
||||
|
||||
Skip this step only when the skill's usage patterns are already clearly understood. It remains valuable even when working with an existing skill.
|
||||
|
||||
To create an effective skill, clearly understand concrete examples of how the skill will be used. This understanding can come from either direct user examples or generated examples that are validated with user feedback.
|
||||
|
||||
For example, when building an image-editor skill, relevant questions include:
|
||||
|
||||
- "What functionality should the image-editor skill support? Editing, rotating, anything else?"
|
||||
- "Can you give some examples of how this skill would be used?"
|
||||
- "I can imagine users asking for things like 'Remove the red-eye from this image' or 'Rotate this image'. Are there other ways you imagine this skill being used?"
|
||||
- "What would a user say that should trigger this skill?"
|
||||
|
||||
To avoid overwhelming users, avoid asking too many questions in a single message. Start with the most important questions and follow up as needed for better effectiveness.
|
||||
|
||||
Conclude this step when there is a clear sense of the functionality the skill should support.
|
||||
|
||||
### Step 2: Planning the Reusable Skill Contents
|
||||
|
||||
To turn concrete examples into an effective skill, analyze each example by:
|
||||
|
||||
1. Considering how to execute on the example from scratch
|
||||
2. Identifying what scripts, references, and assets would be helpful when executing these workflows repeatedly
|
||||
|
||||
Example: When building a `pdf-editor` skill to handle queries like "Help me rotate this PDF," the analysis shows:
|
||||
|
||||
1. Rotating a PDF requires re-writing the same code each time
|
||||
2. A `scripts/rotate_pdf.py` script would be helpful to store in the skill
|
||||
|
||||
Example: When designing a `frontend-webapp-builder` skill for queries like "Build me a todo app" or "Build me a dashboard to track my steps," the analysis shows:
|
||||
|
||||
1. Writing a frontend webapp requires the same boilerplate HTML/React each time
|
||||
2. An `assets/hello-world/` template containing the boilerplate HTML/React project files would be helpful to store in the skill
|
||||
|
||||
Example: When building a `big-query` skill to handle queries like "How many users have logged in today?" the analysis shows:
|
||||
|
||||
1. Querying BigQuery requires re-discovering the table schemas and relationships each time
|
||||
2. A `references/schema.md` file documenting the table schemas would be helpful to store in the skill
|
||||
|
||||
To establish the skill's contents, analyze each concrete example to create a list of the reusable resources to include: scripts, references, and assets.
|
||||
|
||||
### Step 3: Initializing the Skill
|
||||
|
||||
At this point, it is time to actually create the skill.
|
||||
|
||||
Skip this step only if the skill being developed already exists, and iteration or packaging is needed. In this case, continue to the next step.
|
||||
|
||||
When creating a new skill from scratch, always run the `init_skill.py` script. The script conveniently generates a new template skill directory that automatically includes everything a skill requires, making the skill creation process much more efficient and reliable.
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
scripts/init_skill.py <skill-name> --path <output-directory>
|
||||
```
|
||||
|
||||
The script:
|
||||
|
||||
- Creates the skill directory at the specified path
|
||||
- Generates a SKILL.md template with proper frontmatter and TODO placeholders
|
||||
- Creates example resource directories: `scripts/`, `references/`, and `assets/`
|
||||
- Adds example files in each directory that can be customized or deleted
|
||||
|
||||
After initialization, customize or remove the generated SKILL.md and example files as needed.
|
||||
|
||||
### Step 4: Edit the Skill
|
||||
|
||||
When editing the (newly-generated or existing) skill, remember that the skill is being created for another instance of Claude to use. Include information that would be beneficial and non-obvious to Claude. Consider what procedural knowledge, domain-specific details, or reusable assets would help another Claude instance execute these tasks more effectively.
|
||||
|
||||
#### Learn Proven Design Patterns
|
||||
|
||||
Consult these helpful guides based on your skill's needs:
|
||||
|
||||
- **Multi-step processes**: See references/workflows.md for sequential workflows and conditional logic
|
||||
- **Specific output formats or quality standards**: See references/output-patterns.md for template and example patterns
|
||||
|
||||
These files contain established best practices for effective skill design.
|
||||
|
||||
#### Start with Reusable Skill Contents
|
||||
|
||||
To begin implementation, start with the reusable resources identified above: `scripts/`, `references/`, and `assets/` files. Note that this step may require user input. For example, when implementing a `brand-guidelines` skill, the user may need to provide brand assets or templates to store in `assets/`, or documentation to store in `references/`.
|
||||
|
||||
Added scripts must be tested by actually running them to ensure there are no bugs and that the output matches what is expected. If there are many similar scripts, only a representative sample needs to be tested to ensure confidence that they all work while balancing time to completion.
|
||||
|
||||
Any example files and directories not needed for the skill should be deleted. The initialization script creates example files in `scripts/`, `references/`, and `assets/` to demonstrate structure, but most skills won't need all of them.
|
||||
|
||||
#### Update SKILL.md
|
||||
|
||||
**Writing Guidelines:** Always use imperative/infinitive form.
|
||||
|
||||
##### Frontmatter
|
||||
|
||||
Write the YAML frontmatter with `name` and `description`:
|
||||
|
||||
- `name`: The skill name
|
||||
- `description`: This is the primary triggering mechanism for your skill, and helps Claude understand when to use the skill.
|
||||
- Include both what the Skill does and specific triggers/contexts for when to use it.
|
||||
- Include all "when to use" information here - Not in the body. The body is only loaded after triggering, so "When to Use This Skill" sections in the body are not helpful to Claude.
|
||||
- Example description for a `docx` skill: "Comprehensive document creation, editing, and analysis with support for tracked changes, comments, formatting preservation, and text extraction. Use when Claude needs to work with professional documents (.docx files) for: (1) Creating new documents, (2) Modifying or editing content, (3) Working with tracked changes, (4) Adding comments, or any other document tasks"
|
||||
|
||||
Do not include any other fields in YAML frontmatter.
|
||||
|
||||
##### Body
|
||||
|
||||
Write instructions for using the skill and its bundled resources.
|
||||
|
||||
### Step 5: Packaging a Skill
|
||||
|
||||
Once development of the skill is complete, it must be packaged into a distributable .skill file that gets shared with the user. The packaging process automatically validates the skill first to ensure it meets all requirements:
|
||||
|
||||
```bash
|
||||
scripts/package_skill.py <path/to/skill-folder>
|
||||
```
|
||||
|
||||
Optional output directory specification:
|
||||
|
||||
```bash
|
||||
scripts/package_skill.py <path/to/skill-folder> ./dist
|
||||
```
|
||||
|
||||
The packaging script will:
|
||||
|
||||
1. **Validate** the skill automatically, checking:
|
||||
|
||||
- YAML frontmatter format and required fields
|
||||
- Skill naming conventions and directory structure
|
||||
- Description completeness and quality
|
||||
- File organization and resource references
|
||||
|
||||
2. **Package** the skill if validation passes, creating a .skill file named after the skill (e.g., `my-skill.skill`) that includes all files and maintains the proper directory structure for distribution. The .skill file is a zip file with a .skill extension.
|
||||
|
||||
If validation fails, the script will report the errors and exit without creating a package. Fix any validation errors and run the packaging command again.
|
||||
|
||||
### Step 6: Iterate
|
||||
|
||||
After testing the skill, users may request improvements. Often this happens right after using the skill, with fresh context of how the skill performed.
|
||||
|
||||
**Iteration workflow:**
|
||||
|
||||
1. Use the skill on real tasks
|
||||
2. Notice struggles or inefficiencies
|
||||
3. Identify how SKILL.md or bundled resources should be updated
|
||||
4. Implement changes and test again
|
||||
86
.claude/skills/skill-creator/references/output-patterns.md
Normal file
86
.claude/skills/skill-creator/references/output-patterns.md
Normal file
@ -0,0 +1,86 @@
|
||||
# Output Patterns
|
||||
|
||||
Use these patterns when skills need to produce consistent, high-quality output.
|
||||
|
||||
## Template Pattern
|
||||
|
||||
Provide templates for output format. Match the level of strictness to your needs.
|
||||
|
||||
**For strict requirements (like API responses or data formats):**
|
||||
|
||||
```markdown
|
||||
## Report structure
|
||||
|
||||
ALWAYS use this exact template structure:
|
||||
|
||||
# [Analysis Title]
|
||||
|
||||
## Executive summary
|
||||
[One-paragraph overview of key findings]
|
||||
|
||||
## Key findings
|
||||
- Finding 1 with supporting data
|
||||
- Finding 2 with supporting data
|
||||
- Finding 3 with supporting data
|
||||
|
||||
## Recommendations
|
||||
1. Specific actionable recommendation
|
||||
2. Specific actionable recommendation
|
||||
```
|
||||
|
||||
**For flexible guidance (when adaptation is useful):**
|
||||
|
||||
```markdown
|
||||
## Report structure
|
||||
|
||||
Here is a sensible default format, but use your best judgment:
|
||||
|
||||
# [Analysis Title]
|
||||
|
||||
## Executive summary
|
||||
[Overview]
|
||||
|
||||
## Key findings
|
||||
[Adapt sections based on what you discover]
|
||||
|
||||
## Recommendations
|
||||
[Tailor to the specific context]
|
||||
|
||||
Adjust sections as needed for the specific analysis type.
|
||||
```
|
||||
|
||||
## Examples Pattern
|
||||
|
||||
For skills where output quality depends on seeing examples, provide input/output pairs:
|
||||
|
||||
```markdown
|
||||
## Commit message format
|
||||
|
||||
Generate commit messages following these examples:
|
||||
|
||||
**Example 1:**
|
||||
Input: Added user authentication with JWT tokens
|
||||
Output:
|
||||
```
|
||||
|
||||
feat(auth): implement JWT-based authentication
|
||||
|
||||
Add login endpoint and token validation middleware
|
||||
|
||||
```
|
||||
|
||||
**Example 2:**
|
||||
Input: Fixed bug where dates displayed incorrectly in reports
|
||||
Output:
|
||||
```
|
||||
|
||||
fix(reports): correct date formatting in timezone conversion
|
||||
|
||||
Use UTC timestamps consistently across report generation
|
||||
|
||||
```
|
||||
|
||||
Follow this style: type(scope): brief description, then detailed explanation.
|
||||
```
|
||||
|
||||
Examples help Claude understand the desired style and level of detail more clearly than descriptions alone.
|
||||
28
.claude/skills/skill-creator/references/workflows.md
Normal file
28
.claude/skills/skill-creator/references/workflows.md
Normal file
@ -0,0 +1,28 @@
|
||||
# Workflow Patterns
|
||||
|
||||
## Sequential Workflows
|
||||
|
||||
For complex tasks, break operations into clear, sequential steps. It is often helpful to give Claude an overview of the process towards the beginning of SKILL.md:
|
||||
|
||||
```markdown
|
||||
Filling a PDF form involves these steps:
|
||||
|
||||
1. Analyze the form (run analyze_form.py)
|
||||
2. Create field mapping (edit fields.json)
|
||||
3. Validate mapping (run validate_fields.py)
|
||||
4. Fill the form (run fill_form.py)
|
||||
5. Verify output (run verify_output.py)
|
||||
```
|
||||
|
||||
## Conditional Workflows
|
||||
|
||||
For tasks with branching logic, guide Claude through decision points:
|
||||
|
||||
```markdown
|
||||
1. Determine the modification type:
|
||||
**Creating new content?** → Follow "Creation workflow" below
|
||||
**Editing existing content?** → Follow "Editing workflow" below
|
||||
|
||||
2. Creation workflow: [steps]
|
||||
3. Editing workflow: [steps]
|
||||
```
|
||||
300
.claude/skills/skill-creator/scripts/init_skill.py
Executable file
300
.claude/skills/skill-creator/scripts/init_skill.py
Executable file
@ -0,0 +1,300 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Skill Initializer - Creates a new skill from template
|
||||
|
||||
Usage:
|
||||
init_skill.py <skill-name> --path <path>
|
||||
|
||||
Examples:
|
||||
init_skill.py my-new-skill --path skills/public
|
||||
init_skill.py my-api-helper --path skills/private
|
||||
init_skill.py custom-skill --path /custom/location
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
SKILL_TEMPLATE = """---
|
||||
name: {skill_name}
|
||||
description: [TODO: Complete and informative explanation of what the skill does and when to use it. Include WHEN to use this skill - specific scenarios, file types, or tasks that trigger it.]
|
||||
---
|
||||
|
||||
# {skill_title}
|
||||
|
||||
## Overview
|
||||
|
||||
[TODO: 1-2 sentences explaining what this skill enables]
|
||||
|
||||
## Structuring This Skill
|
||||
|
||||
[TODO: Choose the structure that best fits this skill's purpose. Common patterns:
|
||||
|
||||
**1. Workflow-Based** (best for sequential processes)
|
||||
- Works well when there are clear step-by-step procedures
|
||||
- Example: DOCX skill with "Workflow Decision Tree" → "Reading" → "Creating" → "Editing"
|
||||
- Structure: ## Overview → ## Workflow Decision Tree → ## Step 1 → ## Step 2...
|
||||
|
||||
**2. Task-Based** (best for tool collections)
|
||||
- Works well when the skill offers different operations/capabilities
|
||||
- Example: PDF skill with "Quick Start" → "Merge PDFs" → "Split PDFs" → "Extract Text"
|
||||
- Structure: ## Overview → ## Quick Start → ## Task Category 1 → ## Task Category 2...
|
||||
|
||||
**3. Reference/Guidelines** (best for standards or specifications)
|
||||
- Works well for brand guidelines, coding standards, or requirements
|
||||
- Example: Brand styling with "Brand Guidelines" → "Colors" → "Typography" → "Features"
|
||||
- Structure: ## Overview → ## Guidelines → ## Specifications → ## Usage...
|
||||
|
||||
**4. Capabilities-Based** (best for integrated systems)
|
||||
- Works well when the skill provides multiple interrelated features
|
||||
- Example: Product Management with "Core Capabilities" → numbered capability list
|
||||
- Structure: ## Overview → ## Core Capabilities → ### 1. Feature → ### 2. Feature...
|
||||
|
||||
Patterns can be mixed and matched as needed. Most skills combine patterns (e.g., start with task-based, add workflow for complex operations).
|
||||
|
||||
Delete this entire "Structuring This Skill" section when done - it's just guidance.]
|
||||
|
||||
## [TODO: Replace with the first main section based on chosen structure]
|
||||
|
||||
[TODO: Add content here. See examples in existing skills:
|
||||
- Code samples for technical skills
|
||||
- Decision trees for complex workflows
|
||||
- Concrete examples with realistic user requests
|
||||
- References to scripts/templates/references as needed]
|
||||
|
||||
## Resources
|
||||
|
||||
This skill includes example resource directories that demonstrate how to organize different types of bundled resources:
|
||||
|
||||
### scripts/
|
||||
Executable code (Python/Bash/etc.) that can be run directly to perform specific operations.
|
||||
|
||||
**Examples from other skills:**
|
||||
- PDF skill: `fill_fillable_fields.py`, `extract_form_field_info.py` - utilities for PDF manipulation
|
||||
- DOCX skill: `document.py`, `utilities.py` - Python modules for document processing
|
||||
|
||||
**Appropriate for:** Python scripts, shell scripts, or any executable code that performs automation, data processing, or specific operations.
|
||||
|
||||
**Note:** Scripts may be executed without loading into context, but can still be read by Claude for patching or environment adjustments.
|
||||
|
||||
### references/
|
||||
Documentation and reference material intended to be loaded into context to inform Claude's process and thinking.
|
||||
|
||||
**Examples from other skills:**
|
||||
- Product management: `communication.md`, `context_building.md` - detailed workflow guides
|
||||
- BigQuery: API reference documentation and query examples
|
||||
- Finance: Schema documentation, company policies
|
||||
|
||||
**Appropriate for:** In-depth documentation, API references, database schemas, comprehensive guides, or any detailed information that Claude should reference while working.
|
||||
|
||||
### assets/
|
||||
Files not intended to be loaded into context, but rather used within the output Claude produces.
|
||||
|
||||
**Examples from other skills:**
|
||||
- Brand styling: PowerPoint template files (.pptx), logo files
|
||||
- Frontend builder: HTML/React boilerplate project directories
|
||||
- Typography: Font files (.ttf, .woff2)
|
||||
|
||||
**Appropriate for:** Templates, boilerplate code, document templates, images, icons, fonts, or any files meant to be copied or used in the final output.
|
||||
|
||||
---
|
||||
|
||||
**Any unneeded directories can be deleted.** Not every skill requires all three types of resources.
|
||||
"""
|
||||
|
||||
EXAMPLE_SCRIPT = '''#!/usr/bin/env python3
|
||||
"""
|
||||
Example helper script for {skill_name}
|
||||
|
||||
This is a placeholder script that can be executed directly.
|
||||
Replace with actual implementation or delete if not needed.
|
||||
|
||||
Example real scripts from other skills:
|
||||
- pdf/scripts/fill_fillable_fields.py - Fills PDF form fields
|
||||
- pdf/scripts/convert_pdf_to_images.py - Converts PDF pages to images
|
||||
"""
|
||||
|
||||
def main():
|
||||
print("This is an example script for {skill_name}")
|
||||
# TODO: Add actual script logic here
|
||||
# This could be data processing, file conversion, API calls, etc.
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
'''
|
||||
|
||||
EXAMPLE_REFERENCE = """# Reference Documentation for {skill_title}
|
||||
|
||||
This is a placeholder for detailed reference documentation.
|
||||
Replace with actual reference content or delete if not needed.
|
||||
|
||||
Example real reference docs from other skills:
|
||||
- product-management/references/communication.md - Comprehensive guide for status updates
|
||||
- product-management/references/context_building.md - Deep-dive on gathering context
|
||||
- bigquery/references/ - API references and query examples
|
||||
|
||||
## When Reference Docs Are Useful
|
||||
|
||||
Reference docs are ideal for:
|
||||
- Comprehensive API documentation
|
||||
- Detailed workflow guides
|
||||
- Complex multi-step processes
|
||||
- Information too lengthy for main SKILL.md
|
||||
- Content that's only needed for specific use cases
|
||||
|
||||
## Structure Suggestions
|
||||
|
||||
### API Reference Example
|
||||
- Overview
|
||||
- Authentication
|
||||
- Endpoints with examples
|
||||
- Error codes
|
||||
- Rate limits
|
||||
|
||||
### Workflow Guide Example
|
||||
- Prerequisites
|
||||
- Step-by-step instructions
|
||||
- Common patterns
|
||||
- Troubleshooting
|
||||
- Best practices
|
||||
"""
|
||||
|
||||
EXAMPLE_ASSET = """# Example Asset File
|
||||
|
||||
This placeholder represents where asset files would be stored.
|
||||
Replace with actual asset files (templates, images, fonts, etc.) or delete if not needed.
|
||||
|
||||
Asset files are NOT intended to be loaded into context, but rather used within
|
||||
the output Claude produces.
|
||||
|
||||
Example asset files from other skills:
|
||||
- Brand guidelines: logo.png, slides_template.pptx
|
||||
- Frontend builder: hello-world/ directory with HTML/React boilerplate
|
||||
- Typography: custom-font.ttf, font-family.woff2
|
||||
- Data: sample_data.csv, test_dataset.json
|
||||
|
||||
## Common Asset Types
|
||||
|
||||
- Templates: .pptx, .docx, boilerplate directories
|
||||
- Images: .png, .jpg, .svg, .gif
|
||||
- Fonts: .ttf, .otf, .woff, .woff2
|
||||
- Boilerplate code: Project directories, starter files
|
||||
- Icons: .ico, .svg
|
||||
- Data files: .csv, .json, .xml, .yaml
|
||||
|
||||
Note: This is a text placeholder. Actual assets can be any file type.
|
||||
"""
|
||||
|
||||
|
||||
def title_case_skill_name(skill_name):
|
||||
"""Convert hyphenated skill name to Title Case for display."""
|
||||
return " ".join(word.capitalize() for word in skill_name.split("-"))
|
||||
|
||||
|
||||
def init_skill(skill_name, path):
|
||||
"""
|
||||
Initialize a new skill directory with template SKILL.md.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill
|
||||
path: Path where the skill directory should be created
|
||||
|
||||
Returns:
|
||||
Path to created skill directory, or None if error
|
||||
"""
|
||||
# Determine skill directory path
|
||||
skill_dir = Path(path).resolve() / skill_name
|
||||
|
||||
# Check if directory already exists
|
||||
if skill_dir.exists():
|
||||
print(f"❌ Error: Skill directory already exists: {skill_dir}")
|
||||
return None
|
||||
|
||||
# Create skill directory
|
||||
try:
|
||||
skill_dir.mkdir(parents=True, exist_ok=False)
|
||||
print(f"✅ Created skill directory: {skill_dir}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating directory: {e}")
|
||||
return None
|
||||
|
||||
# Create SKILL.md from template
|
||||
skill_title = title_case_skill_name(skill_name)
|
||||
skill_content = SKILL_TEMPLATE.format(skill_name=skill_name, skill_title=skill_title)
|
||||
|
||||
skill_md_path = skill_dir / "SKILL.md"
|
||||
try:
|
||||
skill_md_path.write_text(skill_content)
|
||||
print("✅ Created SKILL.md")
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating SKILL.md: {e}")
|
||||
return None
|
||||
|
||||
# Create resource directories with example files
|
||||
try:
|
||||
# Create scripts/ directory with example script
|
||||
scripts_dir = skill_dir / "scripts"
|
||||
scripts_dir.mkdir(exist_ok=True)
|
||||
example_script = scripts_dir / "example.py"
|
||||
example_script.write_text(EXAMPLE_SCRIPT.format(skill_name=skill_name))
|
||||
example_script.chmod(0o755)
|
||||
print("✅ Created scripts/example.py")
|
||||
|
||||
# Create references/ directory with example reference doc
|
||||
references_dir = skill_dir / "references"
|
||||
references_dir.mkdir(exist_ok=True)
|
||||
example_reference = references_dir / "api_reference.md"
|
||||
example_reference.write_text(EXAMPLE_REFERENCE.format(skill_title=skill_title))
|
||||
print("✅ Created references/api_reference.md")
|
||||
|
||||
# Create assets/ directory with example asset placeholder
|
||||
assets_dir = skill_dir / "assets"
|
||||
assets_dir.mkdir(exist_ok=True)
|
||||
example_asset = assets_dir / "example_asset.txt"
|
||||
example_asset.write_text(EXAMPLE_ASSET)
|
||||
print("✅ Created assets/example_asset.txt")
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating resource directories: {e}")
|
||||
return None
|
||||
|
||||
# Print next steps
|
||||
print(f"\n✅ Skill '{skill_name}' initialized successfully at {skill_dir}")
|
||||
print("\nNext steps:")
|
||||
print("1. Edit SKILL.md to complete the TODO items and update the description")
|
||||
print("2. Customize or delete the example files in scripts/, references/, and assets/")
|
||||
print("3. Run the validator when ready to check the skill structure")
|
||||
|
||||
return skill_dir
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 4 or sys.argv[2] != "--path":
|
||||
print("Usage: init_skill.py <skill-name> --path <path>")
|
||||
print("\nSkill name requirements:")
|
||||
print(" - Hyphen-case identifier (e.g., 'data-analyzer')")
|
||||
print(" - Lowercase letters, digits, and hyphens only")
|
||||
print(" - Max 40 characters")
|
||||
print(" - Must match directory name exactly")
|
||||
print("\nExamples:")
|
||||
print(" init_skill.py my-new-skill --path skills/public")
|
||||
print(" init_skill.py my-api-helper --path skills/private")
|
||||
print(" init_skill.py custom-skill --path /custom/location")
|
||||
sys.exit(1)
|
||||
|
||||
skill_name = sys.argv[1]
|
||||
path = sys.argv[3]
|
||||
|
||||
print(f"🚀 Initializing skill: {skill_name}")
|
||||
print(f" Location: {path}")
|
||||
print()
|
||||
|
||||
result = init_skill(skill_name, path)
|
||||
|
||||
if result:
|
||||
sys.exit(0)
|
||||
else:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
110
.claude/skills/skill-creator/scripts/package_skill.py
Executable file
110
.claude/skills/skill-creator/scripts/package_skill.py
Executable file
@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Skill Packager - Creates a distributable .skill file of a skill folder
|
||||
|
||||
Usage:
|
||||
python utils/package_skill.py <path/to/skill-folder> [output-directory]
|
||||
|
||||
Example:
|
||||
python utils/package_skill.py skills/public/my-skill
|
||||
python utils/package_skill.py skills/public/my-skill ./dist
|
||||
"""
|
||||
|
||||
import sys
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from quick_validate import validate_skill
|
||||
|
||||
|
||||
def package_skill(skill_path, output_dir=None):
|
||||
"""
|
||||
Package a skill folder into a .skill file.
|
||||
|
||||
Args:
|
||||
skill_path: Path to the skill folder
|
||||
output_dir: Optional output directory for the .skill file (defaults to current directory)
|
||||
|
||||
Returns:
|
||||
Path to the created .skill file, or None if error
|
||||
"""
|
||||
skill_path = Path(skill_path).resolve()
|
||||
|
||||
# Validate skill folder exists
|
||||
if not skill_path.exists():
|
||||
print(f"❌ Error: Skill folder not found: {skill_path}")
|
||||
return None
|
||||
|
||||
if not skill_path.is_dir():
|
||||
print(f"❌ Error: Path is not a directory: {skill_path}")
|
||||
return None
|
||||
|
||||
# Validate SKILL.md exists
|
||||
skill_md = skill_path / "SKILL.md"
|
||||
if not skill_md.exists():
|
||||
print(f"❌ Error: SKILL.md not found in {skill_path}")
|
||||
return None
|
||||
|
||||
# Run validation before packaging
|
||||
print("🔍 Validating skill...")
|
||||
valid, message = validate_skill(skill_path)
|
||||
if not valid:
|
||||
print(f"❌ Validation failed: {message}")
|
||||
print(" Please fix the validation errors before packaging.")
|
||||
return None
|
||||
print(f"✅ {message}\n")
|
||||
|
||||
# Determine output location
|
||||
skill_name = skill_path.name
|
||||
if output_dir:
|
||||
output_path = Path(output_dir).resolve()
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
output_path = Path.cwd()
|
||||
|
||||
skill_filename = output_path / f"{skill_name}.skill"
|
||||
|
||||
# Create the .skill file (zip format)
|
||||
try:
|
||||
with zipfile.ZipFile(skill_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
|
||||
# Walk through the skill directory
|
||||
for file_path in skill_path.rglob("*"):
|
||||
if file_path.is_file():
|
||||
# Calculate the relative path within the zip
|
||||
arcname = file_path.relative_to(skill_path.parent)
|
||||
zipf.write(file_path, arcname)
|
||||
print(f" Added: {arcname}")
|
||||
|
||||
print(f"\n✅ Successfully packaged skill to: {skill_filename}")
|
||||
return skill_filename
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating .skill file: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python utils/package_skill.py <path/to/skill-folder> [output-directory]")
|
||||
print("\nExample:")
|
||||
print(" python utils/package_skill.py skills/public/my-skill")
|
||||
print(" python utils/package_skill.py skills/public/my-skill ./dist")
|
||||
sys.exit(1)
|
||||
|
||||
skill_path = sys.argv[1]
|
||||
output_dir = sys.argv[2] if len(sys.argv) > 2 else None
|
||||
|
||||
print(f"📦 Packaging skill: {skill_path}")
|
||||
if output_dir:
|
||||
print(f" Output directory: {output_dir}")
|
||||
print()
|
||||
|
||||
result = package_skill(skill_path, output_dir)
|
||||
|
||||
if result:
|
||||
sys.exit(0)
|
||||
else:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
97
.claude/skills/skill-creator/scripts/quick_validate.py
Executable file
97
.claude/skills/skill-creator/scripts/quick_validate.py
Executable file
@ -0,0 +1,97 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick validation script for skills - minimal version
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def validate_skill(skill_path):
|
||||
"""Basic validation of a skill"""
|
||||
skill_path = Path(skill_path)
|
||||
|
||||
# Check SKILL.md exists
|
||||
skill_md = skill_path / "SKILL.md"
|
||||
if not skill_md.exists():
|
||||
return False, "SKILL.md not found"
|
||||
|
||||
# Read and validate frontmatter
|
||||
content = skill_md.read_text()
|
||||
if not content.startswith("---"):
|
||||
return False, "No YAML frontmatter found"
|
||||
|
||||
# Extract frontmatter
|
||||
match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL)
|
||||
if not match:
|
||||
return False, "Invalid frontmatter format"
|
||||
|
||||
frontmatter_text = match.group(1)
|
||||
|
||||
# Parse YAML frontmatter
|
||||
try:
|
||||
frontmatter = yaml.safe_load(frontmatter_text)
|
||||
if not isinstance(frontmatter, dict):
|
||||
return False, "Frontmatter must be a YAML dictionary"
|
||||
except yaml.YAMLError as e:
|
||||
return False, f"Invalid YAML in frontmatter: {e}"
|
||||
|
||||
# Define allowed properties
|
||||
ALLOWED_PROPERTIES = {"name", "description", "license", "allowed-tools", "metadata"}
|
||||
|
||||
# Check for unexpected properties (excluding nested keys under metadata)
|
||||
unexpected_keys = set(frontmatter.keys()) - ALLOWED_PROPERTIES
|
||||
if unexpected_keys:
|
||||
return False, (
|
||||
f"Unexpected key(s) in SKILL.md frontmatter: {', '.join(sorted(unexpected_keys))}. "
|
||||
f"Allowed properties are: {', '.join(sorted(ALLOWED_PROPERTIES))}"
|
||||
)
|
||||
|
||||
# Check required fields
|
||||
if "name" not in frontmatter:
|
||||
return False, "Missing 'name' in frontmatter"
|
||||
if "description" not in frontmatter:
|
||||
return False, "Missing 'description' in frontmatter"
|
||||
|
||||
# Extract name for validation
|
||||
name = frontmatter.get("name", "")
|
||||
if not isinstance(name, str):
|
||||
return False, f"Name must be a string, got {type(name).__name__}"
|
||||
name = name.strip()
|
||||
if name:
|
||||
# Check naming convention (hyphen-case: lowercase with hyphens)
|
||||
if not re.match(r"^[a-z0-9-]+$", name):
|
||||
return False, f"Name '{name}' should be hyphen-case (lowercase letters, digits, and hyphens only)"
|
||||
if name.startswith("-") or name.endswith("-") or "--" in name:
|
||||
return False, f"Name '{name}' cannot start/end with hyphen or contain consecutive hyphens"
|
||||
# Check name length (max 64 characters per spec)
|
||||
if len(name) > 64:
|
||||
return False, f"Name is too long ({len(name)} characters). Maximum is 64 characters."
|
||||
|
||||
# Extract and validate description
|
||||
description = frontmatter.get("description", "")
|
||||
if not isinstance(description, str):
|
||||
return False, f"Description must be a string, got {type(description).__name__}"
|
||||
description = description.strip()
|
||||
if description:
|
||||
# Check for angle brackets
|
||||
if "<" in description or ">" in description:
|
||||
return False, "Description cannot contain angle brackets (< or >)"
|
||||
# Check description length (max 1024 characters per spec)
|
||||
if len(description) > 1024:
|
||||
return False, f"Description is too long ({len(description)} characters). Maximum is 1024 characters."
|
||||
|
||||
return True, "Skill is valid!"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python quick_validate.py <skill_directory>")
|
||||
sys.exit(1)
|
||||
|
||||
valid, message = validate_skill(sys.argv[1])
|
||||
print(message)
|
||||
sys.exit(0 if valid else 1)
|
||||
@ -53,7 +53,6 @@ ignore_imports =
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
core.workflow.nodes.variable_assigner.common.impl -> extensions.ext_database
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
||||
|
||||
@ -50,16 +50,33 @@ WORKDIR /app/api
|
||||
|
||||
# Create non-root user
|
||||
ARG dify_uid=1001
|
||||
ARG NODE_MAJOR=22
|
||||
ARG NODE_PACKAGE_VERSION=22.21.0-1nodesource1
|
||||
ARG NODESOURCE_KEY_FPR=6F71F525282841EEDAF851B42F59B5F99B1BE0B4
|
||||
RUN groupadd -r -g ${dify_uid} dify && \
|
||||
useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \
|
||||
chown -R dify:dify /app
|
||||
|
||||
RUN \
|
||||
apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
curl \
|
||||
gnupg \
|
||||
&& mkdir -p /etc/apt/keyrings \
|
||||
&& curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key -o /tmp/nodesource.gpg \
|
||||
&& gpg --show-keys --with-colons /tmp/nodesource.gpg \
|
||||
| awk -F: '/^fpr:/ {print $10}' \
|
||||
| grep -Fx "${NODESOURCE_KEY_FPR}" \
|
||||
&& gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg /tmp/nodesource.gpg \
|
||||
&& rm -f /tmp/nodesource.gpg \
|
||||
&& echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_${NODE_MAJOR}.x nodistro main" \
|
||||
> /etc/apt/sources.list.d/nodesource.list \
|
||||
&& apt-get update \
|
||||
# Install dependencies
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
curl nodejs \
|
||||
nodejs=${NODE_PACKAGE_VERSION} \
|
||||
# for gmpy2 \
|
||||
libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# For Security
|
||||
|
||||
@ -1,14 +1,16 @@
|
||||
import re
|
||||
import uuid
|
||||
from typing import Literal
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
@ -19,27 +21,19 @@ from controllers.console.wraps import (
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import (
|
||||
deleted_tool_fields,
|
||||
model_config_fields,
|
||||
model_config_partial_fields,
|
||||
site_fields,
|
||||
tag_fields,
|
||||
)
|
||||
from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
|
||||
from libs.helper import AppIconUrlField, TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, Workflow
|
||||
from models.model import IconType
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
@ -192,124 +186,292 @@ class AppTracePayload(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
JSONValue: TypeAlias = Any
|
||||
|
||||
|
||||
reg(AppListQuery)
|
||||
reg(CreateAppPayload)
|
||||
reg(UpdateAppPayload)
|
||||
reg(CopyAppPayload)
|
||||
reg(AppExportQuery)
|
||||
reg(AppNamePayload)
|
||||
reg(AppIconPayload)
|
||||
reg(AppSiteStatusPayload)
|
||||
reg(AppApiStatusPayload)
|
||||
reg(AppTracePayload)
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base models first
|
||||
tag_model = console_ns.model("Tag", tag_fields)
|
||||
|
||||
workflow_partial_model = console_ns.model("WorkflowPartial", _workflow_partial_fields_dict)
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
model_config_model = console_ns.model("ModelConfig", model_config_fields)
|
||||
|
||||
model_config_partial_model = console_ns.model("ModelConfigPartial", model_config_partial_fields)
|
||||
def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None:
|
||||
if icon is None or icon_type is None:
|
||||
return None
|
||||
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
|
||||
if icon_type_value.lower() != IconType.IMAGE.value:
|
||||
return None
|
||||
return file_helpers.get_signed_file_url(icon)
|
||||
|
||||
deleted_tool_model = console_ns.model("DeletedTool", deleted_tool_fields)
|
||||
|
||||
site_model = console_ns.model("Site", site_fields)
|
||||
class Tag(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
|
||||
app_partial_model = console_ns.model(
|
||||
"AppPartial",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"max_active_requests": fields.Raw(),
|
||||
"description": fields.String(attribute="desc_or_prompt"),
|
||||
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"model_config": fields.Nested(model_config_partial_model, attribute="app_model_config", allow_null=True),
|
||||
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"tags": fields.List(fields.Nested(tag_model)),
|
||||
"access_mode": fields.String,
|
||||
"create_user_name": fields.String,
|
||||
"author_name": fields.String,
|
||||
"has_draft_trigger": fields.Boolean,
|
||||
},
|
||||
)
|
||||
|
||||
app_detail_model = console_ns.model(
|
||||
"AppDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"enable_site": fields.Boolean,
|
||||
"enable_api": fields.Boolean,
|
||||
"model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
|
||||
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
|
||||
"tracing": fields.Raw,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"access_mode": fields.String,
|
||||
"tags": fields.List(fields.Nested(tag_model)),
|
||||
},
|
||||
)
|
||||
class WorkflowPartial(ResponseModel):
|
||||
id: str
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_by: str | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
app_detail_with_site_model = console_ns.model(
|
||||
"AppDetailWithSite",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"enable_site": fields.Boolean,
|
||||
"enable_api": fields.Boolean,
|
||||
"model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
|
||||
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
|
||||
"api_base_url": fields.String,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"max_active_requests": fields.Integer,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"deleted_tools": fields.List(fields.Nested(deleted_tool_model)),
|
||||
"access_mode": fields.String,
|
||||
"tags": fields.List(fields.Nested(tag_model)),
|
||||
"site": fields.Nested(site_model),
|
||||
},
|
||||
)
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
app_pagination_model = console_ns.model(
|
||||
"AppPagination",
|
||||
{
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(app_partial_model), attribute="items"),
|
||||
},
|
||||
|
||||
class ModelConfigPartial(ResponseModel):
|
||||
model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model"))
|
||||
pre_prompt: str | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_by: str | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class ModelConfig(ResponseModel):
|
||||
opening_statement: str | None = None
|
||||
suggested_questions: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("suggested_questions_list", "suggested_questions")
|
||||
)
|
||||
suggested_questions_after_answer: JSONValue | None = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("suggested_questions_after_answer_dict", "suggested_questions_after_answer"),
|
||||
)
|
||||
speech_to_text: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("speech_to_text_dict", "speech_to_text")
|
||||
)
|
||||
text_to_speech: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("text_to_speech_dict", "text_to_speech")
|
||||
)
|
||||
retriever_resource: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("retriever_resource_dict", "retriever_resource")
|
||||
)
|
||||
annotation_reply: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("annotation_reply_dict", "annotation_reply")
|
||||
)
|
||||
more_like_this: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("more_like_this_dict", "more_like_this")
|
||||
)
|
||||
sensitive_word_avoidance: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("sensitive_word_avoidance_dict", "sensitive_word_avoidance")
|
||||
)
|
||||
external_data_tools: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("external_data_tools_list", "external_data_tools")
|
||||
)
|
||||
model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model"))
|
||||
user_input_form: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("user_input_form_list", "user_input_form")
|
||||
)
|
||||
dataset_query_variable: str | None = None
|
||||
pre_prompt: str | None = None
|
||||
agent_mode: JSONValue | None = Field(default=None, validation_alias=AliasChoices("agent_mode_dict", "agent_mode"))
|
||||
prompt_type: str | None = None
|
||||
chat_prompt_config: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("chat_prompt_config_dict", "chat_prompt_config")
|
||||
)
|
||||
completion_prompt_config: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("completion_prompt_config_dict", "completion_prompt_config")
|
||||
)
|
||||
dataset_configs: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("dataset_configs_dict", "dataset_configs")
|
||||
)
|
||||
file_upload: JSONValue | None = Field(
|
||||
default=None, validation_alias=AliasChoices("file_upload_dict", "file_upload")
|
||||
)
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_by: str | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class Site(ResponseModel):
|
||||
access_token: str | None = Field(default=None, validation_alias="code")
|
||||
code: str | None = None
|
||||
title: str | None = None
|
||||
icon_type: str | IconType | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
description: str | None = None
|
||||
default_language: str | None = None
|
||||
chat_color_theme: str | None = None
|
||||
chat_color_theme_inverted: bool | None = None
|
||||
customize_domain: str | None = None
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
customize_token_strategy: str | None = None
|
||||
prompt_public: bool | None = None
|
||||
app_base_url: str | None = None
|
||||
show_workflow_steps: bool | None = None
|
||||
use_icon_as_answer_icon: bool | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_by: str | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return _build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
@field_validator("icon_type", mode="before")
|
||||
@classmethod
|
||||
def _normalize_icon_type(cls, value: str | IconType | None) -> str | None:
|
||||
if isinstance(value, IconType):
|
||||
return value.value
|
||||
return value
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class DeletedTool(ResponseModel):
|
||||
type: str
|
||||
tool_name: str
|
||||
provider_id: str
|
||||
|
||||
|
||||
class AppPartial(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
max_active_requests: int | None = None
|
||||
description: str | None = Field(default=None, validation_alias=AliasChoices("desc_or_prompt", "description"))
|
||||
mode: str = Field(validation_alias="mode_compatible_with_agent")
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
model_config_: ModelConfigPartial | None = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("app_model_config", "model_config"),
|
||||
alias="model_config",
|
||||
)
|
||||
workflow: WorkflowPartial | None = None
|
||||
use_icon_as_answer_icon: bool | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_by: str | None = None
|
||||
updated_at: int | None = None
|
||||
tags: list[Tag] = Field(default_factory=list)
|
||||
access_mode: str | None = None
|
||||
create_user_name: str | None = None
|
||||
author_name: str | None = None
|
||||
has_draft_trigger: bool | None = None
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return _build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class AppDetail(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: str = Field(validation_alias="mode_compatible_with_agent")
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
enable_site: bool
|
||||
enable_api: bool
|
||||
model_config_: ModelConfig | None = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("app_model_config", "model_config"),
|
||||
alias="model_config",
|
||||
)
|
||||
workflow: WorkflowPartial | None = None
|
||||
tracing: JSONValue | None = None
|
||||
use_icon_as_answer_icon: bool | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_by: str | None = None
|
||||
updated_at: int | None = None
|
||||
access_mode: str | None = None
|
||||
tags: list[Tag] = Field(default_factory=list)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class AppDetailWithSite(AppDetail):
|
||||
icon_type: str | None = None
|
||||
api_base_url: str | None = None
|
||||
max_active_requests: int | None = None
|
||||
deleted_tools: list[DeletedTool] = Field(default_factory=list)
|
||||
site: Site | None = None
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
return _build_icon_url(self.icon_type, self.icon)
|
||||
|
||||
|
||||
class AppPagination(ResponseModel):
|
||||
page: int
|
||||
limit: int = Field(validation_alias=AliasChoices("per_page", "limit"))
|
||||
total: int
|
||||
has_more: bool = Field(validation_alias=AliasChoices("has_next", "has_more"))
|
||||
data: list[AppPartial] = Field(validation_alias=AliasChoices("items", "data"))
|
||||
|
||||
|
||||
class AppExportResponse(ResponseModel):
|
||||
data: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
AppListQuery,
|
||||
CreateAppPayload,
|
||||
UpdateAppPayload,
|
||||
CopyAppPayload,
|
||||
AppExportQuery,
|
||||
AppNamePayload,
|
||||
AppIconPayload,
|
||||
AppSiteStatusPayload,
|
||||
AppApiStatusPayload,
|
||||
AppTracePayload,
|
||||
Tag,
|
||||
WorkflowPartial,
|
||||
ModelConfigPartial,
|
||||
ModelConfig,
|
||||
Site,
|
||||
DeletedTool,
|
||||
AppPartial,
|
||||
AppDetail,
|
||||
AppDetailWithSite,
|
||||
AppPagination,
|
||||
AppExportResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -318,7 +480,7 @@ class AppListApi(Resource):
|
||||
@console_ns.doc("list_apps")
|
||||
@console_ns.doc(description="Get list of applications with pagination and filtering")
|
||||
@console_ns.expect(console_ns.models[AppListQuery.__name__])
|
||||
@console_ns.response(200, "Success", app_pagination_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[AppPagination.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -334,7 +496,8 @@ class AppListApi(Resource):
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
|
||||
if not app_pagination:
|
||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
|
||||
return empty.model_dump(mode="json"), 200
|
||||
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
app_ids = [str(app.id) for app in app_pagination.items]
|
||||
@ -378,18 +541,18 @@ class AppListApi(Resource):
|
||||
for app in app_pagination.items:
|
||||
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
|
||||
|
||||
return marshal(app_pagination, app_pagination_model), 200
|
||||
pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True)
|
||||
return pagination_model.model_dump(mode="json"), 200
|
||||
|
||||
@console_ns.doc("create_app")
|
||||
@console_ns.doc(description="Create a new application")
|
||||
@console_ns.expect(console_ns.models[CreateAppPayload.__name__])
|
||||
@console_ns.response(201, "App created successfully", app_detail_model)
|
||||
@console_ns.response(201, "App created successfully", console_ns.models[AppDetail.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_detail_model)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
@ -399,8 +562,8 @@ class AppListApi(Resource):
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
|
||||
|
||||
return app, 201
|
||||
app_detail = AppDetail.model_validate(app, from_attributes=True)
|
||||
return app_detail.model_dump(mode="json"), 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>")
|
||||
@ -408,13 +571,12 @@ class AppApi(Resource):
|
||||
@console_ns.doc("get_app_detail")
|
||||
@console_ns.doc(description="Get application details")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Success", app_detail_with_site_model)
|
||||
@console_ns.response(200, "Success", console_ns.models[AppDetailWithSite.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_with_site_model)
|
||||
@get_app_model(mode=None)
|
||||
def get(self, app_model):
|
||||
"""Get app detail"""
|
||||
app_service = AppService()
|
||||
@ -425,21 +587,21 @@ class AppApi(Resource):
|
||||
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
|
||||
app_model.access_mode = app_setting.access_mode
|
||||
|
||||
return app_model
|
||||
response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
|
||||
return response_model.model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("update_app")
|
||||
@console_ns.doc(description="Update application details")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[UpdateAppPayload.__name__])
|
||||
@console_ns.response(200, "App updated successfully", app_detail_with_site_model)
|
||||
@console_ns.response(200, "App updated successfully", console_ns.models[AppDetailWithSite.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
@marshal_with(app_detail_with_site_model)
|
||||
def put(self, app_model):
|
||||
"""Update app"""
|
||||
args = UpdateAppPayload.model_validate(console_ns.payload)
|
||||
@ -456,8 +618,8 @@ class AppApi(Resource):
|
||||
"max_active_requests": args.max_active_requests or 0,
|
||||
}
|
||||
app_model = app_service.update_app(app_model, args_dict)
|
||||
|
||||
return app_model
|
||||
response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
|
||||
return response_model.model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("delete_app")
|
||||
@console_ns.doc(description="Delete application")
|
||||
@ -483,14 +645,13 @@ class AppCopyApi(Resource):
|
||||
@console_ns.doc(description="Create a copy of an existing application")
|
||||
@console_ns.doc(params={"app_id": "Application ID to copy"})
|
||||
@console_ns.expect(console_ns.models[CopyAppPayload.__name__])
|
||||
@console_ns.response(201, "App copied successfully", app_detail_with_site_model)
|
||||
@console_ns.response(201, "App copied successfully", console_ns.models[AppDetailWithSite.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
@marshal_with(app_detail_with_site_model)
|
||||
def post(self, app_model):
|
||||
"""Copy app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
@ -516,7 +677,8 @@ class AppCopyApi(Resource):
|
||||
stmt = select(App).where(App.id == result.app_id)
|
||||
app = session.scalar(stmt)
|
||||
|
||||
return app, 201
|
||||
response_model = AppDetailWithSite.model_validate(app, from_attributes=True)
|
||||
return response_model.model_dump(mode="json"), 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/export")
|
||||
@ -525,11 +687,7 @@ class AppExportApi(Resource):
|
||||
@console_ns.doc(description="Export application configuration as DSL")
|
||||
@console_ns.doc(params={"app_id": "Application ID to export"})
|
||||
@console_ns.expect(console_ns.models[AppExportQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"App exported successfully",
|
||||
console_ns.model("AppExportResponse", {"data": fields.String(description="DSL export data")}),
|
||||
)
|
||||
@console_ns.response(200, "App exported successfully", console_ns.models[AppExportResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@ -540,13 +698,14 @@ class AppExportApi(Resource):
|
||||
"""Export app"""
|
||||
args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
return {
|
||||
"data": AppDslService.export_dsl(
|
||||
payload = AppExportResponse(
|
||||
data=AppDslService.export_dsl(
|
||||
app_model=app_model,
|
||||
include_secret=args.include_secret,
|
||||
workflow_id=args.workflow_id,
|
||||
)
|
||||
}
|
||||
)
|
||||
return payload.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
@ -555,20 +714,19 @@ class AppNameApi(Resource):
|
||||
@console_ns.doc(description="Check if app name is available")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppNamePayload.__name__])
|
||||
@console_ns.response(200, "Name availability checked")
|
||||
@console_ns.response(200, "Name availability checked", console_ns.models[AppDetail.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_model)
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
args = AppNamePayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_name(app_model, args.name)
|
||||
|
||||
return app_model
|
||||
response_model = AppDetail.model_validate(app_model, from_attributes=True)
|
||||
return response_model.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/icon")
|
||||
@ -582,16 +740,15 @@ class AppIconApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_model)
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
args = AppIconPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
|
||||
|
||||
return app_model
|
||||
response_model = AppDetail.model_validate(app_model, from_attributes=True)
|
||||
return response_model.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site-enable")
|
||||
@ -600,21 +757,20 @@ class AppSiteStatus(Resource):
|
||||
@console_ns.doc(description="Enable or disable app site")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__])
|
||||
@console_ns.response(200, "Site status updated successfully", app_detail_model)
|
||||
@console_ns.response(200, "Site status updated successfully", console_ns.models[AppDetail.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_model)
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
args = AppSiteStatusPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_site_status(app_model, args.enable_site)
|
||||
|
||||
return app_model
|
||||
response_model = AppDetail.model_validate(app_model, from_attributes=True)
|
||||
return response_model.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/api-enable")
|
||||
@ -623,21 +779,20 @@ class AppApiStatus(Resource):
|
||||
@console_ns.doc(description="Enable or disable app API")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppApiStatusPayload.__name__])
|
||||
@console_ns.response(200, "API status updated successfully", app_detail_model)
|
||||
@console_ns.response(200, "API status updated successfully", console_ns.models[AppDetail.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_model)
|
||||
@get_app_model(mode=None)
|
||||
def post(self, app_model):
|
||||
args = AppApiStatusPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_api_status(app_model, args.enable_api)
|
||||
|
||||
return app_model
|
||||
response_model = AppDetail.model_validate(app_model, from_attributes=True)
|
||||
return response_model.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trace")
|
||||
|
||||
@ -348,10 +348,13 @@ class CompletionConversationApi(Resource):
|
||||
)
|
||||
|
||||
if args.keyword:
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_keyword = escape_like_pattern(args.keyword)
|
||||
query = query.join(Message, Message.conversation_id == Conversation.id).where(
|
||||
or_(
|
||||
Message.query.ilike(f"%{args.keyword}%"),
|
||||
Message.answer.ilike(f"%{args.keyword}%"),
|
||||
Message.query.ilike(f"%{escaped_keyword}%", escape="\\"),
|
||||
Message.answer.ilike(f"%{escaped_keyword}%", escape="\\"),
|
||||
)
|
||||
)
|
||||
|
||||
@ -460,7 +463,10 @@ class ChatConversationApi(Resource):
|
||||
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
|
||||
if args.keyword:
|
||||
keyword_filter = f"%{args.keyword}%"
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_keyword = escape_like_pattern(args.keyword)
|
||||
keyword_filter = f"%{escaped_keyword}%"
|
||||
query = (
|
||||
query.join(
|
||||
Message,
|
||||
@ -469,11 +475,11 @@ class ChatConversationApi(Resource):
|
||||
.join(subquery, subquery.c.conversation_id == Conversation.id)
|
||||
.where(
|
||||
or_(
|
||||
Message.query.ilike(keyword_filter),
|
||||
Message.answer.ilike(keyword_filter),
|
||||
Conversation.name.ilike(keyword_filter),
|
||||
Conversation.introduction.ilike(keyword_filter),
|
||||
subquery.c.from_end_user_session_id.ilike(keyword_filter),
|
||||
Message.query.ilike(keyword_filter, escape="\\"),
|
||||
Message.answer.ilike(keyword_filter, escape="\\"),
|
||||
Conversation.name.ilike(keyword_filter, escape="\\"),
|
||||
Conversation.introduction.ilike(keyword_filter, escape="\\"),
|
||||
subquery.c.from_end_user_session_id.ilike(keyword_filter, escape="\\"),
|
||||
),
|
||||
)
|
||||
.group_by(Conversation.id)
|
||||
|
||||
@ -202,6 +202,7 @@ message_detail_model = console_ns.model(
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"generation_detail": fields.Raw,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
import flask_login
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
@ -96,14 +98,13 @@ class LoginApi(Resource):
|
||||
if is_login_error_rate_limit:
|
||||
raise EmailPasswordLoginLimitError()
|
||||
|
||||
# TODO: why invitation is re-assigned with different type?
|
||||
invitation = args.invite_token # type: ignore
|
||||
if invitation:
|
||||
invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
|
||||
invitation_data: dict[str, Any] | None = None
|
||||
if args.invite_token:
|
||||
invitation_data = RegisterService.get_invitation_if_token_valid(None, args.email, args.invite_token)
|
||||
|
||||
try:
|
||||
if invitation:
|
||||
data = invitation.get("data", {}) # type: ignore
|
||||
if invitation_data:
|
||||
data = invitation_data.get("data", {})
|
||||
invitee_email = data.get("email") if data else None
|
||||
if invitee_email != args.email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
@ -30,6 +30,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from libs.helper import escape_like_pattern
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import ChildChunk, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
@ -145,6 +146,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
query = query.where(DocumentSegment.hit_count >= hit_count_gte)
|
||||
|
||||
if keyword:
|
||||
# Escape special characters in keyword to prevent SQL injection via LIKE wildcards
|
||||
escaped_keyword = escape_like_pattern(keyword)
|
||||
# Search in both content and keywords fields
|
||||
# Use database-specific methods for JSON array search
|
||||
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
|
||||
@ -156,15 +159,15 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
.scalar_subquery()
|
||||
),
|
||||
",",
|
||||
).ilike(f"%{keyword}%")
|
||||
).ilike(f"%{escaped_keyword}%", escape="\\")
|
||||
else:
|
||||
# MySQL: Cast JSON to string for pattern matching
|
||||
# MySQL stores Chinese text directly in JSON without Unicode escaping
|
||||
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{keyword}%")
|
||||
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{escaped_keyword}%", escape="\\")
|
||||
|
||||
query = query.where(
|
||||
or_(
|
||||
DocumentSegment.content.ilike(f"%{keyword}%"),
|
||||
DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"),
|
||||
keywords_condition,
|
||||
)
|
||||
)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import marshal, reqparse
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
@ -56,15 +56,10 @@ class DatasetsHitTestingBase:
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
@staticmethod
|
||||
def parse_args():
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("query", type=str, required=False, location="json")
|
||||
.add_argument("attachment_ids", type=list, required=False, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
)
|
||||
return parser.parse_args()
|
||||
def parse_args(payload: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate and return hit-testing arguments from an incoming payload."""
|
||||
hit_testing_payload = HitTestingPayload.model_validate(payload or {})
|
||||
return hit_testing_payload.model_dump(exclude_none=True)
|
||||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset, args):
|
||||
|
||||
@ -355,7 +355,7 @@ class PublishedRagPipelineRunApi(Resource):
|
||||
pipeline=pipeline,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED,
|
||||
invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED_PIPELINE,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
@ -15,18 +15,21 @@ from controllers.common.errors import (
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import file_fields, upload_config_fields
|
||||
from fields.file_fields import FileResponse, UploadConfig
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.file_service import FileService
|
||||
|
||||
from . import console_ns
|
||||
|
||||
register_schema_models(console_ns, UploadConfig, FileResponse)
|
||||
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
@ -35,26 +38,27 @@ class FileApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(upload_config_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[UploadConfig.__name__])
|
||||
def get(self):
|
||||
return {
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
"batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT,
|
||||
"file_upload_limit": dify_config.BATCH_UPLOAD_LIMIT,
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
||||
"image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT,
|
||||
"single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
|
||||
"attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
|
||||
}, 200
|
||||
config = UploadConfig(
|
||||
file_size_limit=dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
batch_count_limit=dify_config.UPLOAD_FILE_BATCH_LIMIT,
|
||||
file_upload_limit=dify_config.BATCH_UPLOAD_LIMIT,
|
||||
image_file_size_limit=dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
video_file_size_limit=dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
audio_file_size_limit=dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
workflow_file_upload_limit=dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
||||
image_file_batch_limit=dify_config.IMAGE_FILE_BATCH_LIMIT,
|
||||
single_chunk_attachment_limit=dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
|
||||
attachment_image_file_size_limit=dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
|
||||
)
|
||||
return config.model_dump(mode="json"), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(file_fields)
|
||||
@cloud_edition_billing_resource_check("documents")
|
||||
@console_ns.response(201, "File uploaded successfully", console_ns.models[FileResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
source_str = request.form.get("source")
|
||||
@ -90,7 +94,8 @@ class FileApi(Resource):
|
||||
except services.errors.file.BlockedFileExtensionError as blocked_extension_error:
|
||||
raise BlockedFileExtensionError(blocked_extension_error.description)
|
||||
|
||||
return upload_file, 201
|
||||
response = FileResponse.model_validate(upload_file, from_attributes=True)
|
||||
return response.model_dump(mode="json"), 201
|
||||
|
||||
|
||||
@console_ns.route("/files/<uuid:file_id>/preview")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from flask_restx import Resource, marshal_with
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import services
|
||||
@ -11,19 +11,22 @@ from controllers.common.errors import (
|
||||
RemoteFileUploadError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.common.schema import register_schema_models
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from libs.login import current_account_with_tenant
|
||||
from services.file_service import FileService
|
||||
|
||||
from . import console_ns
|
||||
|
||||
register_schema_models(console_ns, RemoteFileInfo, FileWithSignedUrl)
|
||||
|
||||
|
||||
@console_ns.route("/remote-files/<path:url>")
|
||||
class RemoteFileInfoApi(Resource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
@console_ns.response(200, "Remote file info", console_ns.models[RemoteFileInfo.__name__])
|
||||
def get(self, url):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
@ -31,10 +34,11 @@ class RemoteFileInfoApi(Resource):
|
||||
# failed back to get method
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
return {
|
||||
"file_type": resp.headers.get("Content-Type", "application/octet-stream"),
|
||||
"file_length": int(resp.headers.get("Content-Length", 0)),
|
||||
}
|
||||
info = RemoteFileInfo(
|
||||
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
|
||||
file_length=int(resp.headers.get("Content-Length", 0)),
|
||||
)
|
||||
return info.model_dump(mode="json")
|
||||
|
||||
|
||||
class RemoteFileUploadPayload(BaseModel):
|
||||
@ -50,7 +54,7 @@ console_ns.schema_model(
|
||||
@console_ns.route("/remote-files/upload")
|
||||
class RemoteFileUploadApi(Resource):
|
||||
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
|
||||
@marshal_with(file_fields_with_signed_url)
|
||||
@console_ns.response(201, "Remote file uploaded", console_ns.models[FileWithSignedUrl.__name__])
|
||||
def post(self):
|
||||
args = RemoteFileUploadPayload.model_validate(console_ns.payload)
|
||||
url = args.url
|
||||
@ -85,13 +89,14 @@ class RemoteFileUploadApi(Resource):
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at,
|
||||
}, 201
|
||||
payload = FileWithSignedUrl(
|
||||
id=upload_file.id,
|
||||
name=upload_file.name,
|
||||
size=upload_file.size,
|
||||
extension=upload_file.extension,
|
||||
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||
mime_type=upload_file.mime_type,
|
||||
created_by=upload_file.created_by,
|
||||
created_at=int(upload_file.created_at.timestamp()),
|
||||
)
|
||||
return payload.model_dump(mode="json"), 201
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
@ -99,7 +101,7 @@ class AccountPasswordPayload(BaseModel):
|
||||
repeat_new_password: str
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_passwords_match(self) -> "AccountPasswordPayload":
|
||||
def check_passwords_match(self) -> AccountPasswordPayload:
|
||||
if self.new_password != self.repeat_new_password:
|
||||
raise RepeatPasswordNotMatchError()
|
||||
return self
|
||||
|
||||
@ -4,18 +4,18 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from flask_restx.api import HTTPStatus
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from core.file.helpers import verify_plugin_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from fields.file_fields import build_file_model
|
||||
from fields.file_fields import FileResponse
|
||||
|
||||
from ..common.errors import (
|
||||
FileTooLargeError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from ..common.schema import register_schema_models
|
||||
from ..console.wraps import setup_required
|
||||
from ..files import files_ns
|
||||
from ..inner_api.plugin.wraps import get_user
|
||||
@ -35,6 +35,8 @@ files_ns.schema_model(
|
||||
PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
register_schema_models(files_ns, FileResponse)
|
||||
|
||||
|
||||
@files_ns.route("/upload/for-plugin")
|
||||
class PluginUploadFileApi(Resource):
|
||||
@ -51,7 +53,7 @@ class PluginUploadFileApi(Resource):
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
@files_ns.marshal_with(build_file_model(files_ns), code=HTTPStatus.CREATED)
|
||||
@files_ns.response(HTTPStatus.CREATED, "File uploaded", files_ns.models[FileResponse.__name__])
|
||||
def post(self):
|
||||
"""Upload a file for plugin usage.
|
||||
|
||||
@ -69,7 +71,7 @@ class PluginUploadFileApi(Resource):
|
||||
"""
|
||||
args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
file: FileStorage | None = request.files.get("file")
|
||||
file = request.files.get("file")
|
||||
if file is None:
|
||||
raise Forbidden("File is required.")
|
||||
|
||||
@ -80,8 +82,8 @@ class PluginUploadFileApi(Resource):
|
||||
user_id = args.user_id
|
||||
user = get_user(tenant_id, user_id)
|
||||
|
||||
filename: str | None = file.filename
|
||||
mimetype: str | None = file.mimetype
|
||||
filename = file.filename
|
||||
mimetype = file.mimetype
|
||||
|
||||
if not filename or not mimetype:
|
||||
raise Forbidden("Invalid request.")
|
||||
@ -111,22 +113,22 @@ class PluginUploadFileApi(Resource):
|
||||
preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension)
|
||||
|
||||
# Create a dictionary with all the necessary attributes
|
||||
result = {
|
||||
"id": tool_file.id,
|
||||
"user_id": tool_file.user_id,
|
||||
"tenant_id": tool_file.tenant_id,
|
||||
"conversation_id": tool_file.conversation_id,
|
||||
"file_key": tool_file.file_key,
|
||||
"mimetype": tool_file.mimetype,
|
||||
"original_url": tool_file.original_url,
|
||||
"name": tool_file.name,
|
||||
"size": tool_file.size,
|
||||
"mime_type": mimetype,
|
||||
"extension": extension,
|
||||
"preview_url": preview_url,
|
||||
}
|
||||
result = FileResponse(
|
||||
id=tool_file.id,
|
||||
name=tool_file.name,
|
||||
size=tool_file.size,
|
||||
extension=extension,
|
||||
mime_type=mimetype,
|
||||
preview_url=preview_url,
|
||||
source_url=tool_file.original_url,
|
||||
original_url=tool_file.original_url,
|
||||
user_id=tool_file.user_id,
|
||||
tenant_id=tool_file.tenant_id,
|
||||
conversation_id=tool_file.conversation_id,
|
||||
file_key=tool_file.file_key,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
return result.model_dump(mode="json"), 201
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
|
||||
@ -10,13 +10,16 @@ from controllers.common.errors import (
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import build_file_model
|
||||
from fields.file_fields import FileResponse
|
||||
from models import App, EndUser
|
||||
from services.file_service import FileService
|
||||
|
||||
register_schema_models(service_api_ns, FileResponse)
|
||||
|
||||
|
||||
@service_api_ns.route("/files/upload")
|
||||
class FileApi(Resource):
|
||||
@ -31,8 +34,8 @@ class FileApi(Resource):
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
||||
@service_api_ns.marshal_with(build_file_model(service_api_ns), code=HTTPStatus.CREATED)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) # type: ignore
|
||||
@service_api_ns.response(HTTPStatus.CREATED, "File uploaded", service_api_ns.models[FileResponse.__name__])
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""Upload a file for use in conversations.
|
||||
|
||||
@ -64,4 +67,5 @@ class FileApi(Resource):
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return upload_file, 201
|
||||
response = FileResponse.model_validate(upload_file, from_attributes=True)
|
||||
return response.model_dump(mode="json"), 201
|
||||
|
||||
@ -24,7 +24,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
args = self.parse_args()
|
||||
args = self.parse_args(service_api_ns.payload)
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
|
||||
@ -174,7 +174,7 @@ class PipelineRunApi(DatasetApiResource):
|
||||
pipeline=pipeline,
|
||||
user=current_user,
|
||||
args=payload.model_dump(),
|
||||
invoke_from=InvokeFrom.PUBLISHED if payload.is_published else InvokeFrom.DEBUGGER,
|
||||
invoke_from=InvokeFrom.PUBLISHED_PIPELINE if payload.is_published else InvokeFrom.DEBUGGER,
|
||||
streaming=payload.response_mode == "streaming",
|
||||
)
|
||||
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from flask_restx import reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from pydantic import TypeAdapter
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotChatAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
@ -21,6 +23,35 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers
|
||||
from services.web_conversation_service import WebConversationService
|
||||
|
||||
|
||||
class ConversationListQuery(BaseModel):
|
||||
last_id: str | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
pinned: bool | None = None
|
||||
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = "-updated_at"
|
||||
|
||||
@field_validator("last_id")
|
||||
@classmethod
|
||||
def validate_last_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = None
|
||||
auto_generate: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)
|
||||
|
||||
|
||||
@web_ns.route("/conversations")
|
||||
class ConversationListApi(WebApiResource):
|
||||
@web_ns.doc("Get Conversation List")
|
||||
@ -64,25 +95,8 @@ class ConversationListApi(WebApiResource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
|
||||
.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
required=False,
|
||||
default="-updated_at",
|
||||
location="args",
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
pinned = None
|
||||
if "pinned" in args and args["pinned"] is not None:
|
||||
pinned = args["pinned"] == "true"
|
||||
raw_args = request.args.to_dict()
|
||||
query = ConversationListQuery.model_validate(raw_args)
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
@ -90,11 +104,11 @@ class ConversationListApi(WebApiResource):
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
last_id=query.last_id,
|
||||
limit=query.limit,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=pinned,
|
||||
sort_by=args["sort_by"],
|
||||
pinned=query.pinned,
|
||||
sort_by=query.sort_by,
|
||||
)
|
||||
adapter = TypeAdapter(SimpleConversation)
|
||||
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
|
||||
@ -168,16 +182,11 @@ class ConversationRenameApi(WebApiResource):
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=False, location="json")
|
||||
.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = ConversationRenamePayload.model_validate(web_ns.payload or {})
|
||||
|
||||
try:
|
||||
conversation = ConversationService.rename(
|
||||
app_model, conversation_id, end_user, args["name"], args["auto_generate"]
|
||||
app_model, conversation_id, end_user, payload.name, payload.auto_generate
|
||||
)
|
||||
return (
|
||||
TypeAdapter(SimpleConversation)
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from flask import request
|
||||
from flask_restx import marshal_with
|
||||
|
||||
import services
|
||||
from controllers.common.errors import (
|
||||
@ -9,12 +8,15 @@ from controllers.common.errors import (
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import build_file_model
|
||||
from fields.file_fields import FileResponse
|
||||
from services.file_service import FileService
|
||||
|
||||
register_schema_models(web_ns, FileResponse)
|
||||
|
||||
|
||||
@web_ns.route("/files/upload")
|
||||
class FileApi(WebApiResource):
|
||||
@ -28,7 +30,7 @@ class FileApi(WebApiResource):
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
@marshal_with(build_file_model(web_ns))
|
||||
@web_ns.response(201, "File uploaded successfully", web_ns.models[FileResponse.__name__])
|
||||
def post(self, app_model, end_user):
|
||||
"""Upload a file for use in web applications.
|
||||
|
||||
@ -81,4 +83,5 @@ class FileApi(WebApiResource):
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return upload_file, 201
|
||||
response = FileResponse.model_validate(upload_file, from_attributes=True)
|
||||
return response.model_dump(mode="json"), 201
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from flask_restx import marshal_with
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
import services
|
||||
@ -14,7 +13,7 @@ from controllers.common.errors import (
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from services.file_service import FileService
|
||||
|
||||
from ..common.schema import register_schema_models
|
||||
@ -26,7 +25,7 @@ class RemoteFileUploadPayload(BaseModel):
|
||||
url: HttpUrl = Field(description="Remote file URL")
|
||||
|
||||
|
||||
register_schema_models(web_ns, RemoteFileUploadPayload)
|
||||
register_schema_models(web_ns, RemoteFileUploadPayload, RemoteFileInfo, FileWithSignedUrl)
|
||||
|
||||
|
||||
@web_ns.route("/remote-files/<path:url>")
|
||||
@ -41,7 +40,7 @@ class RemoteFileInfoApi(WebApiResource):
|
||||
500: "Failed to fetch remote file",
|
||||
}
|
||||
)
|
||||
@marshal_with(build_remote_file_info_model(web_ns))
|
||||
@web_ns.response(200, "Remote file info", web_ns.models[RemoteFileInfo.__name__])
|
||||
def get(self, app_model, end_user, url):
|
||||
"""Get information about a remote file.
|
||||
|
||||
@ -65,10 +64,11 @@ class RemoteFileInfoApi(WebApiResource):
|
||||
# failed back to get method
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
return {
|
||||
"file_type": resp.headers.get("Content-Type", "application/octet-stream"),
|
||||
"file_length": int(resp.headers.get("Content-Length", -1)),
|
||||
}
|
||||
info = RemoteFileInfo(
|
||||
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
|
||||
file_length=int(resp.headers.get("Content-Length", -1)),
|
||||
)
|
||||
return info.model_dump(mode="json")
|
||||
|
||||
|
||||
@web_ns.route("/remote-files/upload")
|
||||
@ -84,7 +84,7 @@ class RemoteFileUploadApi(WebApiResource):
|
||||
500: "Failed to fetch remote file",
|
||||
}
|
||||
)
|
||||
@marshal_with(build_file_with_signed_url_model(web_ns))
|
||||
@web_ns.response(201, "Remote file uploaded", web_ns.models[FileWithSignedUrl.__name__])
|
||||
def post(self, app_model, end_user):
|
||||
"""Upload a file from a remote URL.
|
||||
|
||||
@ -139,13 +139,14 @@ class RemoteFileUploadApi(WebApiResource):
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at,
|
||||
}, 201
|
||||
payload1 = FileWithSignedUrl(
|
||||
id=upload_file.id,
|
||||
name=upload_file.name,
|
||||
size=upload_file.size,
|
||||
extension=upload_file.extension,
|
||||
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||
mime_type=upload_file.mime_type,
|
||||
created_by=upload_file.created_by,
|
||||
created_at=int(upload_file.created_at.timestamp()),
|
||||
)
|
||||
return payload1.model_dump(mode="json"), 201
|
||||
|
||||
@ -1,18 +1,30 @@
|
||||
from flask_restx import reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from pydantic import TypeAdapter
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotCompletionAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from libs.helper import uuid_value
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
||||
class SavedMessageListQuery(BaseModel):
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SavedMessageCreatePayload(BaseModel):
|
||||
message_id: UUIDStrOrEmpty
|
||||
|
||||
|
||||
register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||
|
||||
|
||||
@web_ns.route("/saved-messages")
|
||||
class SavedMessageListApi(WebApiResource):
|
||||
@web_ns.doc("Get Saved Messages")
|
||||
@ -42,14 +54,10 @@ class SavedMessageListApi(WebApiResource):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
raw_args = request.args.to_dict()
|
||||
query = SavedMessageListQuery.model_validate(raw_args)
|
||||
|
||||
pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
|
||||
pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit)
|
||||
adapter = TypeAdapter(SavedMessageItem)
|
||||
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
|
||||
return SavedMessageInfiniteScrollPagination(
|
||||
@ -79,11 +87,10 @@ class SavedMessageListApi(WebApiResource):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
payload = SavedMessageCreatePayload.model_validate(web_ns.payload or {})
|
||||
|
||||
try:
|
||||
SavedMessageService.save(app_model, end_user, args["message_id"])
|
||||
SavedMessageService.save(app_model, end_user, payload.message_id)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
|
||||
380
api/core/agent/agent_app_runner.py
Normal file
380
api/core/agent/agent_app_runner.py
Normal file
@ -0,0 +1,380 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentResult
|
||||
from core.agent.patterns.strategy_factory import StrategyFactory
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.file import file_manager
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentAppRunner(BaseAgentRunner):
|
||||
def _create_tool_invoke_hook(self, message: Message):
|
||||
"""
|
||||
Create a tool invoke hook that uses ToolEngine.agent_invoke.
|
||||
This hook handles file creation and returns proper meta information.
|
||||
"""
|
||||
# Get trace manager from app generate entity
|
||||
trace_manager = self.application_generate_entity.trace_manager
|
||||
|
||||
def tool_invoke_hook(
|
||||
tool: Tool, tool_args: dict[str, Any], tool_name: str
|
||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
||||
"""Hook that uses agent_invoke for proper file and meta handling."""
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool,
|
||||
tool_parameters=tool_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
message_id=message.id,
|
||||
conversation_id=self.conversation.id,
|
||||
)
|
||||
|
||||
# Publish files and track IDs
|
||||
for message_file_id in message_files:
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
self._current_message_file_ids.append(message_file_id)
|
||||
|
||||
return tool_invoke_response, message_files, tool_invoke_meta
|
||||
|
||||
return tool_invoke_hook
|
||||
|
||||
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run Agent application
|
||||
"""
|
||||
self.query = query
|
||||
app_generate_entity = self.application_generate_entity
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config is not None, "app_config is required"
|
||||
assert app_config.agent is not None, "app_config.agent is required"
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, _ = self._init_prompt_tools()
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
# Create tool invoke hook for agent_invoke
|
||||
tool_invoke_hook = self._create_tool_invoke_hook(message)
|
||||
|
||||
# Get instruction for ReAct strategy
|
||||
instruction = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
|
||||
# Use factory to create appropriate strategy
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=self.model_features,
|
||||
model_instance=self.model_instance,
|
||||
tools=list(tool_instances.values()),
|
||||
files=list(self.files),
|
||||
max_iterations=app_config.agent.max_iteration,
|
||||
context=self.build_execution_context(),
|
||||
agent_strategy=self.config.strategy,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
# Initialize state variables
|
||||
current_agent_thought_id = None
|
||||
has_published_thought = False
|
||||
current_tool_name: str | None = None
|
||||
self._current_message_file_ids: list[str] = []
|
||||
|
||||
# organize prompt messages
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
|
||||
# Run strategy
|
||||
generator = strategy.run(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Consume generator and collect result
|
||||
result: AgentResult | None = None
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
output = next(generator)
|
||||
except StopIteration as e:
|
||||
# Generator finished, get the return value
|
||||
result = e.value
|
||||
break
|
||||
|
||||
if isinstance(output, LLMResultChunk):
|
||||
# Handle LLM chunk
|
||||
if current_agent_thought_id and not has_published_thought:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
has_published_thought = True
|
||||
|
||||
yield output
|
||||
|
||||
elif isinstance(output, AgentLog):
|
||||
# Handle Agent Log using log_type for type-safe dispatch
|
||||
if output.status == AgentLog.LogStatus.START:
|
||||
if output.log_type == AgentLog.LogType.ROUND:
|
||||
# Start of a new round
|
||||
message_file_ids: list[str] = []
|
||||
current_agent_thought_id = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message="",
|
||||
tool_name="",
|
||||
tool_input="",
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
has_published_thought = False
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Tool call start - extract data from structured fields
|
||||
current_tool_name = output.data.get("tool_name", "")
|
||||
tool_input = output.data.get("tool_args", {})
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=current_tool_name,
|
||||
tool_input=tool_input,
|
||||
thought=None,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.status == AgentLog.LogStatus.SUCCESS:
|
||||
if output.log_type == AgentLog.LogType.THOUGHT:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
thought_text = output.data.get("thought")
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=thought_text,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Tool call finished
|
||||
tool_output = output.data.get("output")
|
||||
# Get meta from strategy output (now properly populated)
|
||||
tool_meta = output.data.get("meta")
|
||||
|
||||
# Wrap tool_meta with tool_name as key (required by agent_service)
|
||||
if tool_meta and current_tool_name:
|
||||
tool_meta = {current_tool_name: tool_meta}
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
observation=tool_output,
|
||||
tool_invoke_meta=tool_meta,
|
||||
answer=None,
|
||||
messages_ids=self._current_message_file_ids,
|
||||
)
|
||||
# Clear message file ids after saving
|
||||
self._current_message_file_ids = []
|
||||
current_tool_name = None
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.ROUND:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Round finished - save LLM usage and answer
|
||||
llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE)
|
||||
llm_result = output.data.get("llm_result")
|
||||
final_answer = output.data.get("final_answer")
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=llm_result,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# Re-raise any other exceptions
|
||||
raise
|
||||
|
||||
# Process final result
|
||||
if isinstance(result, AgentResult):
|
||||
final_answer = result.text
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
|
||||
# Publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=self.model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=usage,
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
if not prompt_template:
|
||||
return prompt_messages or []
|
||||
|
||||
prompt_messages = prompt_messages or []
|
||||
|
||||
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
prompt_messages[0] = SystemPromptMessage(content=prompt_template)
|
||||
return prompt_messages
|
||||
|
||||
if not prompt_messages:
|
||||
return [SystemPromptMessage(content=prompt_template)]
|
||||
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
return prompt_messages
|
||||
|
||||
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
self.application_generate_entity.file_upload_config
|
||||
and self.application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
We need to remove the image messages from the prompt messages at the first iteration.
|
||||
"""
|
||||
prompt_messages = deepcopy(prompt_messages)
|
||||
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = "\n".join(
|
||||
[
|
||||
content.data
|
||||
if content.type == PromptMessageContentType.TEXT
|
||||
else "[image]"
|
||||
if content.type == PromptMessageContentType.IMAGE
|
||||
else "[file]"
|
||||
for content in prompt_message.content
|
||||
]
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
# For ReAct strategy, use the agent prompt template
|
||||
if self.config.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT and self.config.prompt:
|
||||
prompt_template = self.config.prompt.first_prompt
|
||||
else:
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query or "", [])
|
||||
|
||||
self.history_prompt_messages = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
||||
history_messages=self.history_prompt_messages,
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
|
||||
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
|
||||
if len(self._current_thoughts) != 0:
|
||||
# clear messages after the first iteration
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
return prompt_messages
|
||||
@ -5,7 +5,7 @@ from typing import Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -114,9 +114,20 @@ class BaseAgentRunner(AppRunner):
|
||||
features = model_schema.features if model_schema and model_schema.features else []
|
||||
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
|
||||
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
|
||||
self.model_features = features
|
||||
self.query: str | None = ""
|
||||
self._current_thoughts: list[PromptMessage] = []
|
||||
|
||||
def build_execution_context(self) -> ExecutionContext:
|
||||
"""Build execution context."""
|
||||
return ExecutionContext(
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_config.app_id,
|
||||
conversation_id=self.conversation.id,
|
||||
message_id=self.message.id,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
def _repack_app_generate_entity(
|
||||
self, app_generate_entity: AgentChatAppGenerateEntity
|
||||
) -> AgentChatAppGenerateEntity:
|
||||
|
||||
@ -1,437 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.workflow.nodes.agent.exc import AgentMaxIterationError
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
_historic_prompt_messages: list[PromptMessage]
|
||||
_agent_scratchpad: list[AgentScratchpadUnit]
|
||||
_instruction: str
|
||||
_query: str
|
||||
_prompt_messages_tools: Sequence[PromptMessageTool]
|
||||
|
||||
def run(
|
||||
self,
|
||||
message: Message,
|
||||
query: str,
|
||||
inputs: Mapping[str, str],
|
||||
) -> Generator:
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
|
||||
app_generate_entity = self.application_generate_entity
|
||||
self._repack_app_generate_entity(app_generate_entity)
|
||||
self._init_react_state(query)
|
||||
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
# check model mode
|
||||
if "Observation" not in app_generate_entity.model_conf.stop:
|
||||
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
|
||||
app_generate_entity.model_conf.stop.append("Observation")
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config.agent
|
||||
|
||||
# init instruction
|
||||
inputs = inputs or {}
|
||||
instruction = app_config.prompt_template.simple_prompt_template or ""
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
self._prompt_messages_tools = prompt_messages_tools
|
||||
|
||||
function_call_state = True
|
||||
llm_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
final_answer = ""
|
||||
prompt_messages: list = [] # Initialize prompt_messages
|
||||
agent_thought_id = "" # Initialize agent_thought_id
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage):
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict["usage"]
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.total_tokens += usage.total_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
llm_usage.completion_price += usage.completion_price
|
||||
llm_usage.total_price += usage.total_price
|
||||
|
||||
model_instance = self.model_instance
|
||||
|
||||
while function_call_state and iteration_step <= max_iteration_steps:
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = False
|
||||
|
||||
if iteration_step == max_iteration_steps:
|
||||
# the last iteration, remove all tools
|
||||
self._prompt_messages_tools = []
|
||||
|
||||
message_file_ids: list[str] = []
|
||||
|
||||
agent_thought_id = self.create_agent_thought(
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
if iteration_step > 1:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
tools=[],
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
usage_dict: dict[str, LLMUsage | None] = {}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
thought="",
|
||||
action_str="",
|
||||
observation="",
|
||||
action=None,
|
||||
)
|
||||
|
||||
# publish agent thought if it's first iteration
|
||||
if iteration_step == 1:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
action = chunk
|
||||
# detect action
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
||||
scratchpad.action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.action = action
|
||||
else:
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += chunk
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought += chunk
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint="",
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
|
||||
)
|
||||
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
# Check if max iteration is reached and model still wants to call tools
|
||||
if iteration_step == max_iteration_steps and scratchpad.action:
|
||||
if scratchpad.action.action_name.lower() != "final answer":
|
||||
raise AgentMaxIterationError(app_config.agent.max_iteration)
|
||||
|
||||
# get llm usage
|
||||
if "usage" in usage_dict:
|
||||
if usage_dict["usage"] is not None:
|
||||
increase_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
usage_dict["usage"] = LLMUsage.empty_usage()
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought or "",
|
||||
observation="",
|
||||
answer=scratchpad.agent_response or "",
|
||||
messages_ids=[],
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
|
||||
if not scratchpad.is_final():
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
if not scratchpad.action:
|
||||
# failed to extract action, return final answer directly
|
||||
final_answer = ""
|
||||
else:
|
||||
if scratchpad.action.action_name.lower() == "final answer":
|
||||
# action is final answer, return final answer directly
|
||||
try:
|
||||
if isinstance(scratchpad.action.action_input, dict):
|
||||
final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False)
|
||||
elif isinstance(scratchpad.action.action_input, str):
|
||||
final_answer = scratchpad.action.action_input
|
||||
else:
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
except TypeError:
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
else:
|
||||
function_call_state = True
|
||||
# action is tool call, invoke tool
|
||||
tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
|
||||
action=scratchpad.action,
|
||||
tool_instances=tool_instances,
|
||||
message_file_ids=message_file_ids,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
scratchpad.observation = tool_invoke_response
|
||||
scratchpad.agent_response = tool_invoke_response
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name=scratchpad.action.action_name,
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
|
||||
thought=scratchpad.thought or "",
|
||||
observation={scratchpad.action.action_name: tool_invoke_response},
|
||||
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=message_file_ids,
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# update prompt tool message
|
||||
for prompt_tool in self._prompt_messages_tools:
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
|
||||
),
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name="",
|
||||
tool_input={},
|
||||
tool_invoke_meta={},
|
||||
thought=final_answer,
|
||||
observation={},
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _handle_invoke_action(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
tool_instances: Mapping[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
"""
|
||||
handle invoke action
|
||||
:param action: action
|
||||
:param tool_instances: tool instances
|
||||
:param message_file_ids: message file ids
|
||||
:param trace_manager: trace manager
|
||||
:return: observation, meta
|
||||
"""
|
||||
# action is tool call, invoke tool
|
||||
tool_call_name = action.action_name
|
||||
tool_call_args = action.action_input
|
||||
tool_instance = tool_instances.get(tool_call_name)
|
||||
|
||||
if not tool_instance:
|
||||
answer = f"there is not a tool named {tool_call_name}"
|
||||
return answer, ToolInvokeMeta.error_instance(answer)
|
||||
|
||||
if isinstance(tool_call_args, str):
|
||||
try:
|
||||
tool_call_args = json.loads(tool_call_args)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# invoke tool
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_call_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=self.message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# publish files
|
||||
for message_file_id in message_files:
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
|
||||
return tool_invoke_response, tool_invoke_meta
|
||||
|
||||
def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
|
||||
"""
|
||||
convert dict to action
|
||||
"""
|
||||
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
|
||||
|
||||
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
fill in inputs from external data tools
|
||||
"""
|
||||
for key, value in inputs.items():
|
||||
try:
|
||||
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return instruction
|
||||
|
||||
def _init_react_state(self, query):
|
||||
"""
|
||||
init agent scratchpad
|
||||
"""
|
||||
self._query = query
|
||||
self._agent_scratchpad = []
|
||||
self._historic_prompt_messages = self._organize_historic_prompt_messages()
|
||||
|
||||
@abstractmethod
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
organize prompt messages
|
||||
"""
|
||||
|
||||
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
|
||||
"""
|
||||
format assistant message
|
||||
"""
|
||||
message = ""
|
||||
for scratchpad in agent_scratchpad:
|
||||
if scratchpad.is_final():
|
||||
message += f"Final Answer: {scratchpad.agent_response}"
|
||||
else:
|
||||
message += f"Thought: {scratchpad.thought}\n\n"
|
||||
if scratchpad.action_str:
|
||||
message += f"Action: {scratchpad.action_str}\n\n"
|
||||
if scratchpad.observation:
|
||||
message += f"Observation: {scratchpad.observation}\n\n"
|
||||
|
||||
return message
|
||||
|
||||
def _organize_historic_prompt_messages(
|
||||
self, current_session_messages: list[PromptMessage] | None = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize historic prompt messages
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
scratchpads: list[AgentScratchpadUnit] = []
|
||||
current_scratchpad: AgentScratchpadUnit | None = None
|
||||
|
||||
for message in self.history_prompt_messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
if not current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
action_str="",
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
scratchpads.append(current_scratchpad)
|
||||
if message.tool_calls:
|
||||
try:
|
||||
current_scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=message.tool_calls[0].function.name,
|
||||
action_input=json.loads(message.tool_calls[0].function.arguments),
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
|
||||
except Exception:
|
||||
logger.exception("Failed to parse tool call from assistant message")
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad.observation = message.content
|
||||
else:
|
||||
raise NotImplementedError("expected str type")
|
||||
elif isinstance(message, UserPromptMessage):
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
scratchpads = []
|
||||
current_scratchpad = None
|
||||
|
||||
result.append(message)
|
||||
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
|
||||
historic_prompts = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=current_session_messages or [],
|
||||
history_messages=result,
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
return historic_prompts
|
||||
@ -1,118 +0,0 @@
|
||||
import json
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.file import file_manager
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
class CotChatAgentRunner(CotAgentRunner):
|
||||
def _organize_system_prompt(self) -> SystemPromptMessage:
|
||||
"""
|
||||
Organize system prompt
|
||||
"""
|
||||
assert self.app_config.agent
|
||||
assert self.app_config.agent.prompt
|
||||
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
if not prompt_entity:
|
||||
raise ValueError("Agent prompt configuration is not set")
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = (
|
||||
first_prompt.replace("{{instruction}}", self._instruction)
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
||||
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
||||
)
|
||||
|
||||
return SystemPromptMessage(content=system_prompt)
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
self.application_generate_entity.file_upload_config
|
||||
and self.application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize
|
||||
"""
|
||||
# organize system prompt
|
||||
system_message = self._organize_system_prompt()
|
||||
|
||||
# organize current assistant messages
|
||||
agent_scratchpad = self._agent_scratchpad
|
||||
if not agent_scratchpad:
|
||||
assistant_messages = []
|
||||
else:
|
||||
assistant_message = AssistantPromptMessage(content="")
|
||||
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Thought: {unit.thought}\n\n"
|
||||
if unit.action_str:
|
||||
assistant_message.content += f"Action: {unit.action_str}\n\n"
|
||||
if unit.observation:
|
||||
assistant_message.content += f"Observation: {unit.observation}\n\n"
|
||||
|
||||
assistant_messages = [assistant_message]
|
||||
|
||||
# query messages
|
||||
query_messages = self._organize_user_query(self._query, [])
|
||||
|
||||
if assistant_messages:
|
||||
# organize historic prompt messages
|
||||
historic_messages = self._organize_historic_prompt_messages(
|
||||
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
|
||||
)
|
||||
messages = [
|
||||
system_message,
|
||||
*historic_messages,
|
||||
*query_messages,
|
||||
*assistant_messages,
|
||||
UserPromptMessage(content="continue"),
|
||||
]
|
||||
else:
|
||||
# organize historic prompt messages
|
||||
historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages])
|
||||
messages = [system_message, *historic_messages, *query_messages]
|
||||
|
||||
# join all messages
|
||||
return messages
|
||||
@ -1,87 +0,0 @@
|
||||
import json
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
class CotCompletionAgentRunner(CotAgentRunner):
|
||||
def _organize_instruction_prompt(self) -> str:
|
||||
"""
|
||||
Organize instruction prompt
|
||||
"""
|
||||
if self.app_config.agent is None:
|
||||
raise ValueError("Agent configuration is not set")
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
if prompt_entity is None:
|
||||
raise ValueError("prompt entity is not set")
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = (
|
||||
first_prompt.replace("{{instruction}}", self._instruction)
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
||||
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
||||
)
|
||||
|
||||
return system_prompt
|
||||
|
||||
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] | None = None) -> str:
|
||||
"""
|
||||
Organize historic prompt
|
||||
"""
|
||||
historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages)
|
||||
historic_prompt = ""
|
||||
|
||||
for message in historic_prompt_messages:
|
||||
if isinstance(message, UserPromptMessage):
|
||||
historic_prompt += f"Question: {message.content}\n\n"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
if isinstance(message.content, str):
|
||||
historic_prompt += message.content + "\n\n"
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if not isinstance(content, TextPromptMessageContent):
|
||||
continue
|
||||
historic_prompt += content.data
|
||||
|
||||
return historic_prompt
|
||||
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
"""
|
||||
# organize system prompt
|
||||
system_prompt = self._organize_instruction_prompt()
|
||||
|
||||
# organize historic prompt messages
|
||||
historic_prompt = self._organize_historic_prompt()
|
||||
|
||||
# organize current assistant messages
|
||||
agent_scratchpad = self._agent_scratchpad
|
||||
assistant_prompt = ""
|
||||
for unit in agent_scratchpad or []:
|
||||
if unit.is_final():
|
||||
assistant_prompt += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
assistant_prompt += f"Thought: {unit.thought}\n\n"
|
||||
if unit.action_str:
|
||||
assistant_prompt += f"Action: {unit.action_str}\n\n"
|
||||
if unit.observation:
|
||||
assistant_prompt += f"Observation: {unit.observation}\n\n"
|
||||
|
||||
# query messages
|
||||
query_prompt = f"Question: {self._query}"
|
||||
|
||||
# join all messages
|
||||
prompt = (
|
||||
system_prompt.replace("{{historic_messages}}", historic_prompt)
|
||||
.replace("{{agent_scratchpad}}", assistant_prompt)
|
||||
.replace("{{query}}", query_prompt)
|
||||
)
|
||||
|
||||
return [UserPromptMessage(content=prompt)]
|
||||
@ -1,3 +1,5 @@
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
@ -92,3 +94,96 @@ class AgentInvokeMessage(ToolInvokeMessage):
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExecutionContext(BaseModel):
|
||||
"""Execution context containing trace and audit information.
|
||||
|
||||
This context carries all the IDs and metadata that are not part of
|
||||
the core business logic but needed for tracing, auditing, and
|
||||
correlation purposes.
|
||||
"""
|
||||
|
||||
user_id: str | None = None
|
||||
app_id: str | None = None
|
||||
conversation_id: str | None = None
|
||||
message_id: str | None = None
|
||||
tenant_id: str | None = None
|
||||
|
||||
@classmethod
|
||||
def create_minimal(cls, user_id: str | None = None) -> "ExecutionContext":
|
||||
"""Create a minimal context with only essential fields."""
|
||||
return cls(user_id=user_id)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for passing to legacy code."""
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"app_id": self.app_id,
|
||||
"conversation_id": self.conversation_id,
|
||||
"message_id": self.message_id,
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
|
||||
def with_updates(self, **kwargs) -> "ExecutionContext":
|
||||
"""Create a new context with updated fields."""
|
||||
data = self.to_dict()
|
||||
data.update(kwargs)
|
||||
|
||||
return ExecutionContext(
|
||||
user_id=data.get("user_id"),
|
||||
app_id=data.get("app_id"),
|
||||
conversation_id=data.get("conversation_id"),
|
||||
message_id=data.get("message_id"),
|
||||
tenant_id=data.get("tenant_id"),
|
||||
)
|
||||
|
||||
|
||||
class AgentLog(BaseModel):
|
||||
"""
|
||||
Agent Log.
|
||||
"""
|
||||
|
||||
class LogType(StrEnum):
|
||||
"""Type of agent log entry."""
|
||||
|
||||
ROUND = "round" # A complete iteration round
|
||||
THOUGHT = "thought" # LLM thinking/reasoning
|
||||
TOOL_CALL = "tool_call" # Tool invocation
|
||||
|
||||
class LogMetadata(StrEnum):
|
||||
STARTED_AT = "started_at"
|
||||
FINISHED_AT = "finished_at"
|
||||
ELAPSED_TIME = "elapsed_time"
|
||||
TOTAL_PRICE = "total_price"
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
PROVIDER = "provider"
|
||||
CURRENCY = "currency"
|
||||
LLM_USAGE = "llm_usage"
|
||||
ICON = "icon"
|
||||
ICON_DARK = "icon_dark"
|
||||
|
||||
class LogStatus(StrEnum):
|
||||
START = "start"
|
||||
ERROR = "error"
|
||||
SUCCESS = "success"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="The id of the log")
|
||||
label: str = Field(..., description="The label of the log")
|
||||
log_type: LogType = Field(..., description="The type of the log")
|
||||
parent_id: str | None = Field(default=None, description="Leave empty for root log")
|
||||
error: str | None = Field(default=None, description="The error message")
|
||||
status: LogStatus = Field(..., description="The status of the log")
|
||||
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
||||
metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log")
|
||||
|
||||
|
||||
class AgentResult(BaseModel):
|
||||
"""
|
||||
Agent execution result.
|
||||
"""
|
||||
|
||||
text: str = Field(default="", description="The generated text")
|
||||
files: list[Any] = Field(default_factory=list, description="Files produced during execution")
|
||||
usage: Any | None = Field(default=None, description="LLM usage statistics")
|
||||
finish_reason: str | None = Field(default=None, description="Reason for completion")
|
||||
|
||||
@ -1,470 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.file import file_manager
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.workflow.nodes.agent.exc import AgentMaxIterationError
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run FunctionCall agent application
|
||||
"""
|
||||
self.query = query
|
||||
app_generate_entity = self.application_generate_entity
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config is not None, "app_config is required"
|
||||
assert app_config.agent is not None, "app_config.agent is required"
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
llm_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
final_answer = ""
|
||||
prompt_messages: list = [] # Initialize prompt_messages
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage):
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict["usage"]
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.total_tokens += usage.total_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
llm_usage.completion_price += usage.completion_price
|
||||
llm_usage.total_price += usage.total_price
|
||||
|
||||
model_instance = self.model_instance
|
||||
|
||||
while function_call_state and iteration_step <= max_iteration_steps:
|
||||
function_call_state = False
|
||||
|
||||
if iteration_step == max_iteration_steps:
|
||||
# the last iteration, remove all tools
|
||||
prompt_messages_tools = []
|
||||
|
||||
message_file_ids: list[str] = []
|
||||
agent_thought_id = self.create_agent_thought(
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
tools=prompt_messages_tools,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=self.stream_tool_call,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
|
||||
# save full response
|
||||
response = ""
|
||||
|
||||
# save tool call names and inputs
|
||||
tool_call_names = ""
|
||||
tool_call_inputs = ""
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
if isinstance(chunks, Generator):
|
||||
is_first_chunk = True
|
||||
for chunk in chunks:
|
||||
if is_first_chunk:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
is_first_chunk = False
|
||||
# check if there is any tool call
|
||||
if self.check_tool_calls(chunk):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_tool_calls(chunk) or [])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except TypeError:
|
||||
# fallback: force ASCII to handle non-serializable objects
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
if isinstance(chunk.delta.message.content, list):
|
||||
for content in chunk.delta.message.content:
|
||||
response += content.data
|
||||
else:
|
||||
response += str(chunk.delta.message.content)
|
||||
|
||||
if chunk.delta.usage:
|
||||
increase_usage(llm_usage, chunk.delta.usage)
|
||||
current_llm_usage = chunk.delta.usage
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
result = chunks
|
||||
# check if there is any tool call
|
||||
if self.check_blocking_tool_calls(result):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_blocking_tool_calls(result) or [])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except TypeError:
|
||||
# fallback: force ASCII to handle non-serializable objects
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
if result.usage:
|
||||
increase_usage(llm_usage, result.usage)
|
||||
current_llm_usage = result.usage
|
||||
|
||||
if result.message and result.message.content:
|
||||
if isinstance(result.message.content, list):
|
||||
for content in result.message.content:
|
||||
response += content.data
|
||||
else:
|
||||
response += str(result.message.content)
|
||||
|
||||
if not result.message.content:
|
||||
result.message.content = ""
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
system_fingerprint=result.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=result.message,
|
||||
usage=result.usage,
|
||||
),
|
||||
)
|
||||
|
||||
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
|
||||
if tool_calls:
|
||||
assistant_message.tool_calls = [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call[0],
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
|
||||
),
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
else:
|
||||
assistant_message.content = response
|
||||
|
||||
self._current_thoughts.append(assistant_message)
|
||||
|
||||
# save thought
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name=tool_call_names,
|
||||
tool_input=tool_call_inputs,
|
||||
thought=response,
|
||||
tool_invoke_meta=None,
|
||||
observation=None,
|
||||
answer=response,
|
||||
messages_ids=[],
|
||||
llm_usage=current_llm_usage,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
final_answer += response + "\n"
|
||||
|
||||
# Check if max iteration is reached and model still wants to call tools
|
||||
if iteration_step == max_iteration_steps and tool_calls:
|
||||
raise AgentMaxIterationError(app_config.agent.max_iteration)
|
||||
|
||||
# call tools
|
||||
tool_responses = []
|
||||
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
|
||||
tool_instance = tool_instances.get(tool_call_name)
|
||||
if not tool_instance:
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": f"there is not a tool named {tool_call_name}",
|
||||
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
|
||||
}
|
||||
else:
|
||||
# invoke tool
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_call_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=self.message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
message_id=self.message.id,
|
||||
conversation_id=self.conversation.id,
|
||||
)
|
||||
# publish files
|
||||
for message_file_id in message_files:
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": tool_invoke_response,
|
||||
"meta": tool_invoke_meta.to_dict(),
|
||||
}
|
||||
|
||||
tool_responses.append(tool_response)
|
||||
if tool_response["tool_response"] is not None:
|
||||
self._current_thoughts.append(
|
||||
ToolPromptMessage(
|
||||
content=str(tool_response["tool_response"]),
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_call_name,
|
||||
)
|
||||
)
|
||||
|
||||
if len(tool_responses) > 0:
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name="",
|
||||
tool_input="",
|
||||
thought="",
|
||||
tool_invoke_meta={
|
||||
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
|
||||
},
|
||||
observation={
|
||||
tool_response["tool_call_name"]: tool_response["tool_response"]
|
||||
for tool_response in tool_responses
|
||||
},
|
||||
answer="",
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# update prompt tool
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
|
||||
"""
|
||||
Check if there is any tool call in llm result chunk
|
||||
"""
|
||||
if llm_result_chunk.delta.message.tool_calls:
|
||||
return True
|
||||
return False
|
||||
|
||||
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
|
||||
"""
|
||||
Check if there is any blocking tool call in llm result
|
||||
"""
|
||||
if llm_result.message.tool_calls:
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
|
||||
"""
|
||||
tool_calls = []
|
||||
for prompt_message in llm_result_chunk.delta.message.tool_calls:
|
||||
args = {}
|
||||
if prompt_message.function.arguments != "":
|
||||
args = json.loads(prompt_message.function.arguments)
|
||||
|
||||
tool_calls.append(
|
||||
(
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""
|
||||
Extract blocking tool calls from llm result
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
|
||||
"""
|
||||
tool_calls = []
|
||||
for prompt_message in llm_result.message.tool_calls:
|
||||
args = {}
|
||||
if prompt_message.function.arguments != "":
|
||||
args = json.loads(prompt_message.function.arguments)
|
||||
|
||||
tool_calls.append(
|
||||
(
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
if not prompt_messages and prompt_template:
|
||||
return [
|
||||
SystemPromptMessage(content=prompt_template),
|
||||
]
|
||||
|
||||
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
|
||||
return prompt_messages or []
|
||||
|
||||
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
self.application_generate_entity.file_upload_config
|
||||
and self.application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
We need to remove the image messages from the prompt messages at the first iteration.
|
||||
"""
|
||||
prompt_messages = deepcopy(prompt_messages)
|
||||
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = "\n".join(
|
||||
[
|
||||
content.data
|
||||
if content.type == PromptMessageContentType.TEXT
|
||||
else "[image]"
|
||||
if content.type == PromptMessageContentType.IMAGE
|
||||
else "[file]"
|
||||
for content in prompt_message.content
|
||||
]
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query or "", [])
|
||||
|
||||
self.history_prompt_messages = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
||||
history_messages=self.history_prompt_messages,
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
|
||||
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
|
||||
if len(self._current_thoughts) != 0:
|
||||
# clear messages after the first iteration
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
return prompt_messages
|
||||
55
api/core/agent/patterns/README.md
Normal file
55
api/core/agent/patterns/README.md
Normal file
@ -0,0 +1,55 @@
|
||||
# Agent Patterns
|
||||
|
||||
A unified agent pattern module that powers both Agent V2 workflow nodes and agent applications. Strategies share a common execution contract while adapting to model capabilities and tool availability.
|
||||
|
||||
## Overview
|
||||
|
||||
The module applies a strategy pattern around LLM/tool orchestration. `StrategyFactory` auto-selects the best implementation based on model features or an explicit agent strategy, and each strategy streams logs and usage consistently.
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Dual strategies**
|
||||
- `FunctionCallStrategy`: uses native LLM function/tool calling when the model exposes `TOOL_CALL`, `MULTI_TOOL_CALL`, or `STREAM_TOOL_CALL`.
|
||||
- `ReActStrategy`: ReAct (reasoning + acting) flow driven by `CotAgentOutputParser`, used when function calling is unavailable or explicitly requested.
|
||||
- **Explicit or auto selection**
|
||||
- `StrategyFactory.create_strategy` prefers an explicit `AgentEntity.Strategy` (FUNCTION_CALLING or CHAIN_OF_THOUGHT).
|
||||
- Otherwise it falls back to function calling when tool-call features exist, or ReAct when they do not.
|
||||
- **Unified execution contract**
|
||||
- `AgentPattern.run` yields streaming `AgentLog` entries and `LLMResultChunk` data, returning an `AgentResult` with text, files, usage, and `finish_reason`.
|
||||
- Iterations are configurable and hard-capped at 99 rounds; the last round forces a final answer by withholding tools.
|
||||
- **Tool handling and hooks**
|
||||
- Tools convert to `PromptMessageTool` objects before invocation.
|
||||
- Optional `tool_invoke_hook` lets callers override tool execution (e.g., agent apps) while workflow runs use `ToolEngine.generic_invoke`.
|
||||
- Tool outputs support text, links, JSON, variables, blobs, retriever resources, and file attachments; `target=="self"` files are reloaded into model context, others are returned as outputs.
|
||||
- **File-aware arguments**
|
||||
- Tool args accept `[File: <id>]` or `[Files: <id1, id2>]` placeholders that resolve to `File` objects before invocation, enabling models to reference uploaded files safely.
|
||||
- **ReAct prompt shaping**
|
||||
- System prompts replace `{{instruction}}`, `{{tools}}`, and `{{tool_names}}` placeholders.
|
||||
- Adds `Observation` to stop sequences and appends scratchpad text so the model sees prior Thought/Action/Observation history.
|
||||
- **Observability and accounting**
|
||||
- Standardized `AgentLog` entries for rounds, model thoughts, and tool calls, including usage aggregation (`LLMUsage`) across streaming and non-streaming paths.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
agent/patterns/
|
||||
├── base.py # Shared utilities: logging, usage, tool invocation, file handling
|
||||
├── function_call.py # Native function-calling loop with tool execution
|
||||
├── react.py # ReAct loop with CoT parsing and scratchpad wiring
|
||||
└── strategy_factory.py # Strategy selection by model features or explicit override
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
- For auto-selection:
|
||||
- Call `StrategyFactory.create_strategy(model_features, model_instance, context, tools, files, ...)` and run the returned strategy with prompt messages and model params.
|
||||
- For explicit behavior:
|
||||
- Pass `agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING` to force native calls (falls back to ReAct if unsupported), or `CHAIN_OF_THOUGHT` to force ReAct.
|
||||
- Both strategies stream chunks and logs; collect the generator output until it returns an `AgentResult`.
|
||||
|
||||
## Integration Points
|
||||
|
||||
- **Model runtime**: delegates to `ModelInstance.invoke_llm` for both streaming and non-streaming calls.
|
||||
- **Tool system**: defaults to `ToolEngine.generic_invoke`, with `tool_invoke_hook` for custom callers.
|
||||
- **Files**: flows through `File` objects for tool inputs/outputs and model-context attachments.
|
||||
- **Execution context**: `ExecutionContext` fields (user/app/conversation/message) propagate to tool invocations and logging.
|
||||
19
api/core/agent/patterns/__init__.py
Normal file
19
api/core/agent/patterns/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
"""Agent patterns module.
|
||||
|
||||
This module provides different strategies for agent execution:
|
||||
- FunctionCallStrategy: Uses native function/tool calling
|
||||
- ReActStrategy: Uses ReAct (Reasoning + Acting) approach
|
||||
- StrategyFactory: Factory for creating strategies based on model features
|
||||
"""
|
||||
|
||||
from .base import AgentPattern
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
from .strategy_factory import StrategyFactory
|
||||
|
||||
__all__ = [
|
||||
"AgentPattern",
|
||||
"FunctionCallStrategy",
|
||||
"ReActStrategy",
|
||||
"StrategyFactory",
|
||||
]
|
||||
474
api/core/agent/patterns/base.py
Normal file
474
api/core/agent/patterns/base.py
Normal file
@ -0,0 +1,474 @@
|
||||
"""Base class for agent strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
|
||||
from core.file import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMeta
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
# Type alias for tool invoke hook
|
||||
# Returns: (response_content, message_file_ids, tool_invoke_meta)
|
||||
ToolInvokeHook = Callable[["Tool", dict[str, Any], str], tuple[str, list[str], ToolInvokeMeta]]
|
||||
|
||||
|
||||
class AgentPattern(ABC):
|
||||
"""Base class for agent execution strategies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
tools: list[Tool],
|
||||
context: ExecutionContext,
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
files: list[File] = [],
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
):
|
||||
"""Initialize the agent strategy."""
|
||||
self.model_instance = model_instance
|
||||
self.tools = tools
|
||||
self.context = context
|
||||
self.max_iterations = min(max_iterations, 99) # Cap at 99 iterations
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.files: list[File] = files
|
||||
self.tool_invoke_hook = tool_invoke_hook
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the agent strategy."""
|
||||
pass
|
||||
|
||||
def _accumulate_usage(self, total_usage: dict[str, Any], delta_usage: LLMUsage) -> None:
|
||||
"""Accumulate LLM usage statistics."""
|
||||
if not total_usage.get("usage"):
|
||||
# Create a copy to avoid modifying the original
|
||||
total_usage["usage"] = LLMUsage(
|
||||
prompt_tokens=delta_usage.prompt_tokens,
|
||||
prompt_unit_price=delta_usage.prompt_unit_price,
|
||||
prompt_price_unit=delta_usage.prompt_price_unit,
|
||||
prompt_price=delta_usage.prompt_price,
|
||||
completion_tokens=delta_usage.completion_tokens,
|
||||
completion_unit_price=delta_usage.completion_unit_price,
|
||||
completion_price_unit=delta_usage.completion_price_unit,
|
||||
completion_price=delta_usage.completion_price,
|
||||
total_tokens=delta_usage.total_tokens,
|
||||
total_price=delta_usage.total_price,
|
||||
currency=delta_usage.currency,
|
||||
latency=delta_usage.latency,
|
||||
)
|
||||
else:
|
||||
current: LLMUsage = total_usage["usage"]
|
||||
current.prompt_tokens += delta_usage.prompt_tokens
|
||||
current.completion_tokens += delta_usage.completion_tokens
|
||||
current.total_tokens += delta_usage.total_tokens
|
||||
current.prompt_price += delta_usage.prompt_price
|
||||
current.completion_price += delta_usage.completion_price
|
||||
current.total_price += delta_usage.total_price
|
||||
|
||||
def _extract_content(self, content: Any) -> str:
|
||||
"""Extract text content from message content."""
|
||||
if isinstance(content, list):
|
||||
# Content items are PromptMessageContentUnionTypes
|
||||
text_parts = []
|
||||
for c in content:
|
||||
# Check if it's a TextPromptMessageContent (which has data attribute)
|
||||
if isinstance(c, TextPromptMessageContent):
|
||||
text_parts.append(c.data)
|
||||
return "".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
def _has_tool_calls(self, chunk: LLMResultChunk) -> bool:
|
||||
"""Check if chunk contains tool calls."""
|
||||
# LLMResultChunk always has delta attribute
|
||||
return bool(chunk.delta.message and chunk.delta.message.tool_calls)
|
||||
|
||||
def _has_tool_calls_result(self, result: LLMResult) -> bool:
|
||||
"""Check if result contains tool calls (non-streaming)."""
|
||||
# LLMResult always has message attribute
|
||||
return bool(result.message and result.message.tool_calls)
|
||||
|
||||
def _extract_tool_calls(self, chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""Extract tool calls from streaming chunk."""
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
if chunk.delta.message and chunk.delta.message.tool_calls:
|
||||
for tool_call in chunk.delta.message.tool_calls:
|
||||
if tool_call.function:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
|
||||
return tool_calls
|
||||
|
||||
def _extract_tool_calls_result(self, result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""Extract tool calls from non-streaming result."""
|
||||
tool_calls = []
|
||||
if result.message and result.message.tool_calls:
|
||||
for tool_call in result.message.tool_calls:
|
||||
if tool_call.function:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
|
||||
return tool_calls
|
||||
|
||||
def _extract_text_from_message(self, message: PromptMessage) -> str:
|
||||
"""Extract text content from a prompt message."""
|
||||
# PromptMessage always has content attribute
|
||||
content = message.content
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
# Extract text from content list
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
text_parts.append(item.data)
|
||||
return " ".join(text_parts)
|
||||
return ""
|
||||
|
||||
def _get_tool_metadata(self, tool_instance: Tool) -> dict[AgentLog.LogMetadata, Any]:
|
||||
"""Get metadata for a tool including provider and icon info."""
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
metadata: dict[AgentLog.LogMetadata, Any] = {}
|
||||
if tool_instance.entity and tool_instance.entity.identity:
|
||||
identity = tool_instance.entity.identity
|
||||
if identity.provider:
|
||||
metadata[AgentLog.LogMetadata.PROVIDER] = identity.provider
|
||||
|
||||
# Get icon using ToolManager for proper URL generation
|
||||
tenant_id = self.context.tenant_id
|
||||
if tenant_id and identity.provider:
|
||||
try:
|
||||
provider_type = tool_instance.tool_provider_type()
|
||||
icon = ToolManager.get_tool_icon(tenant_id, provider_type, identity.provider)
|
||||
if isinstance(icon, str):
|
||||
metadata[AgentLog.LogMetadata.ICON] = icon
|
||||
elif isinstance(icon, dict):
|
||||
# Handle icon dict with background/content or light/dark variants
|
||||
metadata[AgentLog.LogMetadata.ICON] = icon
|
||||
except Exception:
|
||||
# Fallback to identity.icon if ToolManager fails
|
||||
if identity.icon:
|
||||
metadata[AgentLog.LogMetadata.ICON] = identity.icon
|
||||
elif identity.icon:
|
||||
metadata[AgentLog.LogMetadata.ICON] = identity.icon
|
||||
return metadata
|
||||
|
||||
def _create_log(
|
||||
self,
|
||||
label: str,
|
||||
log_type: AgentLog.LogType,
|
||||
status: AgentLog.LogStatus,
|
||||
data: dict[str, Any] | None = None,
|
||||
parent_id: str | None = None,
|
||||
extra_metadata: dict[AgentLog.LogMetadata, Any] | None = None,
|
||||
) -> AgentLog:
|
||||
"""Create a new AgentLog with standard metadata."""
|
||||
metadata: dict[AgentLog.LogMetadata, Any] = {
|
||||
AgentLog.LogMetadata.STARTED_AT: time.perf_counter(),
|
||||
}
|
||||
if extra_metadata:
|
||||
metadata.update(extra_metadata)
|
||||
|
||||
return AgentLog(
|
||||
label=label,
|
||||
log_type=log_type,
|
||||
status=status,
|
||||
data=data or {},
|
||||
parent_id=parent_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _finish_log(
|
||||
self,
|
||||
log: AgentLog,
|
||||
data: dict[str, Any] | None = None,
|
||||
usage: LLMUsage | None = None,
|
||||
) -> AgentLog:
|
||||
"""Finish an AgentLog by updating its status and metadata."""
|
||||
log.status = AgentLog.LogStatus.SUCCESS
|
||||
|
||||
if data is not None:
|
||||
log.data = data
|
||||
|
||||
# Calculate elapsed time
|
||||
started_at = log.metadata.get(AgentLog.LogMetadata.STARTED_AT, time.perf_counter())
|
||||
finished_at = time.perf_counter()
|
||||
|
||||
# Update metadata
|
||||
log.metadata = {
|
||||
**log.metadata,
|
||||
AgentLog.LogMetadata.FINISHED_AT: finished_at,
|
||||
# Calculate elapsed time in seconds
|
||||
AgentLog.LogMetadata.ELAPSED_TIME: round(finished_at - started_at, 4),
|
||||
}
|
||||
|
||||
# Add usage information if provided
|
||||
if usage:
|
||||
log.metadata.update(
|
||||
{
|
||||
AgentLog.LogMetadata.TOTAL_PRICE: usage.total_price,
|
||||
AgentLog.LogMetadata.CURRENCY: usage.currency,
|
||||
AgentLog.LogMetadata.TOTAL_TOKENS: usage.total_tokens,
|
||||
AgentLog.LogMetadata.LLM_USAGE: usage,
|
||||
}
|
||||
)
|
||||
|
||||
return log
|
||||
|
||||
def _replace_file_references(self, tool_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Replace file references in tool arguments with actual File objects.
|
||||
|
||||
Args:
|
||||
tool_args: Dictionary of tool arguments
|
||||
|
||||
Returns:
|
||||
Updated tool arguments with file references replaced
|
||||
"""
|
||||
# Process each argument in the dictionary
|
||||
processed_args: dict[str, Any] = {}
|
||||
for key, value in tool_args.items():
|
||||
processed_args[key] = self._process_file_reference(value)
|
||||
return processed_args
|
||||
|
||||
def _process_file_reference(self, data: Any) -> Any:
|
||||
"""
|
||||
Recursively process data to replace file references.
|
||||
Supports both single file [File: file_id] and multiple files [Files: file_id1, file_id2, ...].
|
||||
|
||||
Args:
|
||||
data: The data to process (can be dict, list, str, or other types)
|
||||
|
||||
Returns:
|
||||
Processed data with file references replaced
|
||||
"""
|
||||
single_file_pattern = re.compile(r"^\[File:\s*([^\]]+)\]$")
|
||||
multiple_files_pattern = re.compile(r"^\[Files:\s*([^\]]+)\]$")
|
||||
|
||||
if isinstance(data, dict):
|
||||
# Process dictionary recursively
|
||||
return {key: self._process_file_reference(value) for key, value in data.items()}
|
||||
elif isinstance(data, list):
|
||||
# Process list recursively
|
||||
return [self._process_file_reference(item) for item in data]
|
||||
elif isinstance(data, str):
|
||||
# Check for single file pattern [File: file_id]
|
||||
single_match = single_file_pattern.match(data.strip())
|
||||
if single_match:
|
||||
file_id = single_match.group(1).strip()
|
||||
# Find the file in self.files
|
||||
for file in self.files:
|
||||
if file.id and str(file.id) == file_id:
|
||||
return file
|
||||
# If file not found, return original value
|
||||
return data
|
||||
|
||||
# Check for multiple files pattern [Files: file_id1, file_id2, ...]
|
||||
multiple_match = multiple_files_pattern.match(data.strip())
|
||||
if multiple_match:
|
||||
file_ids_str = multiple_match.group(1).strip()
|
||||
# Split by comma and strip whitespace
|
||||
file_ids = [fid.strip() for fid in file_ids_str.split(",")]
|
||||
|
||||
# Find all matching files
|
||||
matched_files: list[File] = []
|
||||
for file_id in file_ids:
|
||||
for file in self.files:
|
||||
if file.id and str(file.id) == file_id:
|
||||
matched_files.append(file)
|
||||
break
|
||||
|
||||
# Return list of files if any were found, otherwise return original
|
||||
return matched_files or data
|
||||
|
||||
return data
|
||||
else:
|
||||
# Return other types as-is
|
||||
return data
|
||||
|
||||
def _create_text_chunk(self, text: str, prompt_messages: list[PromptMessage]) -> LLMResultChunk:
|
||||
"""Create a text chunk for streaming."""
|
||||
return LLMResultChunk(
|
||||
model=self.model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=None,
|
||||
),
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
def _invoke_tool(
|
||||
self,
|
||||
tool_instance: Tool,
|
||||
tool_args: dict[str, Any],
|
||||
tool_name: str,
|
||||
) -> tuple[str, list[File], ToolInvokeMeta | None]:
|
||||
"""
|
||||
Invoke a tool and collect its response.
|
||||
|
||||
Args:
|
||||
tool_instance: The tool instance to invoke
|
||||
tool_args: Tool arguments
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
Tuple of (response_content, tool_files, tool_invoke_meta)
|
||||
"""
|
||||
# Process tool_args to replace file references with actual File objects
|
||||
tool_args = self._replace_file_references(tool_args)
|
||||
|
||||
# If a tool invoke hook is set, use it instead of generic_invoke
|
||||
if self.tool_invoke_hook:
|
||||
response_content, _, tool_invoke_meta = self.tool_invoke_hook(tool_instance, tool_args, tool_name)
|
||||
# Note: message_file_ids are stored in DB, we don't convert them to File objects here
|
||||
# The caller (AgentAppRunner) handles file publishing
|
||||
return response_content, [], tool_invoke_meta
|
||||
|
||||
# Default: use generic_invoke for workflow scenarios
|
||||
# Import here to avoid circular import
|
||||
from core.tools.tool_engine import DifyWorkflowCallbackHandler, ToolEngine
|
||||
|
||||
tool_response = ToolEngine().generic_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_args,
|
||||
user_id=self.context.user_id or "",
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
app_id=self.context.app_id,
|
||||
conversation_id=self.context.conversation_id,
|
||||
message_id=self.context.message_id,
|
||||
)
|
||||
|
||||
# Collect response and files
|
||||
response_content = ""
|
||||
tool_files: list[File] = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
|
||||
response_content += response.message.text
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# Handle link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Link: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
# Handle image URL messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Image: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK:
|
||||
# Handle image link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Image: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BINARY_LINK:
|
||||
# Handle binary file link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
filename = response.meta.get("filename", "file") if response.meta else "file"
|
||||
response_content += f"[File: {filename} - {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
# Handle JSON messages
|
||||
if isinstance(response.message, ToolInvokeMessage.JsonMessage):
|
||||
response_content += json.dumps(response.message.json_object, ensure_ascii=False, indent=2)
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# Handle blob messages - convert to text representation
|
||||
if isinstance(response.message, ToolInvokeMessage.BlobMessage):
|
||||
mime_type = (
|
||||
response.meta.get("mime_type", "application/octet-stream")
|
||||
if response.meta
|
||||
else "application/octet-stream"
|
||||
)
|
||||
size = len(response.message.blob)
|
||||
response_content += f"[Binary data: {mime_type}, size: {size} bytes]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
# Handle variable messages
|
||||
if isinstance(response.message, ToolInvokeMessage.VariableMessage):
|
||||
var_name = response.message.variable_name
|
||||
var_value = response.message.variable_value
|
||||
if isinstance(var_value, str):
|
||||
response_content += var_value
|
||||
else:
|
||||
response_content += f"[Variable {var_name}: {json.dumps(var_value, ensure_ascii=False)}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
|
||||
# Handle blob chunk messages - these are parts of a larger blob
|
||||
if isinstance(response.message, ToolInvokeMessage.BlobChunkMessage):
|
||||
response_content += f"[Blob chunk {response.message.sequence}: {len(response.message.blob)} bytes]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
|
||||
# Handle retriever resources messages
|
||||
if isinstance(response.message, ToolInvokeMessage.RetrieverResourceMessage):
|
||||
response_content += response.message.context
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.FILE:
|
||||
# Extract file from meta
|
||||
if response.meta and "file" in response.meta:
|
||||
file = response.meta["file"]
|
||||
if isinstance(file, File):
|
||||
# Check if file is for model or tool output
|
||||
if response.meta.get("target") == "self":
|
||||
# File is for model - add to files for next prompt
|
||||
self.files.append(file)
|
||||
response_content += f"File '{file.filename}' has been loaded into your context."
|
||||
else:
|
||||
# File is tool output
|
||||
tool_files.append(file)
|
||||
|
||||
return response_content, tool_files, None
|
||||
|
||||
def _find_tool_by_name(self, tool_name: str) -> Tool | None:
|
||||
"""Find a tool instance by its name."""
|
||||
for tool in self.tools:
|
||||
if tool.entity.identity.name == tool_name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
def _convert_tools_to_prompt_format(self) -> list[PromptMessageTool]:
|
||||
"""Convert tools to prompt message format."""
|
||||
prompt_tools: list[PromptMessageTool] = []
|
||||
for tool in self.tools:
|
||||
prompt_tools.append(tool.to_prompt_message_tool())
|
||||
return prompt_tools
|
||||
|
||||
def _update_usage_with_empty(self, llm_usage: dict[str, Any]) -> None:
|
||||
"""Initialize usage tracking with empty usage if not set."""
|
||||
if "usage" not in llm_usage or llm_usage["usage"] is None:
|
||||
llm_usage["usage"] = LLMUsage.empty_usage()
|
||||
299
api/core/agent/patterns/function_call.py
Normal file
299
api/core/agent/patterns/function_call.py
Normal file
@ -0,0 +1,299 @@
|
||||
"""Function Call strategy implementation."""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
|
||||
from .base import AgentPattern
|
||||
|
||||
|
||||
class FunctionCallStrategy(AgentPattern):
|
||||
"""Function Call strategy using model's native tool calling capability."""
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the function call agent strategy."""
|
||||
# Convert tools to prompt format
|
||||
prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format()
|
||||
|
||||
# Initialize tracking
|
||||
iteration_step: int = 1
|
||||
max_iterations: int = self.max_iterations + 1
|
||||
function_call_state: bool = True
|
||||
total_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
|
||||
while function_call_state and iteration_step <= max_iterations:
|
||||
function_call_state = False
|
||||
round_log = self._create_log(
|
||||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
# On last iteration, remove tools to force final answer
|
||||
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata={
|
||||
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
|
||||
},
|
||||
)
|
||||
yield model_log
|
||||
|
||||
# Track usage for this round only
|
||||
round_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=current_tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=self.context.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log
|
||||
)
|
||||
messages.append(self._create_assistant_message(response_content, tool_calls))
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update final text if no tool calls (this is likely the final answer)
|
||||
if not tool_calls:
|
||||
final_text = response_content
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Process tool calls
|
||||
tool_outputs: dict[str, str] = {}
|
||||
if tool_calls:
|
||||
function_call_state = True
|
||||
# Execute tools
|
||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||
tool_response, tool_files, _ = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
)
|
||||
tool_outputs[tool_name] = tool_response
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
"llm_result": response_content,
|
||||
"tool_calls": [
|
||||
{"name": tc[1], "args": tc[2], "output": tool_outputs.get(tc[1], "")} for tc in tool_calls
|
||||
]
|
||||
if tool_calls
|
||||
else [],
|
||||
"final_answer": final_text if not function_call_state else None,
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
text=final_text,
|
||||
files=output_files,
|
||||
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, LLMUsage | None],
|
||||
start_log: AgentLog,
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
tuple[list[tuple[str, str, dict[str, Any]]], str, str | None],
|
||||
]:
|
||||
"""Handle LLM response chunks and extract tool calls and content.
|
||||
|
||||
Returns a tuple of (tool_calls, response_content, finish_reason).
|
||||
"""
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
response_content: str = ""
|
||||
finish_reason: str | None = None
|
||||
if isinstance(chunks, Generator):
|
||||
# Streaming response
|
||||
for chunk in chunks:
|
||||
# Extract tool calls
|
||||
if self._has_tool_calls(chunk):
|
||||
tool_calls.extend(self._extract_tool_calls(chunk))
|
||||
|
||||
# Extract content
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
response_content += self._extract_content(chunk.delta.message.content)
|
||||
|
||||
# Track usage
|
||||
if chunk.delta.usage:
|
||||
self._accumulate_usage(llm_usage, chunk.delta.usage)
|
||||
|
||||
# Capture finish reason
|
||||
if chunk.delta.finish_reason:
|
||||
finish_reason = chunk.delta.finish_reason
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
# Non-streaming response
|
||||
result: LLMResult = chunks
|
||||
|
||||
if self._has_tool_calls_result(result):
|
||||
tool_calls.extend(self._extract_tool_calls_result(result))
|
||||
|
||||
if result.message and result.message.content:
|
||||
response_content += self._extract_content(result.message.content)
|
||||
|
||||
if result.usage:
|
||||
self._accumulate_usage(llm_usage, result.usage)
|
||||
|
||||
# Convert to streaming format
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
|
||||
)
|
||||
yield self._finish_log(
|
||||
start_log,
|
||||
data={
|
||||
"result": response_content,
|
||||
},
|
||||
usage=llm_usage.get("usage"),
|
||||
)
|
||||
return tool_calls, response_content, finish_reason
|
||||
|
||||
def _create_assistant_message(
|
||||
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
|
||||
) -> AssistantPromptMessage:
|
||||
"""Create assistant message with tool calls."""
|
||||
if tool_calls is None:
|
||||
return AssistantPromptMessage(content=content)
|
||||
return AssistantPromptMessage(
|
||||
content=content or "",
|
||||
tool_calls=[
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tc[0],
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tc[1], arguments=json.dumps(tc[2])),
|
||||
)
|
||||
for tc in tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any],
|
||||
tool_call_id: str,
|
||||
messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None]]:
|
||||
"""Handle a single tool call and return response with files and meta."""
|
||||
# Find tool
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
if not tool_instance:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
|
||||
# Get tool metadata (provider, icon, etc.)
|
||||
tool_metadata = self._get_tool_metadata(tool_instance)
|
||||
|
||||
# Create tool call log
|
||||
tool_call_log = self._create_log(
|
||||
label=f"CALL {tool_name}",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata=tool_metadata,
|
||||
)
|
||||
yield tool_call_log
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
|
||||
|
||||
yield self._finish_log(
|
||||
tool_call_log,
|
||||
data={
|
||||
**tool_call_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
final_content = response_content or "Tool executed successfully"
|
||||
# Add tool response to messages
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=final_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return response_content, tool_files, tool_invoke_meta
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_call_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_call_log.error = error_message
|
||||
tool_call_log.data = {
|
||||
**tool_call_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_call_log
|
||||
|
||||
# Add error message to conversation
|
||||
error_content = f"Tool execution failed: {error_message}"
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=error_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return error_content, [], None
|
||||
418
api/core/agent/patterns/react.py
Normal file
418
api/core/agent/patterns/react.py
Normal file
@ -0,0 +1,418 @@
|
||||
"""ReAct strategy implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.file import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
)
|
||||
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class ReActStrategy(AgentPattern):
|
||||
"""ReAct strategy using reasoning and acting approach."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
tools: list[Tool],
|
||||
context: ExecutionContext,
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
files: list[File] = [],
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
):
|
||||
"""Initialize the ReAct strategy with instruction support."""
|
||||
super().__init__(
|
||||
model_instance=model_instance,
|
||||
tools=tools,
|
||||
context=context,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
files=files,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
self.instruction = instruction
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the ReAct agent strategy."""
|
||||
# Initialize tracking
|
||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||
iteration_step: int = 1
|
||||
max_iterations: int = self.max_iterations + 1
|
||||
react_state: bool = True
|
||||
total_usage: dict[str, Any] = {"usage": None}
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Add "Observation" to stop sequences
|
||||
if "Observation" not in stop:
|
||||
stop = stop.copy()
|
||||
stop.append("Observation")
|
||||
|
||||
while react_state and iteration_step <= max_iterations:
|
||||
react_state = False
|
||||
round_log = self._create_log(
|
||||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
|
||||
# Build prompt with/without tools based on iteration
|
||||
include_tools = iteration_step < max_iterations
|
||||
current_messages = self._build_prompt_with_react_format(
|
||||
prompt_messages, agent_scratchpad, include_tools, self.instruction
|
||||
)
|
||||
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata={
|
||||
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
|
||||
},
|
||||
)
|
||||
yield model_log
|
||||
|
||||
# Track usage for this round only
|
||||
round_usage: dict[str, Any] = {"usage": None}
|
||||
|
||||
# Use current messages directly (files are handled by base class if needed)
|
||||
messages_to_use = current_messages
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages_to_use,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=self.context.user_id or "",
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log, current_messages
|
||||
)
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Check if we have an action to execute
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
react_state = True
|
||||
# Execute tool
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
)
|
||||
scratchpad.observation = observation
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
|
||||
# Add observation to scratchpad for display
|
||||
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
|
||||
else:
|
||||
# Extract final answer
|
||||
if scratchpad.action and scratchpad.action.action_input:
|
||||
final_answer = scratchpad.action.action_input
|
||||
if isinstance(final_answer, dict):
|
||||
final_answer = json.dumps(final_answer, ensure_ascii=False)
|
||||
final_text = str(final_answer)
|
||||
elif scratchpad.thought:
|
||||
# If no action but we have thought, use thought as final answer
|
||||
final_text = scratchpad.thought
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
"thought": scratchpad.thought,
|
||||
"action": scratchpad.action_str if scratchpad.action else None,
|
||||
"observation": scratchpad.observation or None,
|
||||
"final_answer": final_text if not react_state else None,
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
|
||||
)
|
||||
|
||||
def _build_prompt_with_react_format(
|
||||
self,
|
||||
original_messages: list[PromptMessage],
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
include_tools: bool = True,
|
||||
instruction: str = "",
|
||||
) -> list[PromptMessage]:
|
||||
"""Build prompt messages with ReAct format."""
|
||||
# Copy messages to avoid modifying original
|
||||
messages = list(original_messages)
|
||||
|
||||
# Find and update the system prompt that should already exist
|
||||
system_prompt_found = False
|
||||
for i, msg in enumerate(messages):
|
||||
if isinstance(msg, SystemPromptMessage):
|
||||
system_prompt_found = True
|
||||
# The system prompt from frontend already has the template, just replace placeholders
|
||||
|
||||
# Format tools
|
||||
tools_str = ""
|
||||
tool_names = []
|
||||
if include_tools and self.tools:
|
||||
# Convert tools to prompt message tools format
|
||||
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
|
||||
tool_names = [tool.name for tool in prompt_tools]
|
||||
|
||||
# Format tools as JSON for comprehensive information
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
tools_str = json.dumps(jsonable_encoder(prompt_tools), indent=2)
|
||||
tool_names_str = ", ".join(f'"{name}"' for name in tool_names)
|
||||
else:
|
||||
tools_str = "No tools available"
|
||||
tool_names_str = ""
|
||||
|
||||
# Replace placeholders in the existing system prompt
|
||||
updated_content = msg.content
|
||||
assert isinstance(updated_content, str)
|
||||
updated_content = updated_content.replace("{{instruction}}", instruction)
|
||||
updated_content = updated_content.replace("{{tools}}", tools_str)
|
||||
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
|
||||
|
||||
# Create new SystemPromptMessage with updated content
|
||||
messages[i] = SystemPromptMessage(content=updated_content)
|
||||
break
|
||||
|
||||
# If no system prompt found, that's unexpected but add scratchpad anyway
|
||||
if not system_prompt_found:
|
||||
# This shouldn't happen if frontend is working correctly
|
||||
pass
|
||||
|
||||
# Format agent scratchpad
|
||||
scratchpad_str = ""
|
||||
if agent_scratchpad:
|
||||
scratchpad_parts: list[str] = []
|
||||
for unit in agent_scratchpad:
|
||||
if unit.thought:
|
||||
scratchpad_parts.append(f"Thought: {unit.thought}")
|
||||
if unit.action_str:
|
||||
scratchpad_parts.append(f"Action:\n```\n{unit.action_str}\n```")
|
||||
if unit.observation:
|
||||
scratchpad_parts.append(f"Observation: {unit.observation}")
|
||||
scratchpad_str = "\n".join(scratchpad_parts)
|
||||
|
||||
# If there's a scratchpad, append it to the last message
|
||||
if scratchpad_str:
|
||||
messages.append(AssistantPromptMessage(content=scratchpad_str))
|
||||
|
||||
return messages
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, Any],
|
||||
model_log: AgentLog,
|
||||
current_messages: list[PromptMessage],
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
tuple[AgentScratchpadUnit, str | None],
|
||||
]:
|
||||
"""Handle LLM response chunks and extract action/thought.
|
||||
|
||||
Returns a tuple of (scratchpad_unit, finish_reason).
|
||||
"""
|
||||
usage_dict: dict[str, Any] = {}
|
||||
|
||||
# Convert non-streaming to streaming format if needed
|
||||
if isinstance(chunks, LLMResult):
|
||||
# Create a generator from the LLMResult
|
||||
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
model=chunks.model,
|
||||
prompt_messages=chunks.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=chunks.message,
|
||||
usage=chunks.usage,
|
||||
finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
|
||||
),
|
||||
system_fingerprint=chunks.system_fingerprint or "",
|
||||
)
|
||||
|
||||
streaming_chunks = result_to_chunks()
|
||||
else:
|
||||
streaming_chunks = chunks
|
||||
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
|
||||
|
||||
# Initialize scratchpad unit
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
thought="",
|
||||
action_str="",
|
||||
observation="",
|
||||
action=None,
|
||||
)
|
||||
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Process chunks
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
# Action detected
|
||||
action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + action_str
|
||||
scratchpad.action_str = action_str
|
||||
scratchpad.action = chunk
|
||||
|
||||
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
||||
else:
|
||||
# Text chunk
|
||||
chunk_text = str(chunk)
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
|
||||
scratchpad.thought = (scratchpad.thought or "") + chunk_text
|
||||
|
||||
yield self._create_text_chunk(chunk_text, current_messages)
|
||||
|
||||
# Update usage
|
||||
if usage_dict.get("usage"):
|
||||
if llm_usage.get("usage"):
|
||||
self._accumulate_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
llm_usage["usage"] = usage_dict["usage"]
|
||||
|
||||
# Clean up thought
|
||||
scratchpad.thought = (scratchpad.thought or "").strip() or "I am thinking about how to help you"
|
||||
|
||||
# Finish model log
|
||||
yield self._finish_log(
|
||||
model_log,
|
||||
data={
|
||||
"thought": scratchpad.thought,
|
||||
"action": scratchpad.action_str if scratchpad.action else None,
|
||||
},
|
||||
usage=llm_usage.get("usage"),
|
||||
)
|
||||
|
||||
return scratchpad, finish_reason
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
prompt_messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File]]]:
|
||||
"""Handle tool call and return observation with files."""
|
||||
tool_name = action.action_name
|
||||
tool_args: dict[str, Any] | str = action.action_input
|
||||
|
||||
# Find tool instance first to get metadata
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
tool_metadata = self._get_tool_metadata(tool_instance) if tool_instance else {}
|
||||
|
||||
# Start tool log with tool metadata
|
||||
tool_log = self._create_log(
|
||||
label=f"CALL {tool_name}",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata=tool_metadata,
|
||||
)
|
||||
yield tool_log
|
||||
|
||||
if not tool_instance:
|
||||
# Finish tool log with error
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"error": f"Tool {tool_name} not found",
|
||||
},
|
||||
)
|
||||
return f"Tool {tool_name} not found", []
|
||||
|
||||
# Ensure tool_args is a dict
|
||||
tool_args_dict: dict[str, Any]
|
||||
if isinstance(tool_args, str):
|
||||
try:
|
||||
tool_args_dict = json.loads(tool_args)
|
||||
except json.JSONDecodeError:
|
||||
tool_args_dict = {"input": tool_args}
|
||||
elif not isinstance(tool_args, dict):
|
||||
tool_args_dict = {"input": str(tool_args)}
|
||||
else:
|
||||
tool_args_dict = tool_args
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name)
|
||||
|
||||
# Finish tool log
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
|
||||
return response_content or "Tool executed successfully", tool_files
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_log.error = error_message
|
||||
tool_log.data = {
|
||||
**tool_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_log
|
||||
|
||||
return f"Tool execution failed: {error_message}", []
|
||||
107
api/core/agent/patterns/strategy_factory.py
Normal file
107
api/core/agent/patterns/strategy_factory.py
Normal file
@ -0,0 +1,107 @@
|
||||
"""Strategy factory for creating agent strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.agent.entities import AgentEntity, ExecutionContext
|
||||
from core.file.models import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class StrategyFactory:
|
||||
"""Factory for creating agent strategies based on model features."""
|
||||
|
||||
# Tool calling related features
|
||||
TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL}
|
||||
|
||||
@staticmethod
|
||||
def create_strategy(
|
||||
model_features: list[ModelFeature],
|
||||
model_instance: ModelInstance,
|
||||
context: ExecutionContext,
|
||||
tools: list[Tool],
|
||||
files: list[File],
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
agent_strategy: AgentEntity.Strategy | None = None,
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
) -> AgentPattern:
|
||||
"""
|
||||
Create an appropriate strategy based on model features.
|
||||
|
||||
Args:
|
||||
model_features: List of model features/capabilities
|
||||
model_instance: Model instance to use
|
||||
context: Execution context containing trace/audit information
|
||||
tools: Available tools
|
||||
files: Available files
|
||||
max_iterations: Maximum iterations for the strategy
|
||||
workflow_call_depth: Depth of workflow calls
|
||||
agent_strategy: Optional explicit strategy override
|
||||
tool_invoke_hook: Optional hook for custom tool invocation (e.g., agent_invoke)
|
||||
instruction: Optional instruction for ReAct strategy
|
||||
|
||||
Returns:
|
||||
AgentStrategy instance
|
||||
"""
|
||||
# If explicit strategy is provided and it's Function Calling, try to use it if supported
|
||||
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
return FunctionCallStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
# Fallback to ReAct if FC is requested but not supported
|
||||
|
||||
# If explicit strategy is Chain of Thought (ReAct)
|
||||
if agent_strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
return ReActStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
# Default auto-selection logic
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
# Model supports native function calling
|
||||
return FunctionCallStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
else:
|
||||
# Use ReAct strategy for models without function calling
|
||||
return ReActStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
@ -20,6 +20,8 @@ from core.app.entities.queue_entities import (
|
||||
QueueTextChunkEvent,
|
||||
)
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
@ -40,6 +42,7 @@ from models import Workflow
|
||||
from models.enums import UserFrom
|
||||
from models.model import App, Conversation, Message, MessageAnnotation
|
||||
from models.workflow import ConversationVariable
|
||||
from services.conversation_variable_updater import ConversationVariableUpdater
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -200,6 +203,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
conversation_variable_layer = ConversationVariablePersistenceLayer(
|
||||
ConversationVariableUpdater(session_factory.get_session_maker())
|
||||
)
|
||||
workflow_entry.graph_engine.layer(conversation_variable_layer)
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import re
|
||||
import time
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Thread
|
||||
from typing import Any, Union
|
||||
|
||||
@ -19,6 +20,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
ChunkType,
|
||||
MessageQueueMessage,
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
QueueAgentLogEvent,
|
||||
@ -70,13 +72,122 @@ from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models import Account, Conversation, EndUser, LLMGenerationDetail, Message, MessageFile
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamEventBuffer:
|
||||
"""
|
||||
Buffer for recording stream events in order to reconstruct the generation sequence.
|
||||
Records the exact order of text chunks, thoughts, and tool calls as they stream.
|
||||
"""
|
||||
|
||||
# Accumulated reasoning content (each thought block is a separate element)
|
||||
reasoning_content: list[str] = field(default_factory=list)
|
||||
# Current reasoning buffer (accumulates until we see a different event type)
|
||||
_current_reasoning: str = ""
|
||||
# Tool calls with their details
|
||||
tool_calls: list[dict] = field(default_factory=list)
|
||||
# Tool call ID to index mapping for updating results
|
||||
_tool_call_id_map: dict[str, int] = field(default_factory=dict)
|
||||
# Sequence of events in stream order
|
||||
sequence: list[dict] = field(default_factory=list)
|
||||
# Current position in answer text
|
||||
_content_position: int = 0
|
||||
# Track last event type to detect transitions
|
||||
_last_event_type: str | None = None
|
||||
|
||||
def _flush_current_reasoning(self) -> None:
|
||||
"""Flush accumulated reasoning to the list and add to sequence."""
|
||||
if self._current_reasoning.strip():
|
||||
self.reasoning_content.append(self._current_reasoning.strip())
|
||||
self.sequence.append({"type": "reasoning", "index": len(self.reasoning_content) - 1})
|
||||
self._current_reasoning = ""
|
||||
|
||||
def record_text_chunk(self, text: str) -> None:
|
||||
"""Record a text chunk event."""
|
||||
if not text:
|
||||
return
|
||||
|
||||
# Flush any pending reasoning first
|
||||
if self._last_event_type == "thought":
|
||||
self._flush_current_reasoning()
|
||||
|
||||
text_len = len(text)
|
||||
start_pos = self._content_position
|
||||
|
||||
# If last event was also content, extend it; otherwise create new
|
||||
if self.sequence and self.sequence[-1].get("type") == "content":
|
||||
self.sequence[-1]["end"] = start_pos + text_len
|
||||
else:
|
||||
self.sequence.append({"type": "content", "start": start_pos, "end": start_pos + text_len})
|
||||
|
||||
self._content_position += text_len
|
||||
self._last_event_type = "content"
|
||||
|
||||
def record_thought_chunk(self, text: str) -> None:
|
||||
"""Record a thought/reasoning chunk event."""
|
||||
if not text:
|
||||
return
|
||||
|
||||
# Accumulate thought content
|
||||
self._current_reasoning += text
|
||||
self._last_event_type = "thought"
|
||||
|
||||
def record_tool_call(self, tool_call_id: str, tool_name: str, tool_arguments: str) -> None:
|
||||
"""Record a tool call event."""
|
||||
if not tool_call_id:
|
||||
return
|
||||
|
||||
# Flush any pending reasoning first
|
||||
if self._last_event_type == "thought":
|
||||
self._flush_current_reasoning()
|
||||
|
||||
# Check if this tool call already exists (we might get multiple chunks)
|
||||
if tool_call_id in self._tool_call_id_map:
|
||||
idx = self._tool_call_id_map[tool_call_id]
|
||||
# Update arguments if provided
|
||||
if tool_arguments:
|
||||
self.tool_calls[idx]["arguments"] = tool_arguments
|
||||
else:
|
||||
# New tool call
|
||||
tool_call = {
|
||||
"id": tool_call_id or "",
|
||||
"name": tool_name or "",
|
||||
"arguments": tool_arguments or "",
|
||||
"result": "",
|
||||
"elapsed_time": None,
|
||||
}
|
||||
self.tool_calls.append(tool_call)
|
||||
idx = len(self.tool_calls) - 1
|
||||
self._tool_call_id_map[tool_call_id] = idx
|
||||
self.sequence.append({"type": "tool_call", "index": idx})
|
||||
|
||||
self._last_event_type = "tool_call"
|
||||
|
||||
def record_tool_result(self, tool_call_id: str, result: str, tool_elapsed_time: float | None = None) -> None:
|
||||
"""Record a tool result event (update existing tool call)."""
|
||||
if not tool_call_id:
|
||||
return
|
||||
if tool_call_id in self._tool_call_id_map:
|
||||
idx = self._tool_call_id_map[tool_call_id]
|
||||
self.tool_calls[idx]["result"] = result
|
||||
self.tool_calls[idx]["elapsed_time"] = tool_elapsed_time
|
||||
|
||||
def finalize(self) -> None:
|
||||
"""Finalize the buffer, flushing any pending data."""
|
||||
if self._last_event_type == "thought":
|
||||
self._flush_current_reasoning()
|
||||
|
||||
def has_data(self) -> bool:
|
||||
"""Check if there's any meaningful data recorded."""
|
||||
return bool(self.reasoning_content or self.tool_calls or self.sequence)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
@ -144,6 +255,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
self._workflow_run_id: str = ""
|
||||
self._draft_var_saver_factory = draft_var_saver_factory
|
||||
self._graph_runtime_state: GraphRuntimeState | None = None
|
||||
# Stream event buffer for recording generation sequence
|
||||
self._stream_buffer = StreamEventBuffer()
|
||||
self._seed_graph_runtime_state_from_queue_manager()
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
@ -358,6 +471,25 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
|
||||
# For ANSWER nodes, check if we need to send a message_replace event
|
||||
# Only send if the final output differs from the accumulated task_state.answer
|
||||
# This happens when variables were updated by variable_assigner during workflow execution
|
||||
if event.node_type == NodeType.ANSWER and event.outputs:
|
||||
final_answer = event.outputs.get("answer")
|
||||
if final_answer is not None and final_answer != self._task_state.answer:
|
||||
logger.info(
|
||||
"ANSWER node final output '%s' differs from accumulated answer '%s', sending message_replace event",
|
||||
final_answer,
|
||||
self._task_state.answer,
|
||||
)
|
||||
# Update the task state answer
|
||||
self._task_state.answer = str(final_answer)
|
||||
# Send message_replace event to update the UI
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||
answer=str(final_answer),
|
||||
reason="variable_update",
|
||||
)
|
||||
|
||||
def _handle_node_failed_events(
|
||||
self,
|
||||
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
|
||||
@ -383,7 +515,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle text chunk events."""
|
||||
"""Handle text chunk events and record to stream buffer for sequence reconstruction."""
|
||||
delta_text = event.text
|
||||
if delta_text is None:
|
||||
return
|
||||
@ -405,9 +537,52 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if tts_publisher and queue_message:
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
tool_call = event.tool_call
|
||||
tool_result = event.tool_result
|
||||
tool_payload = tool_call or tool_result
|
||||
tool_call_id = tool_payload.id if tool_payload and tool_payload.id else ""
|
||||
tool_name = tool_payload.name if tool_payload and tool_payload.name else ""
|
||||
tool_arguments = tool_call.arguments if tool_call and tool_call.arguments else ""
|
||||
tool_files = tool_result.files if tool_result else []
|
||||
tool_elapsed_time = tool_result.elapsed_time if tool_result else None
|
||||
tool_icon = tool_payload.icon if tool_payload else None
|
||||
tool_icon_dark = tool_payload.icon_dark if tool_payload else None
|
||||
# Record stream event based on chunk type
|
||||
chunk_type = event.chunk_type or ChunkType.TEXT
|
||||
match chunk_type:
|
||||
case ChunkType.TEXT:
|
||||
self._stream_buffer.record_text_chunk(delta_text)
|
||||
self._task_state.answer += delta_text
|
||||
case ChunkType.THOUGHT:
|
||||
# Reasoning should not be part of final answer text
|
||||
self._stream_buffer.record_thought_chunk(delta_text)
|
||||
case ChunkType.TOOL_CALL:
|
||||
self._stream_buffer.record_tool_call(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_arguments=tool_arguments,
|
||||
)
|
||||
case ChunkType.TOOL_RESULT:
|
||||
self._stream_buffer.record_tool_result(
|
||||
tool_call_id=tool_call_id,
|
||||
result=delta_text,
|
||||
tool_elapsed_time=tool_elapsed_time,
|
||||
)
|
||||
self._task_state.answer += delta_text
|
||||
case _:
|
||||
pass
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||
answer=delta_text,
|
||||
message_id=self._message_id,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
chunk_type=event.chunk_type.value if event.chunk_type else None,
|
||||
tool_call_id=tool_call_id or None,
|
||||
tool_name=tool_name or None,
|
||||
tool_arguments=tool_arguments or None,
|
||||
tool_files=tool_files,
|
||||
tool_elapsed_time=tool_elapsed_time,
|
||||
tool_icon=tool_icon,
|
||||
tool_icon_dark=tool_icon_dark,
|
||||
)
|
||||
|
||||
def _handle_iteration_start_event(
|
||||
@ -775,6 +950,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
answer_text = self._task_state.answer
|
||||
answer_text = self._strip_think_blocks(answer_text)
|
||||
if self._recorded_files:
|
||||
# Remove markdown image links since we're storing files separately
|
||||
answer_text = re.sub(r"!\[.*?\]\(.*?\)", "", answer_text).strip()
|
||||
@ -826,6 +1002,54 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
]
|
||||
session.add_all(message_files)
|
||||
|
||||
# Save generation detail (reasoning/tool calls/sequence) from stream buffer
|
||||
self._save_generation_detail(session=session, message=message)
|
||||
|
||||
@staticmethod
|
||||
def _strip_think_blocks(text: str) -> str:
|
||||
"""Remove <think>...</think> blocks (including their content) from text."""
|
||||
if not text or "<think" not in text.lower():
|
||||
return text
|
||||
|
||||
clean_text = re.sub(r"<think[^>]*>.*?</think>", "", text, flags=re.IGNORECASE | re.DOTALL)
|
||||
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
|
||||
return clean_text
|
||||
|
||||
def _save_generation_detail(self, *, session: Session, message: Message) -> None:
|
||||
"""
|
||||
Save LLM generation detail for Chatflow using stream event buffer.
|
||||
The buffer records the exact order of events as they streamed,
|
||||
allowing accurate reconstruction of the generation sequence.
|
||||
"""
|
||||
# Finalize the stream buffer to flush any pending data
|
||||
self._stream_buffer.finalize()
|
||||
|
||||
# Only save if there's meaningful data
|
||||
if not self._stream_buffer.has_data():
|
||||
return
|
||||
|
||||
reasoning_content = self._stream_buffer.reasoning_content
|
||||
tool_calls = self._stream_buffer.tool_calls
|
||||
sequence = self._stream_buffer.sequence
|
||||
|
||||
# Check if generation detail already exists for this message
|
||||
existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first()
|
||||
|
||||
if existing:
|
||||
existing.reasoning_content = json.dumps(reasoning_content) if reasoning_content else None
|
||||
existing.tool_calls = json.dumps(tool_calls) if tool_calls else None
|
||||
existing.sequence = json.dumps(sequence) if sequence else None
|
||||
else:
|
||||
generation_detail = LLMGenerationDetail(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
message_id=message.id,
|
||||
reasoning_content=json.dumps(reasoning_content) if reasoning_content else None,
|
||||
tool_calls=json.dumps(tool_calls) if tool_calls else None,
|
||||
sequence=json.dumps(sequence) if sequence else None,
|
||||
)
|
||||
session.add(generation_detail)
|
||||
|
||||
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
|
||||
"""Bootstrap the cached runtime state from the queue manager when present."""
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
||||
@ -3,10 +3,8 @@ from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
from core.agent.entities import AgentEntity
|
||||
from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
@ -14,8 +12,7 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.base import ModerationError
|
||||
from extensions.ext_database import db
|
||||
@ -194,22 +191,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
raise ValueError("Message not found")
|
||||
db.session.close()
|
||||
|
||||
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
|
||||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
# check LLM mode
|
||||
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
|
||||
runner_cls = CotChatAgentRunner
|
||||
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
|
||||
runner_cls = CotCompletionAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")
|
||||
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
runner_cls = FunctionCallAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}")
|
||||
|
||||
runner = runner_cls(
|
||||
runner = AgentAppRunner(
|
||||
tenant_id=app_config.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation_result,
|
||||
|
||||
@ -671,7 +671,7 @@ class WorkflowResponseConverter:
|
||||
task_id=task_id,
|
||||
data=AgentLogStreamResponse.Data(
|
||||
node_execution_id=event.node_execution_id,
|
||||
id=event.id,
|
||||
message_id=event.id,
|
||||
parent_id=event.parent_id,
|
||||
label=event.label,
|
||||
error=event.error,
|
||||
|
||||
@ -130,7 +130,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
|
||||
)
|
||||
documents: list[Document] = []
|
||||
if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"):
|
||||
if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry and not args.get("original_document_id"):
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
for datasource_info in datasource_info_list:
|
||||
@ -156,7 +156,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
for i, datasource_info in enumerate(datasource_info_list):
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
document_id = args.get("original_document_id") or None
|
||||
if invoke_from == InvokeFrom.PUBLISHED and not is_retry:
|
||||
if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry:
|
||||
document_id = document_id or documents[i].id
|
||||
document_pipeline_execution_log = DocumentPipelineExecutionLog(
|
||||
document_id=document_id,
|
||||
|
||||
@ -13,6 +13,7 @@ from core.app.apps.common.workflow_response_converter import WorkflowResponseCon
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
ChunkType,
|
||||
MessageQueueMessage,
|
||||
QueueAgentLogEvent,
|
||||
QueueErrorEvent,
|
||||
@ -483,11 +484,33 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if delta_text is None:
|
||||
return
|
||||
|
||||
tool_call = event.tool_call
|
||||
tool_result = event.tool_result
|
||||
tool_payload = tool_call or tool_result
|
||||
tool_call_id = tool_payload.id if tool_payload and tool_payload.id else None
|
||||
tool_name = tool_payload.name if tool_payload and tool_payload.name else None
|
||||
tool_arguments = tool_call.arguments if tool_call else None
|
||||
tool_elapsed_time = tool_result.elapsed_time if tool_result else None
|
||||
tool_files = tool_result.files if tool_result else []
|
||||
tool_icon = tool_payload.icon if tool_payload else None
|
||||
tool_icon_dark = tool_payload.icon_dark if tool_payload else None
|
||||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher and queue_message:
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector)
|
||||
yield self._text_chunk_to_stream_response(
|
||||
text=delta_text,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
chunk_type=event.chunk_type,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_arguments=tool_arguments,
|
||||
tool_files=tool_files,
|
||||
tool_elapsed_time=tool_elapsed_time,
|
||||
tool_icon=tool_icon,
|
||||
tool_icon_dark=tool_icon_dark,
|
||||
)
|
||||
|
||||
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle agent log events."""
|
||||
@ -650,16 +673,61 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
session.add(workflow_app_log)
|
||||
|
||||
def _text_chunk_to_stream_response(
|
||||
self, text: str, from_variable_selector: list[str] | None = None
|
||||
self,
|
||||
text: str,
|
||||
from_variable_selector: list[str] | None = None,
|
||||
chunk_type: ChunkType | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tool_arguments: str | None = None,
|
||||
tool_files: list[str] | None = None,
|
||||
tool_error: str | None = None,
|
||||
tool_elapsed_time: float | None = None,
|
||||
tool_icon: str | dict | None = None,
|
||||
tool_icon_dark: str | dict | None = None,
|
||||
) -> TextChunkStreamResponse:
|
||||
"""
|
||||
Handle completed event.
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
from core.app.entities.task_entities import ChunkType as ResponseChunkType
|
||||
|
||||
response_chunk_type = ResponseChunkType(chunk_type.value) if chunk_type else ResponseChunkType.TEXT
|
||||
|
||||
data = TextChunkStreamResponse.Data(
|
||||
text=text,
|
||||
from_variable_selector=from_variable_selector,
|
||||
chunk_type=response_chunk_type,
|
||||
)
|
||||
|
||||
if response_chunk_type == ResponseChunkType.TOOL_CALL:
|
||||
data = data.model_copy(
|
||||
update={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_arguments": tool_arguments,
|
||||
"tool_icon": tool_icon,
|
||||
"tool_icon_dark": tool_icon_dark,
|
||||
}
|
||||
)
|
||||
elif response_chunk_type == ResponseChunkType.TOOL_RESULT:
|
||||
data = data.model_copy(
|
||||
update={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_arguments": tool_arguments,
|
||||
"tool_files": tool_files,
|
||||
"tool_error": tool_error,
|
||||
"tool_elapsed_time": tool_elapsed_time,
|
||||
"tool_icon": tool_icon,
|
||||
"tool_icon_dark": tool_icon_dark,
|
||||
}
|
||||
)
|
||||
|
||||
response = TextChunkStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector),
|
||||
data=data,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@ -455,12 +455,20 @@ class WorkflowBasedAppRunner:
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
from core.app.entities.queue_entities import ChunkType as QueueChunkType
|
||||
|
||||
if event.is_final and not event.chunk:
|
||||
return
|
||||
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=event.chunk,
|
||||
from_variable_selector=list(event.selector),
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
chunk_type=QueueChunkType(event.chunk_type.value),
|
||||
tool_call=event.tool_call,
|
||||
tool_result=event.tool_result,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
|
||||
@ -42,7 +42,8 @@ class InvokeFrom(StrEnum):
|
||||
# DEBUGGER indicates that this invocation is from
|
||||
# the workflow (or chatflow) edit page.
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED = "published"
|
||||
# PUBLISHED_PIPELINE indicates that this invocation runs a published RAG pipeline workflow.
|
||||
PUBLISHED_PIPELINE = "published"
|
||||
|
||||
# VALIDATION indicates that this invocation is from validation.
|
||||
VALIDATION = "validation"
|
||||
|
||||
70
api/core/app/entities/llm_generation_entities.py
Normal file
70
api/core/app/entities/llm_generation_entities.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""
|
||||
LLM Generation Detail entities.
|
||||
|
||||
Defines the structure for storing and transmitting LLM generation details
|
||||
including reasoning content, tool calls, and their sequence.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ContentSegment(BaseModel):
|
||||
"""Represents a content segment in the generation sequence."""
|
||||
|
||||
type: Literal["content"] = "content"
|
||||
start: int = Field(..., description="Start position in the text")
|
||||
end: int = Field(..., description="End position in the text")
|
||||
|
||||
|
||||
class ReasoningSegment(BaseModel):
|
||||
"""Represents a reasoning segment in the generation sequence."""
|
||||
|
||||
type: Literal["reasoning"] = "reasoning"
|
||||
index: int = Field(..., description="Index into reasoning_content array")
|
||||
|
||||
|
||||
class ToolCallSegment(BaseModel):
|
||||
"""Represents a tool call segment in the generation sequence."""
|
||||
|
||||
type: Literal["tool_call"] = "tool_call"
|
||||
index: int = Field(..., description="Index into tool_calls array")
|
||||
|
||||
|
||||
SequenceSegment = ContentSegment | ReasoningSegment | ToolCallSegment
|
||||
|
||||
|
||||
class ToolCallDetail(BaseModel):
|
||||
"""Represents a tool call with its arguments and result."""
|
||||
|
||||
id: str = Field(default="", description="Unique identifier for the tool call")
|
||||
name: str = Field(..., description="Name of the tool")
|
||||
arguments: str = Field(default="", description="JSON string of tool arguments")
|
||||
result: str = Field(default="", description="Result from the tool execution")
|
||||
elapsed_time: float | None = Field(default=None, description="Elapsed time in seconds")
|
||||
|
||||
|
||||
class LLMGenerationDetailData(BaseModel):
|
||||
"""
|
||||
Domain model for LLM generation detail.
|
||||
|
||||
Contains the structured data for reasoning content, tool calls,
|
||||
and their display sequence.
|
||||
"""
|
||||
|
||||
reasoning_content: list[str] = Field(default_factory=list, description="List of reasoning segments")
|
||||
tool_calls: list[ToolCallDetail] = Field(default_factory=list, description="List of tool call details")
|
||||
sequence: list[SequenceSegment] = Field(default_factory=list, description="Display order of segments")
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if there's any meaningful generation detail."""
|
||||
return not self.reasoning_content and not self.tool_calls
|
||||
|
||||
def to_response_dict(self) -> dict:
|
||||
"""Convert to dictionary for API response."""
|
||||
return {
|
||||
"reasoning_content": self.reasoning_content,
|
||||
"tool_calls": [tc.model_dump() for tc in self.tool_calls],
|
||||
"sequence": [seg.model_dump() for seg in self.sequence],
|
||||
}
|
||||
@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes import NodeType
|
||||
|
||||
@ -177,6 +177,17 @@ class QueueLoopCompletedEvent(AppQueueEvent):
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
"""Stream chunk type for LLM-related events."""
|
||||
|
||||
TEXT = "text" # Normal text streaming
|
||||
TOOL_CALL = "tool_call" # Tool call arguments streaming
|
||||
TOOL_RESULT = "tool_result" # Tool execution result
|
||||
THOUGHT = "thought" # Agent thinking process (ReAct)
|
||||
THOUGHT_START = "thought_start" # Agent thought start
|
||||
THOUGHT_END = "thought_end" # Agent thought end
|
||||
|
||||
|
||||
class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueTextChunkEvent entity
|
||||
@ -191,6 +202,16 @@ class QueueTextChunkEvent(AppQueueEvent):
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
# Extended fields for Agent/Tool streaming
|
||||
chunk_type: ChunkType = ChunkType.TEXT
|
||||
"""type of the chunk"""
|
||||
|
||||
# Tool streaming payloads
|
||||
tool_call: ToolCall | None = None
|
||||
"""structured tool call info"""
|
||||
tool_result: ToolResult | None = None
|
||||
"""structured tool result info"""
|
||||
|
||||
|
||||
class QueueAgentMessageEvent(AppQueueEvent):
|
||||
"""
|
||||
|
||||
@ -113,6 +113,38 @@ class MessageStreamResponse(StreamResponse):
|
||||
answer: str
|
||||
from_variable_selector: list[str] | None = None
|
||||
|
||||
# Extended fields for Agent/Tool streaming (imported at runtime to avoid circular import)
|
||||
chunk_type: str | None = None
|
||||
"""type of the chunk: text, tool_call, tool_result, thought"""
|
||||
|
||||
# Tool call fields (when chunk_type == "tool_call")
|
||||
tool_call_id: str | None = None
|
||||
"""unique identifier for this tool call"""
|
||||
tool_name: str | None = None
|
||||
"""name of the tool being called"""
|
||||
tool_arguments: str | None = None
|
||||
"""accumulated tool arguments JSON"""
|
||||
|
||||
# Tool result fields (when chunk_type == "tool_result")
|
||||
tool_files: list[str] | None = None
|
||||
"""file IDs produced by tool"""
|
||||
tool_error: str | None = None
|
||||
"""error message if tool failed"""
|
||||
tool_elapsed_time: float | None = None
|
||||
"""elapsed time spent executing the tool"""
|
||||
tool_icon: str | dict | None = None
|
||||
"""icon of the tool"""
|
||||
tool_icon_dark: str | dict | None = None
|
||||
"""dark theme icon of the tool"""
|
||||
|
||||
def model_dump(self, *args, **kwargs) -> dict[str, object]:
|
||||
kwargs.setdefault("exclude_none", True)
|
||||
return super().model_dump(*args, **kwargs)
|
||||
|
||||
def model_dump_json(self, *args, **kwargs) -> str:
|
||||
kwargs.setdefault("exclude_none", True)
|
||||
return super().model_dump_json(*args, **kwargs)
|
||||
|
||||
|
||||
class MessageAudioStreamResponse(StreamResponse):
|
||||
"""
|
||||
@ -582,6 +614,17 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
"""Stream chunk type for LLM-related events."""
|
||||
|
||||
TEXT = "text" # Normal text streaming
|
||||
TOOL_CALL = "tool_call" # Tool call arguments streaming
|
||||
TOOL_RESULT = "tool_result" # Tool execution result
|
||||
THOUGHT = "thought" # Agent thinking process (ReAct)
|
||||
THOUGHT_START = "thought_start" # Agent thought start
|
||||
THOUGHT_END = "thought_end" # Agent thought end
|
||||
|
||||
|
||||
class TextChunkStreamResponse(StreamResponse):
|
||||
"""
|
||||
TextChunkStreamResponse entity
|
||||
@ -595,6 +638,36 @@ class TextChunkStreamResponse(StreamResponse):
|
||||
text: str
|
||||
from_variable_selector: list[str] | None = None
|
||||
|
||||
# Extended fields for Agent/Tool streaming
|
||||
chunk_type: ChunkType = ChunkType.TEXT
|
||||
"""type of the chunk"""
|
||||
|
||||
# Tool call fields (when chunk_type == TOOL_CALL)
|
||||
tool_call_id: str | None = None
|
||||
"""unique identifier for this tool call"""
|
||||
tool_name: str | None = None
|
||||
"""name of the tool being called"""
|
||||
tool_arguments: str | None = None
|
||||
"""accumulated tool arguments JSON"""
|
||||
|
||||
# Tool result fields (when chunk_type == TOOL_RESULT)
|
||||
tool_files: list[str] | None = None
|
||||
"""file IDs produced by tool"""
|
||||
tool_error: str | None = None
|
||||
"""error message if tool failed"""
|
||||
|
||||
# Tool elapsed time fields (when chunk_type == TOOL_RESULT)
|
||||
tool_elapsed_time: float | None = None
|
||||
"""elapsed time spent executing the tool"""
|
||||
|
||||
def model_dump(self, *args, **kwargs) -> dict[str, object]:
|
||||
kwargs.setdefault("exclude_none", True)
|
||||
return super().model_dump(*args, **kwargs)
|
||||
|
||||
def model_dump_json(self, *args, **kwargs) -> str:
|
||||
kwargs.setdefault("exclude_none", True)
|
||||
return super().model_dump_json(*args, **kwargs)
|
||||
|
||||
event: StreamEvent = StreamEvent.TEXT_CHUNK
|
||||
data: Data
|
||||
|
||||
@ -743,7 +816,7 @@ class AgentLogStreamResponse(StreamResponse):
|
||||
"""
|
||||
|
||||
node_execution_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
label: str
|
||||
parent_id: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
60
api/core/app/layers/conversation_variable_persist_layer.py
Normal file
60
api/core/app/layers/conversation_variable_persist_layer.py
Normal file
@ -0,0 +1,60 @@
|
||||
import logging
|
||||
|
||||
from core.variables import Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationVariablePersistenceLayer(GraphEngineLayer):
|
||||
def __init__(self, conversation_variable_updater: ConversationVariableUpdater) -> None:
|
||||
super().__init__()
|
||||
self._conversation_variable_updater = conversation_variable_updater
|
||||
|
||||
def on_graph_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if not isinstance(event, NodeRunSucceededEvent):
|
||||
return
|
||||
if event.node_type != NodeType.VARIABLE_ASSIGNER:
|
||||
return
|
||||
if self.graph_runtime_state is None:
|
||||
return
|
||||
|
||||
updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or []
|
||||
if not updated_variables:
|
||||
return
|
||||
|
||||
conversation_id = self.graph_runtime_state.system_variable.conversation_id
|
||||
if conversation_id is None:
|
||||
return
|
||||
|
||||
updated_any = False
|
||||
for item in updated_variables:
|
||||
selector = item.selector
|
||||
if len(selector) < 2:
|
||||
logger.warning("Conversation variable selector invalid. selector=%s", selector)
|
||||
continue
|
||||
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
|
||||
continue
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if not isinstance(variable, Variable):
|
||||
logger.warning(
|
||||
"Conversation variable not found in variable pool. selector=%s",
|
||||
selector,
|
||||
)
|
||||
continue
|
||||
self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable)
|
||||
updated_any = True
|
||||
|
||||
if updated_any:
|
||||
self._conversation_variable_updater.flush()
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
pass
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from threading import Thread
|
||||
@ -58,7 +59,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import AppMode, Conversation, Message, MessageAgentThought
|
||||
from models.model import AppMode, Conversation, LLMGenerationDetail, Message, MessageAgentThought
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -68,6 +69,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
_task_state: EasyUITaskState
|
||||
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
|
||||
|
||||
@ -409,11 +412,136 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
)
|
||||
)
|
||||
|
||||
# Save LLM generation detail if there's reasoning_content
|
||||
self._save_generation_detail(session=session, message=message, llm_result=llm_result)
|
||||
|
||||
message_was_created.send(
|
||||
message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
)
|
||||
|
||||
def _save_generation_detail(self, *, session: Session, message: Message, llm_result: LLMResult) -> None:
|
||||
"""
|
||||
Save LLM generation detail for Completion/Chat/Agent-Chat applications.
|
||||
For Agent-Chat, also merges MessageAgentThought records.
|
||||
"""
|
||||
import json
|
||||
|
||||
reasoning_list: list[str] = []
|
||||
tool_calls_list: list[dict] = []
|
||||
sequence: list[dict] = []
|
||||
answer = message.answer or ""
|
||||
|
||||
# Check if this is Agent-Chat mode by looking for agent thoughts
|
||||
agent_thoughts = (
|
||||
session.query(MessageAgentThought)
|
||||
.filter_by(message_id=message.id)
|
||||
.order_by(MessageAgentThought.position.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
if agent_thoughts:
|
||||
# Agent-Chat mode: merge MessageAgentThought records
|
||||
content_pos = 0
|
||||
cleaned_answer_parts: list[str] = []
|
||||
for thought in agent_thoughts:
|
||||
# Add thought/reasoning
|
||||
if thought.thought:
|
||||
reasoning_text = thought.thought
|
||||
if "<think" in reasoning_text.lower():
|
||||
clean_text, extracted_reasoning = self._split_reasoning_from_answer(reasoning_text)
|
||||
if extracted_reasoning:
|
||||
reasoning_text = extracted_reasoning
|
||||
thought.thought = clean_text or extracted_reasoning
|
||||
reasoning_list.append(reasoning_text)
|
||||
sequence.append({"type": "reasoning", "index": len(reasoning_list) - 1})
|
||||
|
||||
# Add tool calls
|
||||
if thought.tool:
|
||||
tool_calls_list.append(
|
||||
{
|
||||
"name": thought.tool,
|
||||
"arguments": thought.tool_input or "",
|
||||
"result": thought.observation or "",
|
||||
}
|
||||
)
|
||||
sequence.append({"type": "tool_call", "index": len(tool_calls_list) - 1})
|
||||
|
||||
# Add answer content if present
|
||||
if thought.answer:
|
||||
content_text = thought.answer
|
||||
if "<think" in content_text.lower():
|
||||
clean_answer, extracted_reasoning = self._split_reasoning_from_answer(content_text)
|
||||
if extracted_reasoning:
|
||||
reasoning_list.append(extracted_reasoning)
|
||||
sequence.append({"type": "reasoning", "index": len(reasoning_list) - 1})
|
||||
content_text = clean_answer
|
||||
thought.answer = clean_answer or content_text
|
||||
|
||||
if content_text:
|
||||
start = content_pos
|
||||
end = content_pos + len(content_text)
|
||||
sequence.append({"type": "content", "start": start, "end": end})
|
||||
content_pos = end
|
||||
cleaned_answer_parts.append(content_text)
|
||||
|
||||
if cleaned_answer_parts:
|
||||
merged_answer = "".join(cleaned_answer_parts)
|
||||
message.answer = merged_answer
|
||||
llm_result.message.content = merged_answer
|
||||
else:
|
||||
# Completion/Chat mode: use reasoning_content from llm_result
|
||||
reasoning_content = llm_result.reasoning_content
|
||||
if not reasoning_content and answer:
|
||||
# Extract reasoning from <think> blocks and clean the final answer
|
||||
clean_answer, reasoning_content = self._split_reasoning_from_answer(answer)
|
||||
if reasoning_content:
|
||||
answer = clean_answer
|
||||
llm_result.message.content = clean_answer
|
||||
llm_result.reasoning_content = reasoning_content
|
||||
message.answer = clean_answer
|
||||
if reasoning_content:
|
||||
reasoning_list = [reasoning_content]
|
||||
# Content comes first, then reasoning
|
||||
if answer:
|
||||
sequence.append({"type": "content", "start": 0, "end": len(answer)})
|
||||
sequence.append({"type": "reasoning", "index": 0})
|
||||
|
||||
# Only save if there's meaningful generation detail
|
||||
if not reasoning_list and not tool_calls_list:
|
||||
return
|
||||
|
||||
# Check if generation detail already exists
|
||||
existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first()
|
||||
|
||||
if existing:
|
||||
existing.reasoning_content = json.dumps(reasoning_list) if reasoning_list else None
|
||||
existing.tool_calls = json.dumps(tool_calls_list) if tool_calls_list else None
|
||||
existing.sequence = json.dumps(sequence) if sequence else None
|
||||
else:
|
||||
generation_detail = LLMGenerationDetail(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
message_id=message.id,
|
||||
reasoning_content=json.dumps(reasoning_list) if reasoning_list else None,
|
||||
tool_calls=json.dumps(tool_calls_list) if tool_calls_list else None,
|
||||
sequence=json.dumps(sequence) if sequence else None,
|
||||
)
|
||||
session.add(generation_detail)
|
||||
|
||||
@classmethod
|
||||
def _split_reasoning_from_answer(cls, text: str) -> tuple[str, str]:
|
||||
"""
|
||||
Extract reasoning segments from <think> blocks and return (clean_text, reasoning).
|
||||
"""
|
||||
matches = cls._THINK_PATTERN.findall(text)
|
||||
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
|
||||
|
||||
clean_text = cls._THINK_PATTERN.sub("", text)
|
||||
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
|
||||
|
||||
return clean_text, reasoning_content or ""
|
||||
|
||||
def _handle_stop(self, event: QueueStopEvent):
|
||||
"""
|
||||
Handle stop.
|
||||
|
||||
@ -232,15 +232,31 @@ class MessageCycleManager:
|
||||
answer: str,
|
||||
message_id: str,
|
||||
from_variable_selector: list[str] | None = None,
|
||||
chunk_type: str | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tool_arguments: str | None = None,
|
||||
tool_files: list[str] | None = None,
|
||||
tool_error: str | None = None,
|
||||
tool_elapsed_time: float | None = None,
|
||||
tool_icon: str | dict | None = None,
|
||||
tool_icon_dark: str | dict | None = None,
|
||||
event_type: StreamEvent | None = None,
|
||||
) -> MessageStreamResponse:
|
||||
"""
|
||||
Message to stream response.
|
||||
:param answer: answer
|
||||
:param message_id: message id
|
||||
:param from_variable_selector: from variable selector
|
||||
:param chunk_type: type of the chunk (text, function_call, tool_result, thought)
|
||||
:param tool_call_id: unique identifier for this tool call
|
||||
:param tool_name: name of the tool being called
|
||||
:param tool_arguments: accumulated tool arguments JSON
|
||||
:param tool_files: file IDs produced by tool
|
||||
:param tool_error: error message if tool failed
|
||||
:return:
|
||||
"""
|
||||
return MessageStreamResponse(
|
||||
response = MessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer,
|
||||
@ -248,6 +264,35 @@ class MessageCycleManager:
|
||||
event=event_type or StreamEvent.MESSAGE,
|
||||
)
|
||||
|
||||
if chunk_type:
|
||||
response = response.model_copy(update={"chunk_type": chunk_type})
|
||||
|
||||
if chunk_type == "tool_call":
|
||||
response = response.model_copy(
|
||||
update={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_arguments": tool_arguments,
|
||||
"tool_icon": tool_icon,
|
||||
"tool_icon_dark": tool_icon_dark,
|
||||
}
|
||||
)
|
||||
elif chunk_type == "tool_result":
|
||||
response = response.model_copy(
|
||||
update={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_arguments": tool_arguments,
|
||||
"tool_files": tool_files,
|
||||
"tool_error": tool_error,
|
||||
"tool_elapsed_time": tool_elapsed_time,
|
||||
"tool_icon": tool_icon,
|
||||
"tool_icon_dark": tool_icon_dark,
|
||||
}
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
"""
|
||||
Message replace to stream response.
|
||||
|
||||
@ -5,7 +5,6 @@ from sqlalchemy import select
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.models.document import Document
|
||||
@ -90,6 +89,8 @@ class DatasetIndexToolCallbackHandler:
|
||||
# TODO(-LAN-): Improve type check
|
||||
def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]):
|
||||
"""Handle return_retriever_resource_info."""
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from configs import dify_config
|
||||
@ -30,7 +32,7 @@ class DatasourcePlugin(ABC):
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> DatasourcePlugin:
|
||||
return self.__class__(
|
||||
entity=self.entity.model_copy(),
|
||||
runtime=runtime,
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
@ -31,7 +33,7 @@ class DatasourceProviderType(enum.StrEnum):
|
||||
ONLINE_DRIVE = "online_drive"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "DatasourceProviderType":
|
||||
def value_of(cls, value: str) -> DatasourceProviderType:
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@ -81,7 +83,7 @@ class DatasourceParameter(PluginParameter):
|
||||
typ: DatasourceParameterType,
|
||||
required: bool,
|
||||
options: list[str] | None = None,
|
||||
) -> "DatasourceParameter":
|
||||
) -> DatasourceParameter:
|
||||
"""
|
||||
get a simple datasource parameter
|
||||
|
||||
@ -187,14 +189,14 @@ class DatasourceInvokeMeta(BaseModel):
|
||||
tool_config: dict | None = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "DatasourceInvokeMeta":
|
||||
def empty(cls) -> DatasourceInvokeMeta:
|
||||
"""
|
||||
Get an empty instance of DatasourceInvokeMeta
|
||||
"""
|
||||
return cls(time_cost=0.0, error=None, tool_config={})
|
||||
|
||||
@classmethod
|
||||
def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
|
||||
def error_instance(cls, error: str) -> DatasourceInvokeMeta:
|
||||
"""
|
||||
Get an instance of DatasourceInvokeMeta with error
|
||||
"""
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
_session_maker: sessionmaker | None = None
|
||||
_session_maker: sessionmaker[Session] | None = None
|
||||
|
||||
|
||||
def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
|
||||
@ -10,7 +10,7 @@ def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
|
||||
_session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
|
||||
|
||||
|
||||
def get_session_maker() -> sessionmaker:
|
||||
def get_session_maker() -> sessionmaker[Session]:
|
||||
if _session_maker is None:
|
||||
raise RuntimeError("Session factory not configured. Call configure_session_factory() first.")
|
||||
return _session_maker
|
||||
@ -27,7 +27,7 @@ class SessionFactory:
|
||||
configure_session_factory(engine, expire_on_commit)
|
||||
|
||||
@staticmethod
|
||||
def get_session_maker() -> sessionmaker:
|
||||
def get_session_maker() -> sessionmaker[Session]:
|
||||
return get_session_maker()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
@ -75,7 +77,7 @@ class MCPProviderEntity(BaseModel):
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
|
||||
def from_db_model(cls, db_provider: MCPToolProvider) -> MCPProviderEntity:
|
||||
"""Create entity from database model with decryption"""
|
||||
|
||||
return cls(
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum, auto
|
||||
from typing import Union
|
||||
|
||||
@ -178,7 +180,7 @@ class BasicProviderConfig(BaseModel):
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ProviderConfig.Type":
|
||||
def value_of(cls, value: str) -> ProviderConfig.Type:
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
|
||||
@ -8,8 +8,9 @@ import urllib.parse
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def get_signed_file_url(upload_file_id: str, as_attachment=False) -> str:
|
||||
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview"
|
||||
def get_signed_file_url(upload_file_id: str, as_attachment=False, for_external: bool = True) -> str:
|
||||
base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
|
||||
url = f"{base_url}/files/{upload_file_id}/file-preview"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
|
||||
@ -112,17 +112,17 @@ class File(BaseModel):
|
||||
|
||||
return text
|
||||
|
||||
def generate_url(self) -> str | None:
|
||||
def generate_url(self, for_external: bool = True) -> str | None:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.remote_url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if self.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
return helpers.get_signed_file_url(upload_file_id=self.related_id)
|
||||
return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external)
|
||||
elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]:
|
||||
assert self.related_id is not None
|
||||
assert self.extension is not None
|
||||
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension)
|
||||
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension, for_external=for_external)
|
||||
return None
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
@ -133,7 +133,7 @@ class File(BaseModel):
|
||||
"extension": self.extension,
|
||||
"size": self.size,
|
||||
"type": self.type,
|
||||
"url": self.generate_url(),
|
||||
"url": self.generate_url(for_external=False),
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
@ -76,7 +76,7 @@ class TemplateTransformer(ABC):
|
||||
Post-process the result to convert scientific notation strings back to numbers
|
||||
"""
|
||||
|
||||
def convert_scientific_notation(value):
|
||||
def convert_scientific_notation(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
# Check if the string looks like scientific notation
|
||||
if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
|
||||
@ -90,7 +90,7 @@ class TemplateTransformer(ABC):
|
||||
return [convert_scientific_notation(v) for v in value]
|
||||
return value
|
||||
|
||||
return convert_scientific_notation(result) # type: ignore[no-any-return]
|
||||
return convert_scientific_notation(result)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
|
||||
@ -68,13 +68,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
request_id: RequestId,
|
||||
request_meta: RequestParams.Meta | None,
|
||||
request: ReceiveRequestT,
|
||||
session: """BaseSession[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT
|
||||
]""",
|
||||
session: """BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]""",
|
||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||
):
|
||||
self.request_id = request_id
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum, auto
|
||||
@ -17,7 +19,7 @@ class PromptMessageRole(StrEnum):
|
||||
TOOL = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "PromptMessageRole":
|
||||
def value_of(cls, value: str) -> PromptMessageRole:
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
@ -20,7 +22,7 @@ class ModelType(StrEnum):
|
||||
TTS = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, origin_model_type: str) -> "ModelType":
|
||||
def value_of(cls, origin_model_type: str) -> ModelType:
|
||||
"""
|
||||
Get model type from origin model type.
|
||||
|
||||
@ -103,7 +105,7 @@ class DefaultParameterName(StrEnum):
|
||||
JSON_SCHEMA = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: Any) -> "DefaultParameterName":
|
||||
def value_of(cls, value: Any) -> DefaultParameterName:
|
||||
"""
|
||||
Get parameter name from value.
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
@ -38,7 +40,7 @@ class ModelProviderFactory:
|
||||
plugin_providers = self.get_plugin_model_providers()
|
||||
return [provider.declaration for provider in plugin_providers]
|
||||
|
||||
def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
|
||||
def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
|
||||
"""
|
||||
Get all plugin model providers
|
||||
:return: list of plugin model providers
|
||||
@ -76,7 +78,7 @@ class ModelProviderFactory:
|
||||
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
|
||||
return plugin_model_provider_entity.declaration
|
||||
|
||||
def get_plugin_model_provider(self, provider: str) -> "PluginModelProviderEntity":
|
||||
def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
|
||||
"""
|
||||
Get plugin model provider
|
||||
:param provider: provider name
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
@ -242,7 +244,7 @@ class CredentialType(enum.StrEnum):
|
||||
return [item.value for item in cls]
|
||||
|
||||
@classmethod
|
||||
def of(cls, credential_type: str) -> "CredentialType":
|
||||
def of(cls, credential_type: str) -> CredentialType:
|
||||
type_name = credential_type.lower()
|
||||
if type_name in {"api-key", "api_key"}:
|
||||
return cls.API_KEY
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
@ -6,7 +8,7 @@ import re
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import clickzetta # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
@ -76,7 +78,7 @@ class ClickzettaConnectionPool:
|
||||
Manages connection reuse across ClickzettaVector instances.
|
||||
"""
|
||||
|
||||
_instance: Optional["ClickzettaConnectionPool"] = None
|
||||
_instance: ClickzettaConnectionPool | None = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __init__(self):
|
||||
@ -89,7 +91,7 @@ class ClickzettaConnectionPool:
|
||||
self._start_cleanup_thread()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "ClickzettaConnectionPool":
|
||||
def get_instance(cls) -> ClickzettaConnectionPool:
|
||||
"""Get singleton instance of connection pool."""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
@ -104,7 +106,7 @@ class ClickzettaConnectionPool:
|
||||
f"{config.workspace}:{config.vcluster}:{config.schema_name}"
|
||||
)
|
||||
|
||||
def _create_connection(self, config: ClickzettaConfig) -> "Connection":
|
||||
def _create_connection(self, config: ClickzettaConfig) -> Connection:
|
||||
"""Create a new ClickZetta connection."""
|
||||
max_retries = 3
|
||||
retry_delay = 1.0
|
||||
@ -134,7 +136,7 @@ class ClickzettaConnectionPool:
|
||||
|
||||
raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts")
|
||||
|
||||
def _configure_connection(self, connection: "Connection"):
|
||||
def _configure_connection(self, connection: Connection):
|
||||
"""Configure connection session settings."""
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
@ -181,7 +183,7 @@ class ClickzettaConnectionPool:
|
||||
except Exception:
|
||||
logger.exception("Failed to configure connection, continuing with defaults")
|
||||
|
||||
def _is_connection_valid(self, connection: "Connection") -> bool:
|
||||
def _is_connection_valid(self, connection: Connection) -> bool:
|
||||
"""Check if connection is still valid."""
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
@ -190,7 +192,7 @@ class ClickzettaConnectionPool:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_connection(self, config: ClickzettaConfig) -> "Connection":
|
||||
def get_connection(self, config: ClickzettaConfig) -> Connection:
|
||||
"""Get a connection from the pool or create a new one."""
|
||||
config_key = self._get_config_key(config)
|
||||
|
||||
@ -221,7 +223,7 @@ class ClickzettaConnectionPool:
|
||||
# No valid connection found, create new one
|
||||
return self._create_connection(config)
|
||||
|
||||
def return_connection(self, config: ClickzettaConfig, connection: "Connection"):
|
||||
def return_connection(self, config: ClickzettaConfig, connection: Connection):
|
||||
"""Return a connection to the pool."""
|
||||
config_key = self._get_config_key(config)
|
||||
|
||||
@ -315,22 +317,22 @@ class ClickzettaVector(BaseVector):
|
||||
self._connection_pool = ClickzettaConnectionPool.get_instance()
|
||||
self._init_write_queue()
|
||||
|
||||
def _get_connection(self) -> "Connection":
|
||||
def _get_connection(self) -> Connection:
|
||||
"""Get a connection from the pool."""
|
||||
return self._connection_pool.get_connection(self._config)
|
||||
|
||||
def _return_connection(self, connection: "Connection"):
|
||||
def _return_connection(self, connection: Connection):
|
||||
"""Return a connection to the pool."""
|
||||
self._connection_pool.return_connection(self._config, connection)
|
||||
|
||||
class ConnectionContext:
|
||||
"""Context manager for borrowing and returning connections."""
|
||||
|
||||
def __init__(self, vector_instance: "ClickzettaVector"):
|
||||
def __init__(self, vector_instance: ClickzettaVector):
|
||||
self.vector = vector_instance
|
||||
self.connection: Connection | None = None
|
||||
|
||||
def __enter__(self) -> "Connection":
|
||||
def __enter__(self) -> Connection:
|
||||
self.connection = self.vector._get_connection()
|
||||
return self.connection
|
||||
|
||||
@ -338,7 +340,7 @@ class ClickzettaVector(BaseVector):
|
||||
if self.connection:
|
||||
self.vector._return_connection(self.connection)
|
||||
|
||||
def get_connection_context(self) -> "ClickzettaVector.ConnectionContext":
|
||||
def get_connection_context(self) -> ClickzettaVector.ConnectionContext:
|
||||
"""Get a connection context manager."""
|
||||
return self.ConnectionContext(self)
|
||||
|
||||
@ -437,7 +439,7 @@ class ClickzettaVector(BaseVector):
|
||||
"""Return the vector database type."""
|
||||
return "clickzetta"
|
||||
|
||||
def _ensure_connection(self) -> "Connection":
|
||||
def _ensure_connection(self) -> Connection:
|
||||
"""Get a connection from the pool."""
|
||||
return self._get_connection()
|
||||
|
||||
@ -984,9 +986,11 @@ class ClickzettaVector(BaseVector):
|
||||
|
||||
# No need for dataset_id filter since each dataset has its own table
|
||||
|
||||
# Use simple quote escaping for LIKE clause
|
||||
escaped_query = query.replace("'", "''")
|
||||
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'")
|
||||
# Escape special characters for LIKE clause to prevent SQL injection
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_query = escape_like_pattern(query).replace("'", "''")
|
||||
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%' ESCAPE '\\\\'")
|
||||
where_clause = " AND ".join(filter_clauses)
|
||||
|
||||
search_sql = f"""
|
||||
|
||||
@ -287,11 +287,15 @@ class IrisVector(BaseVector):
|
||||
cursor.execute(sql, (query,))
|
||||
else:
|
||||
# Fallback to LIKE search (inefficient for large datasets)
|
||||
query_pattern = f"%{query}%"
|
||||
# Escape special characters for LIKE clause to prevent SQL injection
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
escaped_query = escape_like_pattern(query)
|
||||
query_pattern = f"%{escaped_query}%"
|
||||
sql = f"""
|
||||
SELECT TOP {top_k} id, text, meta
|
||||
FROM {self.schema}.{self.table_name}
|
||||
WHERE text LIKE ?
|
||||
WHERE text LIKE ? ESCAPE '\\'
|
||||
"""
|
||||
cursor.execute(sql, (query_pattern,))
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
@ -22,7 +24,7 @@ class DatasetDocumentStore:
|
||||
self._document_id = document_id
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore":
|
||||
def from_dict(cls, config_dict: dict[str, Any]) -> DatasetDocumentStore:
|
||||
return cls(**config_dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
|
||||
@ -7,10 +7,11 @@ import re
|
||||
import tempfile
|
||||
import uuid
|
||||
from urllib.parse import urlparse
|
||||
from xml.etree import ElementTree
|
||||
|
||||
import httpx
|
||||
from docx import Document as DocxDocument
|
||||
from docx.oxml.ns import qn
|
||||
from docx.text.run import Run
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
@ -229,44 +230,20 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
image_map = self._extract_images_from_docx(doc)
|
||||
|
||||
hyperlinks_url = None
|
||||
url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+")
|
||||
for para in doc.paragraphs:
|
||||
for run in para.runs:
|
||||
if run.text and hyperlinks_url:
|
||||
result = f" [{run.text}]({hyperlinks_url}) "
|
||||
run.text = result
|
||||
hyperlinks_url = None
|
||||
if "HYPERLINK" in run.element.xml:
|
||||
try:
|
||||
xml = ElementTree.XML(run.element.xml)
|
||||
x_child = [c for c in xml.iter() if c is not None]
|
||||
for x in x_child:
|
||||
if x is None:
|
||||
continue
|
||||
if x.tag.endswith("instrText"):
|
||||
if x.text is None:
|
||||
continue
|
||||
for i in url_pattern.findall(x.text):
|
||||
hyperlinks_url = str(i)
|
||||
except Exception:
|
||||
logger.exception("Failed to parse HYPERLINK xml")
|
||||
|
||||
def parse_paragraph(paragraph):
|
||||
paragraph_content = []
|
||||
|
||||
def append_image_link(image_id, has_drawing):
|
||||
def append_image_link(image_id, has_drawing, target_buffer):
|
||||
"""Helper to append image link from image_map based on relationship type."""
|
||||
rel = doc.part.rels[image_id]
|
||||
if rel.is_external:
|
||||
if image_id in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_id])
|
||||
target_buffer.append(image_map[image_id])
|
||||
else:
|
||||
image_part = rel.target_part
|
||||
if image_part in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
target_buffer.append(image_map[image_part])
|
||||
|
||||
for run in paragraph.runs:
|
||||
def process_run(run, target_buffer):
|
||||
# Helper to extract text and embedded images from a run element and append them to target_buffer
|
||||
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
|
||||
# Process drawing type images
|
||||
drawing_elements = run.element.findall(
|
||||
@ -287,13 +264,13 @@ class WordExtractor(BaseExtractor):
|
||||
# External image: use embed_id as key
|
||||
if embed_id in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[embed_id])
|
||||
target_buffer.append(image_map[embed_id])
|
||||
else:
|
||||
# Internal image: use target_part as key
|
||||
image_part = doc.part.related_parts.get(embed_id)
|
||||
if image_part in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[image_part])
|
||||
target_buffer.append(image_map[image_part])
|
||||
# Process pict type images
|
||||
shape_elements = run.element.findall(
|
||||
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
|
||||
@ -308,7 +285,7 @@ class WordExtractor(BaseExtractor):
|
||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
||||
)
|
||||
if image_id and image_id in doc.part.rels:
|
||||
append_image_link(image_id, has_drawing)
|
||||
append_image_link(image_id, has_drawing, target_buffer)
|
||||
# Find imagedata element in VML
|
||||
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
|
||||
if image_data is not None:
|
||||
@ -316,9 +293,93 @@ class WordExtractor(BaseExtractor):
|
||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
||||
)
|
||||
if image_id and image_id in doc.part.rels:
|
||||
append_image_link(image_id, has_drawing)
|
||||
append_image_link(image_id, has_drawing, target_buffer)
|
||||
if run.text.strip():
|
||||
paragraph_content.append(run.text.strip())
|
||||
target_buffer.append(run.text.strip())
|
||||
|
||||
def process_hyperlink(hyperlink_elem, target_buffer):
|
||||
# Helper to extract text from a hyperlink element and append it to target_buffer
|
||||
r_id = hyperlink_elem.get(qn("r:id"))
|
||||
|
||||
# Extract text from runs inside the hyperlink
|
||||
link_text_parts = []
|
||||
for run_elem in hyperlink_elem.findall(qn("w:r")):
|
||||
run = Run(run_elem, paragraph)
|
||||
# Hyperlink text may be split across multiple runs (e.g., with different formatting),
|
||||
# so collect all run texts first
|
||||
if run.text:
|
||||
link_text_parts.append(run.text)
|
||||
|
||||
link_text = "".join(link_text_parts).strip()
|
||||
|
||||
# Resolve URL
|
||||
if r_id:
|
||||
try:
|
||||
rel = doc.part.rels.get(r_id)
|
||||
if rel and rel.is_external:
|
||||
link_text = f"[{link_text or rel.target_ref}]({rel.target_ref})"
|
||||
except Exception:
|
||||
logger.exception("Failed to resolve URL for hyperlink with r:id: %s", r_id)
|
||||
|
||||
if link_text:
|
||||
target_buffer.append(link_text)
|
||||
|
||||
paragraph_content = []
|
||||
# State for legacy HYPERLINK fields
|
||||
hyperlink_field_url = None
|
||||
hyperlink_field_text_parts: list = []
|
||||
is_collecting_field_text = False
|
||||
# Iterate through paragraph elements in document order
|
||||
for child in paragraph._element:
|
||||
tag = child.tag
|
||||
if tag == qn("w:r"):
|
||||
# Regular run
|
||||
run = Run(child, paragraph)
|
||||
|
||||
# Check for fldChar (begin/end/separate) and instrText for legacy hyperlinks
|
||||
fld_chars = child.findall(qn("w:fldChar"))
|
||||
instr_texts = child.findall(qn("w:instrText"))
|
||||
|
||||
# Handle Fields
|
||||
if fld_chars or instr_texts:
|
||||
# Process instrText to find HYPERLINK "url"
|
||||
for instr in instr_texts:
|
||||
if instr.text and "HYPERLINK" in instr.text:
|
||||
# Quick regex to extract URL
|
||||
match = re.search(r'HYPERLINK\s+"([^"]+)"', instr.text, re.IGNORECASE)
|
||||
if match:
|
||||
hyperlink_field_url = match.group(1)
|
||||
|
||||
# Process fldChar
|
||||
for fld_char in fld_chars:
|
||||
fld_char_type = fld_char.get(qn("w:fldCharType"))
|
||||
if fld_char_type == "begin":
|
||||
# Start of a field: reset legacy link state
|
||||
hyperlink_field_url = None
|
||||
hyperlink_field_text_parts = []
|
||||
is_collecting_field_text = False
|
||||
elif fld_char_type == "separate":
|
||||
# Separator: if we found a URL, start collecting visible text
|
||||
if hyperlink_field_url:
|
||||
is_collecting_field_text = True
|
||||
elif fld_char_type == "end":
|
||||
# End of field
|
||||
if is_collecting_field_text and hyperlink_field_url:
|
||||
# Create markdown link and append to main content
|
||||
display_text = "".join(hyperlink_field_text_parts).strip()
|
||||
if display_text:
|
||||
link_md = f"[{display_text}]({hyperlink_field_url})"
|
||||
paragraph_content.append(link_md)
|
||||
# Reset state
|
||||
hyperlink_field_url = None
|
||||
hyperlink_field_text_parts = []
|
||||
is_collecting_field_text = False
|
||||
|
||||
# Decide where to append content
|
||||
target_buffer = hyperlink_field_text_parts if is_collecting_field_text else paragraph_content
|
||||
process_run(run, target_buffer)
|
||||
elif tag == qn("w:hyperlink"):
|
||||
process_hyperlink(child, paragraph_content)
|
||||
return "".join(paragraph_content) if paragraph_content else ""
|
||||
|
||||
paragraphs = doc.paragraphs.copy()
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
@ -16,7 +18,7 @@ class TaskWrapper(BaseModel):
|
||||
return self.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, serialized_data: str) -> "TaskWrapper":
|
||||
def deserialize(cls, serialized_data: str) -> TaskWrapper:
|
||||
return cls.model_validate_json(serialized_data)
|
||||
|
||||
|
||||
|
||||
@ -1198,18 +1198,24 @@ class DatasetRetrieval:
|
||||
|
||||
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
|
||||
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
match condition:
|
||||
case "contains":
|
||||
filters.append(json_field.like(f"%{value}%"))
|
||||
escaped_value = escape_like_pattern(str(value))
|
||||
filters.append(json_field.like(f"%{escaped_value}%", escape="\\"))
|
||||
|
||||
case "not contains":
|
||||
filters.append(json_field.notlike(f"%{value}%"))
|
||||
escaped_value = escape_like_pattern(str(value))
|
||||
filters.append(json_field.notlike(f"%{escaped_value}%", escape="\\"))
|
||||
|
||||
case "start with":
|
||||
filters.append(json_field.like(f"{value}%"))
|
||||
escaped_value = escape_like_pattern(str(value))
|
||||
filters.append(json_field.like(f"{escaped_value}%", escape="\\"))
|
||||
|
||||
case "end with":
|
||||
filters.append(json_field.like(f"%{value}"))
|
||||
escaped_value = escape_like_pattern(str(value))
|
||||
filters.append(json_field.like(f"%{escaped_value}", escape="\\"))
|
||||
|
||||
case "is" | "=":
|
||||
if isinstance(value, str):
|
||||
|
||||
@ -29,6 +29,7 @@ from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
EndUser,
|
||||
LLMGenerationDetail,
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
)
|
||||
@ -457,6 +458,113 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
session.merge(db_model)
|
||||
session.flush()
|
||||
|
||||
# Save LLMGenerationDetail for LLM nodes with successful execution
|
||||
if (
|
||||
domain_model.node_type == NodeType.LLM
|
||||
and domain_model.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
and domain_model.outputs is not None
|
||||
):
|
||||
self._save_llm_generation_detail(session, domain_model)
|
||||
|
||||
def _save_llm_generation_detail(self, session, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Save LLM generation detail for LLM nodes.
|
||||
Extracts reasoning_content, tool_calls, and sequence from outputs and metadata.
|
||||
"""
|
||||
outputs = execution.outputs or {}
|
||||
metadata = execution.metadata or {}
|
||||
|
||||
reasoning_list = self._extract_reasoning(outputs)
|
||||
tool_calls_list = self._extract_tool_calls(metadata.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG))
|
||||
|
||||
if not reasoning_list and not tool_calls_list:
|
||||
return
|
||||
|
||||
sequence = self._build_generation_sequence(outputs.get("text", ""), reasoning_list, tool_calls_list)
|
||||
self._upsert_generation_detail(session, execution, reasoning_list, tool_calls_list, sequence)
|
||||
|
||||
def _extract_reasoning(self, outputs: Mapping[str, Any]) -> list[str]:
|
||||
"""Extract reasoning_content as a clean list of non-empty strings."""
|
||||
reasoning_content = outputs.get("reasoning_content")
|
||||
if isinstance(reasoning_content, str):
|
||||
trimmed = reasoning_content.strip()
|
||||
return [trimmed] if trimmed else []
|
||||
if isinstance(reasoning_content, list):
|
||||
return [item.strip() for item in reasoning_content if isinstance(item, str) and item.strip()]
|
||||
return []
|
||||
|
||||
def _extract_tool_calls(self, agent_log: Any) -> list[dict[str, str]]:
|
||||
"""Extract tool call records from agent logs."""
|
||||
if not agent_log or not isinstance(agent_log, list):
|
||||
return []
|
||||
|
||||
tool_calls: list[dict[str, str]] = []
|
||||
for log in agent_log:
|
||||
log_data = log.data if hasattr(log, "data") else (log.get("data", {}) if isinstance(log, dict) else {})
|
||||
tool_name = log_data.get("tool_name")
|
||||
if tool_name and str(tool_name).strip():
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": log_data.get("tool_call_id", ""),
|
||||
"name": tool_name,
|
||||
"arguments": json.dumps(log_data.get("tool_args", {})),
|
||||
"result": str(log_data.get("output", "")),
|
||||
}
|
||||
)
|
||||
return tool_calls
|
||||
|
||||
def _build_generation_sequence(
|
||||
self, text: str, reasoning_list: list[str], tool_calls_list: list[dict[str, str]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build a simple content/reasoning/tool_call sequence."""
|
||||
sequence: list[dict[str, Any]] = []
|
||||
if text:
|
||||
sequence.append({"type": "content", "start": 0, "end": len(text)})
|
||||
for index in range(len(reasoning_list)):
|
||||
sequence.append({"type": "reasoning", "index": index})
|
||||
for index in range(len(tool_calls_list)):
|
||||
sequence.append({"type": "tool_call", "index": index})
|
||||
return sequence
|
||||
|
||||
def _upsert_generation_detail(
|
||||
self,
|
||||
session,
|
||||
execution: WorkflowNodeExecution,
|
||||
reasoning_list: list[str],
|
||||
tool_calls_list: list[dict[str, str]],
|
||||
sequence: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Insert or update LLMGenerationDetail with serialized fields."""
|
||||
existing = (
|
||||
session.query(LLMGenerationDetail)
|
||||
.filter_by(
|
||||
workflow_run_id=execution.workflow_execution_id,
|
||||
node_id=execution.node_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
reasoning_json = json.dumps(reasoning_list) if reasoning_list else None
|
||||
tool_calls_json = json.dumps(tool_calls_list) if tool_calls_list else None
|
||||
sequence_json = json.dumps(sequence) if sequence else None
|
||||
|
||||
if existing:
|
||||
existing.reasoning_content = reasoning_json
|
||||
existing.tool_calls = tool_calls_json
|
||||
existing.sequence = sequence_json
|
||||
return
|
||||
|
||||
generation_detail = LLMGenerationDetail(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_run_id=execution.workflow_execution_id,
|
||||
node_id=execution.node_id,
|
||||
reasoning_content=reasoning_json,
|
||||
tool_calls=tool_calls_json,
|
||||
sequence=sequence_json,
|
||||
)
|
||||
session.add(generation_detail)
|
||||
|
||||
def get_db_models_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Optional
|
||||
from typing import Any, ClassVar
|
||||
|
||||
|
||||
class SchemaRegistry:
|
||||
@ -11,7 +13,7 @@ class SchemaRegistry:
|
||||
|
||||
logger: ClassVar[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
_default_instance: ClassVar[Optional["SchemaRegistry"]] = None
|
||||
_default_instance: ClassVar[SchemaRegistry | None] = None
|
||||
_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
def __init__(self, base_dir: str):
|
||||
@ -20,7 +22,7 @@ class SchemaRegistry:
|
||||
self.metadata: MutableMapping[str, MutableMapping[str, Any]] = {}
|
||||
|
||||
@classmethod
|
||||
def default_registry(cls) -> "SchemaRegistry":
|
||||
def default_registry(cls) -> SchemaRegistry:
|
||||
"""Returns the default schema registry for builtin schemas (thread-safe singleton)"""
|
||||
if cls._default_instance is None:
|
||||
with cls._lock:
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
@ -6,6 +8,7 @@ from typing import TYPE_CHECKING, Any
|
||||
if TYPE_CHECKING:
|
||||
from models.model import File
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolEntity,
|
||||
@ -24,7 +27,7 @@ class Tool(ABC):
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool":
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> Tool:
|
||||
"""
|
||||
fork a new tool with metadata
|
||||
:return: the new tool
|
||||
@ -152,6 +155,60 @@ class Tool(ABC):
|
||||
|
||||
return parameters
|
||||
|
||||
def to_prompt_message_tool(self) -> PromptMessageTool:
|
||||
message_tool = PromptMessageTool(
|
||||
name=self.entity.identity.name,
|
||||
description=self.entity.description.llm if self.entity.description else "",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
|
||||
parameters = self.get_merged_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = parameter.type.as_normal_type()
|
||||
if parameter.type in {
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
ToolParameter.ToolParameterType.FILE,
|
||||
ToolParameter.ToolParameterType.FILES,
|
||||
}:
|
||||
# Determine the description based on parameter type
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
file_format_desc = " Input the file id with format: [File: file_id]."
|
||||
else:
|
||||
file_format_desc = "Input the file id with format: [Files: file_id1, file_id2, ...]. "
|
||||
|
||||
message_tool.parameters["properties"][parameter.name] = {
|
||||
"type": "string",
|
||||
"description": (parameter.llm_description or "") + file_format_desc,
|
||||
}
|
||||
continue
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options] if parameter.options else []
|
||||
|
||||
message_tool.parameters["properties"][parameter.name] = (
|
||||
{
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
if parameter.input_schema is None
|
||||
else parameter.input_schema
|
||||
)
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
|
||||
if parameter.required:
|
||||
message_tool.parameters["required"].append(parameter.name)
|
||||
|
||||
return message_tool
|
||||
|
||||
def create_image_message(
|
||||
self,
|
||||
image: str,
|
||||
@ -166,7 +223,7 @@ class Tool(ABC):
|
||||
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
|
||||
)
|
||||
|
||||
def create_file_message(self, file: "File") -> ToolInvokeMessage:
|
||||
def create_file_message(self, file: File) -> ToolInvokeMessage:
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.FILE,
|
||||
message=ToolInvokeMessage.FileMessage(),
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.tools.__base.tool import Tool
|
||||
@ -24,7 +26,7 @@ class BuiltinTool(Tool):
|
||||
super().__init__(**kwargs)
|
||||
self.provider = provider
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool":
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> BuiltinTool:
|
||||
"""
|
||||
fork a new tool with metadata
|
||||
:return: the new tool
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import Field
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -32,7 +34,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
self.tools = []
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
|
||||
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> ApiToolProviderController:
|
||||
credentials_schema = [
|
||||
ProviderConfig(
|
||||
name="auth_type",
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
from collections.abc import Mapping
|
||||
@ -55,7 +57,7 @@ class ToolProviderType(StrEnum):
|
||||
MCP = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ToolProviderType":
|
||||
def value_of(cls, value: str) -> ToolProviderType:
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@ -79,7 +81,7 @@ class ApiProviderSchemaType(StrEnum):
|
||||
OPENAI_ACTIONS = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ApiProviderSchemaType":
|
||||
def value_of(cls, value: str) -> ApiProviderSchemaType:
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@ -102,7 +104,7 @@ class ApiProviderAuthType(StrEnum):
|
||||
API_KEY_QUERY = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ApiProviderAuthType":
|
||||
def value_of(cls, value: str) -> ApiProviderAuthType:
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@ -307,7 +309,7 @@ class ToolParameter(PluginParameter):
|
||||
typ: ToolParameterType,
|
||||
required: bool,
|
||||
options: list[str] | None = None,
|
||||
) -> "ToolParameter":
|
||||
) -> ToolParameter:
|
||||
"""
|
||||
get a simple tool parameter
|
||||
|
||||
@ -429,14 +431,14 @@ class ToolInvokeMeta(BaseModel):
|
||||
tool_config: dict | None = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "ToolInvokeMeta":
|
||||
def empty(cls) -> ToolInvokeMeta:
|
||||
"""
|
||||
Get an empty instance of ToolInvokeMeta
|
||||
"""
|
||||
return cls(time_cost=0.0, error=None, tool_config={})
|
||||
|
||||
@classmethod
|
||||
def error_instance(cls, error: str) -> "ToolInvokeMeta":
|
||||
def error_instance(cls, error: str) -> ToolInvokeMeta:
|
||||
"""
|
||||
Get an instance of ToolInvokeMeta with error
|
||||
"""
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
@ -118,7 +120,7 @@ class MCPTool(Tool):
|
||||
for item in json_list:
|
||||
yield self.create_json_message(item)
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool:
|
||||
return MCPTool(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
@ -46,7 +48,7 @@ class PluginTool(Tool):
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool":
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> PluginTool:
|
||||
return PluginTool(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
|
||||
@ -7,12 +7,12 @@ import time
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def sign_tool_file(tool_file_id: str, extension: str) -> str:
|
||||
def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True) -> str:
|
||||
"""
|
||||
sign file to get a temporary url for plugin access
|
||||
"""
|
||||
# Use internal URL for plugin/tool file access in Docker environments
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
# Use internal URL for plugin/tool file access in Docker environments, unless for_external is True
|
||||
base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
|
||||
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
|
||||
from pydantic import Field
|
||||
@ -47,7 +49,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
self.provider_id = provider_id
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
|
||||
def from_db(cls, db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
app = session.get(App, db_provider.app_id)
|
||||
if not app:
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
@ -181,7 +183,7 @@ class WorkflowTool(Tool):
|
||||
return found
|
||||
return None
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> WorkflowTool:
|
||||
"""
|
||||
fork a new tool with metadata
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.file.models import File
|
||||
|
||||
@ -52,7 +54,7 @@ class SegmentType(StrEnum):
|
||||
return self in _ARRAY_TYPES
|
||||
|
||||
@classmethod
|
||||
def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]:
|
||||
def infer_segment_type(cls, value: Any) -> SegmentType | None:
|
||||
"""
|
||||
Attempt to infer the `SegmentType` based on the Python type of the `value` parameter.
|
||||
|
||||
@ -173,7 +175,7 @@ class SegmentType(StrEnum):
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
@staticmethod
|
||||
def cast_value(value: Any, type_: "SegmentType"):
|
||||
def cast_value(value: Any, type_: SegmentType):
|
||||
# Cast Python's `bool` type to `int` when the runtime type requires
|
||||
# an integer or number.
|
||||
#
|
||||
@ -193,7 +195,7 @@ class SegmentType(StrEnum):
|
||||
return [int(i) for i in value]
|
||||
return value
|
||||
|
||||
def exposed_type(self) -> "SegmentType":
|
||||
def exposed_type(self) -> SegmentType:
|
||||
"""Returns the type exposed to the frontend.
|
||||
|
||||
The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here.
|
||||
@ -202,7 +204,7 @@ class SegmentType(StrEnum):
|
||||
return SegmentType.NUMBER
|
||||
return self
|
||||
|
||||
def element_type(self) -> "SegmentType | None":
|
||||
def element_type(self) -> SegmentType | None:
|
||||
"""Return the element type of the current segment type, or `None` if the element type is undefined.
|
||||
|
||||
Raises:
|
||||
@ -217,7 +219,7 @@ class SegmentType(StrEnum):
|
||||
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
|
||||
|
||||
@staticmethod
|
||||
def get_zero_value(t: "SegmentType"):
|
||||
def get_zero_value(t: SegmentType):
|
||||
# Lazy import to avoid circular dependency
|
||||
from factories import variable_factory
|
||||
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .tool_entities import ToolCall, ToolCallResult, ToolResult, ToolResultStatus
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"ToolCall",
|
||||
"ToolCallResult",
|
||||
"ToolResult",
|
||||
"ToolResultStatus",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
]
|
||||
|
||||
39
api/core/workflow/entities/tool_entities.py
Normal file
39
api/core/workflow/entities/tool_entities.py
Normal file
@ -0,0 +1,39 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.file import File
|
||||
|
||||
|
||||
class ToolResultStatus(StrEnum):
|
||||
SUCCESS = "success"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str | None = Field(default=None, description="Unique identifier for this tool call")
|
||||
name: str | None = Field(default=None, description="Name of the tool being called")
|
||||
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
|
||||
icon: str | dict | None = Field(default=None, description="Icon of the tool")
|
||||
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
id: str | None = Field(default=None, description="Identifier of the tool call this result belongs to")
|
||||
name: str | None = Field(default=None, description="Name of the tool")
|
||||
output: str | None = Field(default=None, description="Tool output text, error or success message")
|
||||
files: list[str] = Field(default_factory=list, description="File produced by tool")
|
||||
status: ToolResultStatus | None = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
|
||||
elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool")
|
||||
icon: str | dict | None = Field(default=None, description="Icon of the tool")
|
||||
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
|
||||
|
||||
|
||||
class ToolCallResult(BaseModel):
|
||||
id: str | None = Field(default=None, description="Identifier for the tool call")
|
||||
name: str | None = Field(default=None, description="Name of the tool")
|
||||
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
|
||||
output: str | None = Field(default=None, description="Tool output text, error or success message")
|
||||
files: list[File] = Field(default_factory=list, description="File produced by tool")
|
||||
status: ToolResultStatus = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
|
||||
elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool")
|
||||
@ -5,6 +5,8 @@ Models are independent of the storage mechanism and don't contain
|
||||
implementation details like tenant_id, app_id, etc.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
@ -59,7 +61,7 @@ class WorkflowExecution(BaseModel):
|
||||
graph: Mapping[str, Any],
|
||||
inputs: Mapping[str, Any],
|
||||
started_at: datetime,
|
||||
) -> "WorkflowExecution":
|
||||
) -> WorkflowExecution:
|
||||
return WorkflowExecution(
|
||||
id_=id_,
|
||||
workflow_id=workflow_id,
|
||||
|
||||
@ -248,6 +248,8 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
LLM_CONTENT_SEQUENCE = "llm_content_sequence"
|
||||
LLM_TRACE = "llm_trace"
|
||||
COMPLETED_REASON = "completed_reason" # completed reason for loop node
|
||||
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
@ -175,7 +177,7 @@ class Graph:
|
||||
def _create_node_instances(
|
||||
cls,
|
||||
node_configs_map: dict[str, dict[str, object]],
|
||||
node_factory: "NodeFactory",
|
||||
node_factory: NodeFactory,
|
||||
) -> dict[str, Node]:
|
||||
"""
|
||||
Create node instances from configurations using the node factory.
|
||||
@ -197,7 +199,7 @@ class Graph:
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
def new(cls) -> "GraphBuilder":
|
||||
def new(cls) -> GraphBuilder:
|
||||
"""Create a fluent builder for assembling a graph programmatically."""
|
||||
|
||||
return GraphBuilder(graph_cls=cls)
|
||||
@ -284,9 +286,9 @@ class Graph:
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, object],
|
||||
node_factory: "NodeFactory",
|
||||
node_factory: NodeFactory,
|
||||
root_node_id: str | None = None,
|
||||
) -> "Graph":
|
||||
) -> Graph:
|
||||
"""
|
||||
Initialize graph
|
||||
|
||||
@ -383,7 +385,7 @@ class GraphBuilder:
|
||||
self._edges: list[Edge] = []
|
||||
self._edge_counter = 0
|
||||
|
||||
def add_root(self, node: Node) -> "GraphBuilder":
|
||||
def add_root(self, node: Node) -> GraphBuilder:
|
||||
"""Register the root node. Must be called exactly once."""
|
||||
|
||||
if self._nodes:
|
||||
@ -398,7 +400,7 @@ class GraphBuilder:
|
||||
*,
|
||||
from_node_id: str | None = None,
|
||||
source_handle: str = "source",
|
||||
) -> "GraphBuilder":
|
||||
) -> GraphBuilder:
|
||||
"""Append a node and connect it from the specified predecessor."""
|
||||
|
||||
if not self._nodes:
|
||||
@ -419,7 +421,7 @@ class GraphBuilder:
|
||||
|
||||
return self
|
||||
|
||||
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder":
|
||||
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder:
|
||||
"""Connect two existing nodes without adding a new node."""
|
||||
|
||||
if tail not in self._nodes_by_id:
|
||||
|
||||
@ -5,9 +5,12 @@ This engine uses a modular architecture with separated packages following
|
||||
Domain-Driven Design principles for improved maintainability and testability.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
@ -75,10 +78,13 @@ class GraphEngine:
|
||||
scale_down_idle_time: float | None = None,
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
# stop event
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
# Bind runtime state to current workflow context
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_runtime_state.stop_event = self._stop_event
|
||||
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
||||
self._command_channel = command_channel
|
||||
|
||||
@ -177,6 +183,7 @@ class GraphEngine:
|
||||
max_workers=self._max_workers,
|
||||
scale_up_threshold=self._scale_up_threshold,
|
||||
scale_down_idle_time=self._scale_down_idle_time,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
# === Orchestration ===
|
||||
@ -207,6 +214,7 @@ class GraphEngine:
|
||||
event_handler=self._event_handler_registry,
|
||||
execution_coordinator=self._execution_coordinator,
|
||||
event_emitter=self._event_manager,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
# === Validation ===
|
||||
@ -226,7 +234,7 @@ class GraphEngine:
|
||||
) -> None:
|
||||
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
|
||||
|
||||
def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
|
||||
def layer(self, layer: GraphEngineLayer) -> GraphEngine:
|
||||
"""Add a layer for extending functionality."""
|
||||
self._layers.append(layer)
|
||||
self._bind_layer_context(layer)
|
||||
@ -324,6 +332,7 @@ class GraphEngine:
|
||||
|
||||
def _start_execution(self, *, resume: bool = False) -> None:
|
||||
"""Start execution subsystems."""
|
||||
self._stop_event.clear()
|
||||
paused_nodes: list[str] = []
|
||||
if resume:
|
||||
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
|
||||
@ -351,13 +360,12 @@ class GraphEngine:
|
||||
|
||||
def _stop_execution(self) -> None:
|
||||
"""Stop execution subsystems."""
|
||||
self._stop_event.set()
|
||||
self._dispatcher.stop()
|
||||
self._worker_pool.stop()
|
||||
# Don't mark complete here as the dispatcher already does it
|
||||
|
||||
# Notify layers
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_graph_end(self._graph_execution.error)
|
||||
|
||||
@ -44,6 +44,7 @@ class Dispatcher:
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
event_handler: "EventHandler",
|
||||
execution_coordinator: ExecutionCoordinator,
|
||||
stop_event: threading.Event,
|
||||
event_emitter: EventManager | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@ -61,7 +62,7 @@ class Dispatcher:
|
||||
self._event_emitter = event_emitter
|
||||
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._stop_event = stop_event
|
||||
self._start_time: float | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
@ -69,16 +70,14 @@ class Dispatcher:
|
||||
if self._thread and self._thread.is_alive():
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._start_time = time.time()
|
||||
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the dispatcher thread."""
|
||||
self._stop_event.set()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=10.0)
|
||||
self._thread.join(timeout=2.0)
|
||||
|
||||
def _dispatcher_loop(self) -> None:
|
||||
"""Main dispatcher loop."""
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
Factory for creating ReadyQueue instances from serialized state.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .in_memory import InMemoryReadyQueue
|
||||
@ -11,7 +13,7 @@ if TYPE_CHECKING:
|
||||
from .protocol import ReadyQueue
|
||||
|
||||
|
||||
def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue":
|
||||
def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue:
|
||||
"""
|
||||
Create a ReadyQueue instance from a serialized state.
|
||||
|
||||
|
||||
@ -16,7 +16,13 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
|
||||
from core.workflow.graph_events import (
|
||||
ChunkType,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ToolCall,
|
||||
ToolResult,
|
||||
)
|
||||
from core.workflow.nodes.base.template import TextSegment, VariableSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
@ -321,11 +327,24 @@ class ResponseStreamCoordinator:
|
||||
selector: Sequence[str],
|
||||
chunk: str,
|
||||
is_final: bool = False,
|
||||
chunk_type: ChunkType = ChunkType.TEXT,
|
||||
tool_call: ToolCall | None = None,
|
||||
tool_result: ToolResult | None = None,
|
||||
) -> NodeRunStreamChunkEvent:
|
||||
"""Create a stream chunk event with consistent structure.
|
||||
|
||||
For selectors with special prefixes (sys, env, conversation), we use the
|
||||
active response node's information since these are not actual node IDs.
|
||||
|
||||
Args:
|
||||
node_id: The node ID to attribute the event to
|
||||
execution_id: The execution ID for this node
|
||||
selector: The variable selector
|
||||
chunk: The chunk content
|
||||
is_final: Whether this is the final chunk
|
||||
chunk_type: The semantic type of the chunk being streamed
|
||||
tool_call: Structured data for tool_call chunks
|
||||
tool_result: Structured data for tool_result chunks
|
||||
"""
|
||||
# Check if this is a special selector that doesn't correspond to a node
|
||||
if selector and selector[0] not in self._graph.nodes and self._active_session:
|
||||
@ -338,6 +357,9 @@ class ResponseStreamCoordinator:
|
||||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
chunk_type=chunk_type,
|
||||
tool_call=tool_call,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
# Standard case: selector refers to an actual node
|
||||
@ -349,6 +371,9 @@ class ResponseStreamCoordinator:
|
||||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
chunk_type=chunk_type,
|
||||
tool_call=tool_call,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
|
||||
@ -356,6 +381,8 @@ class ResponseStreamCoordinator:
|
||||
|
||||
Handles both regular node selectors and special system selectors (sys, env, conversation).
|
||||
For special selectors, we attribute the output to the active response node.
|
||||
|
||||
For object-type variables, automatically streams all child fields that have stream events.
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
source_selector_prefix = segment.selector[0] if segment.selector else ""
|
||||
@ -364,60 +391,81 @@ class ResponseStreamCoordinator:
|
||||
# Determine which node to attribute the output to
|
||||
# For special selectors (sys, env, conversation), use the active response node
|
||||
# For regular selectors, use the source node
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
# Special selector - use active response node
|
||||
output_node_id = self._active_session.node_id
|
||||
else:
|
||||
# Regular node selector
|
||||
output_node_id = source_selector_prefix
|
||||
active_session = self._active_session
|
||||
special_selector = bool(active_session and source_selector_prefix not in self._graph.nodes)
|
||||
output_node_id = active_session.node_id if special_selector and active_session else source_selector_prefix
|
||||
execution_id = self._get_or_create_execution_id(output_node_id)
|
||||
|
||||
# Stream all available chunks
|
||||
while self._has_unread_stream(segment.selector):
|
||||
if event := self._pop_stream_chunk(segment.selector):
|
||||
# For special selectors, we need to update the event to use
|
||||
# the active response node's information
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
response_node = self._graph.nodes[self._active_session.node_id]
|
||||
# Create a new event with the response node's information
|
||||
# but keep the original selector
|
||||
updated_event = NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=response_node.id,
|
||||
node_type=response_node.node_type,
|
||||
selector=event.selector, # Keep original selector
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
)
|
||||
events.append(updated_event)
|
||||
else:
|
||||
# Regular node selector - use event as is
|
||||
events.append(event)
|
||||
# Check if there's a direct stream for this selector
|
||||
has_direct_stream = (
|
||||
tuple(segment.selector) in self._stream_buffers or tuple(segment.selector) in self._closed_streams
|
||||
)
|
||||
|
||||
# Check if this is the last chunk by looking ahead
|
||||
stream_closed = self._is_stream_closed(segment.selector)
|
||||
# Check if stream is closed to determine if segment is complete
|
||||
if stream_closed:
|
||||
is_complete = True
|
||||
stream_targets = [segment.selector] if has_direct_stream else sorted(self._find_child_streams(segment.selector))
|
||||
|
||||
elif value := self._variable_pool.get(segment.selector):
|
||||
# Process scalar value
|
||||
is_last_segment = bool(
|
||||
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
)
|
||||
events.append(
|
||||
self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=segment.selector,
|
||||
chunk=value.markdown,
|
||||
is_final=is_last_segment,
|
||||
if stream_targets:
|
||||
all_complete = True
|
||||
|
||||
for target_selector in stream_targets:
|
||||
while self._has_unread_stream(target_selector):
|
||||
if event := self._pop_stream_chunk(target_selector):
|
||||
events.append(
|
||||
self._rewrite_stream_event(
|
||||
event=event,
|
||||
output_node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
special_selector=bool(special_selector),
|
||||
)
|
||||
)
|
||||
|
||||
if not self._is_stream_closed(target_selector):
|
||||
all_complete = False
|
||||
|
||||
is_complete = all_complete
|
||||
|
||||
# Fallback: check if scalar value exists in variable pool
|
||||
if not is_complete and not has_direct_stream:
|
||||
if value := self._variable_pool.get(segment.selector):
|
||||
# Process scalar value
|
||||
is_last_segment = bool(
|
||||
self._active_session
|
||||
and self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
)
|
||||
)
|
||||
is_complete = True
|
||||
events.append(
|
||||
self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=segment.selector,
|
||||
chunk=value.markdown,
|
||||
is_final=is_last_segment,
|
||||
)
|
||||
)
|
||||
is_complete = True
|
||||
|
||||
return events, is_complete
|
||||
|
||||
def _rewrite_stream_event(
|
||||
self,
|
||||
event: NodeRunStreamChunkEvent,
|
||||
output_node_id: str,
|
||||
execution_id: str,
|
||||
special_selector: bool,
|
||||
) -> NodeRunStreamChunkEvent:
|
||||
"""Rewrite event to attribute to active response node when selector is special."""
|
||||
if not special_selector:
|
||||
return event
|
||||
|
||||
return self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=event.chunk_type,
|
||||
tool_call=event.tool_call,
|
||||
tool_result=event.tool_result,
|
||||
)
|
||||
|
||||
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""Process a text segment. Returns (events, is_complete)."""
|
||||
assert self._active_session is not None
|
||||
@ -513,6 +561,36 @@ class ResponseStreamCoordinator:
|
||||
|
||||
# ============= Internal Stream Management Methods =============
|
||||
|
||||
def _find_child_streams(self, parent_selector: Sequence[str]) -> list[tuple[str, ...]]:
|
||||
"""Find all child stream selectors that are descendants of the parent selector.
|
||||
|
||||
For example, if parent_selector is ['llm', 'generation'], this will find:
|
||||
- ['llm', 'generation', 'content']
|
||||
- ['llm', 'generation', 'tool_calls']
|
||||
- ['llm', 'generation', 'tool_results']
|
||||
- ['llm', 'generation', 'thought']
|
||||
|
||||
Args:
|
||||
parent_selector: The parent selector to search for children
|
||||
|
||||
Returns:
|
||||
List of child selector tuples found in stream buffers or closed streams
|
||||
"""
|
||||
parent_key = tuple(parent_selector)
|
||||
parent_len = len(parent_key)
|
||||
child_streams: set[tuple[str, ...]] = set()
|
||||
|
||||
# Search in both active buffers and closed streams
|
||||
all_selectors = set(self._stream_buffers.keys()) | self._closed_streams
|
||||
|
||||
for selector_key in all_selectors:
|
||||
# Check if this selector is a direct child of the parent
|
||||
# Direct child means: len(child) == len(parent) + 1 and child starts with parent
|
||||
if len(selector_key) == parent_len + 1 and selector_key[:parent_len] == parent_key:
|
||||
child_streams.add(selector_key)
|
||||
|
||||
return sorted(child_streams)
|
||||
|
||||
def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Append a stream chunk to the internal buffer.
|
||||
|
||||
@ -5,6 +5,8 @@ This module contains the private ResponseSession class used internally
|
||||
by ResponseStreamCoordinator to manage streaming sessions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
@ -27,7 +29,7 @@ class ResponseSession:
|
||||
index: int = 0 # Current position in the template segments
|
||||
|
||||
@classmethod
|
||||
def from_node(cls, node: Node) -> "ResponseSession":
|
||||
def from_node(cls, node: Node) -> ResponseSession:
|
||||
"""
|
||||
Create a ResponseSession from an AnswerNode or EndNode.
|
||||
|
||||
|
||||
@ -42,6 +42,7 @@ class Worker(threading.Thread):
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: Sequence[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
worker_id: int = 0,
|
||||
flask_app: Flask | None = None,
|
||||
context_vars: contextvars.Context | None = None,
|
||||
@ -65,13 +66,16 @@ class Worker(threading.Thread):
|
||||
self._worker_id = worker_id
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
self._stop_event = threading.Event()
|
||||
self._last_task_time = time.time()
|
||||
self._stop_event = stop_event
|
||||
self._layers = layers if layers is not None else []
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to stop processing."""
|
||||
self._stop_event.set()
|
||||
"""Worker is controlled via shared stop_event from GraphEngine.
|
||||
|
||||
This method is a no-op retained for backward compatibility.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_idle(self) -> bool:
|
||||
|
||||
@ -41,6 +41,7 @@ class WorkerPool:
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: list[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
flask_app: "Flask | None" = None,
|
||||
context_vars: "Context | None" = None,
|
||||
min_workers: int | None = None,
|
||||
@ -81,6 +82,7 @@ class WorkerPool:
|
||||
self._worker_counter = 0
|
||||
self._lock = threading.RLock()
|
||||
self._running = False
|
||||
self._stop_event = stop_event
|
||||
|
||||
# No longer tracking worker states with callbacks to avoid lock contention
|
||||
|
||||
@ -135,7 +137,7 @@ class WorkerPool:
|
||||
# Wait for workers to finish
|
||||
for worker in self._workers:
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=10.0)
|
||||
worker.join(timeout=2.0)
|
||||
|
||||
self._workers.clear()
|
||||
|
||||
@ -152,6 +154,7 @@ class WorkerPool:
|
||||
worker_id=worker_id,
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
worker.start()
|
||||
|
||||
@ -36,6 +36,7 @@ from .loop import (
|
||||
|
||||
# Node events
|
||||
from .node import (
|
||||
ChunkType,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
@ -44,10 +45,13 @@ from .node import (
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ToolCall,
|
||||
ToolResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseGraphEvent",
|
||||
"ChunkType",
|
||||
"GraphEngineEvent",
|
||||
"GraphNodeEventBase",
|
||||
"GraphRunAbortedEvent",
|
||||
@ -73,4 +77,6 @@ __all__ = [
|
||||
"NodeRunStartedEvent",
|
||||
"NodeRunStreamChunkEvent",
|
||||
"NodeRunSucceededEvent",
|
||||
"ToolCall",
|
||||
"ToolResult",
|
||||
]
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
@ -21,13 +22,39 @@ class NodeRunStartedEvent(GraphNodeEventBase):
|
||||
provider_id: str = ""
|
||||
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
"""Stream chunk type for LLM-related events."""
|
||||
|
||||
TEXT = "text" # Normal text streaming
|
||||
TOOL_CALL = "tool_call" # Tool call arguments streaming
|
||||
TOOL_RESULT = "tool_result" # Tool execution result
|
||||
THOUGHT = "thought" # Agent thinking process (ReAct)
|
||||
THOUGHT_START = "thought_start" # Agent thought start
|
||||
THOUGHT_END = "thought_end" # Agent thought end
|
||||
|
||||
|
||||
class NodeRunStreamChunkEvent(GraphNodeEventBase):
|
||||
# Spec-compliant fields
|
||||
"""Stream chunk event for workflow node execution."""
|
||||
|
||||
# Base fields
|
||||
selector: Sequence[str] = Field(
|
||||
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
|
||||
)
|
||||
chunk: str = Field(..., description="the actual chunk content")
|
||||
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
|
||||
|
||||
# Tool call fields (when chunk_type == TOOL_CALL)
|
||||
tool_call: ToolCall | None = Field(
|
||||
default=None,
|
||||
description="structured payload for tool_call chunks",
|
||||
)
|
||||
|
||||
# Tool result fields (when chunk_type == TOOL_RESULT)
|
||||
tool_result: ToolResult | None = Field(
|
||||
default=None,
|
||||
description="structured payload for tool_result chunks",
|
||||
)
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user