diff --git a/.claude/settings.json b/.claude/settings.json index c5c514b5f5..509dbe8447 100644 --- a/.claude/settings.json +++ b/.claude/settings.json @@ -4,6 +4,6 @@ "context7@claude-plugins-official": true, "typescript-lsp@claude-plugins-official": true, "pyright-lsp@claude-plugins-official": true, - "ralph-wiggum@claude-plugins-official": true + "ralph-loop@claude-plugins-official": true } } diff --git a/.claude/skills/skill-creator/SKILL.md b/.claude/skills/skill-creator/SKILL.md new file mode 100644 index 0000000000..b49da5ac68 --- /dev/null +++ b/.claude/skills/skill-creator/SKILL.md @@ -0,0 +1,355 @@ +--- +name: skill-creator +description: Guide for creating effective skills. This skill should be used when users want to create a new skill (or update an existing skill) that extends Claude's capabilities with specialized knowledge, workflows, or tool integrations. +--- + +# Skill Creator + +This skill provides guidance for creating effective skills. + +## About Skills + +Skills are modular, self-contained packages that extend Claude's capabilities by providing +specialized knowledge, workflows, and tools. Think of them as "onboarding guides" for specific +domains or tasks—they transform Claude from a general-purpose agent into a specialized agent +equipped with procedural knowledge that no model can fully possess. + +### What Skills Provide + +1. Specialized workflows - Multi-step procedures for specific domains +2. Tool integrations - Instructions for working with specific file formats or APIs +3. Domain expertise - Company-specific knowledge, schemas, business logic +4. Bundled resources - Scripts, references, and assets for complex and repetitive tasks + +## Core Principles + +### Concise is Key + +The context window is a public good. Skills share the context window with everything else Claude needs: system prompt, conversation history, other Skills' metadata, and the actual user request. + +**Default assumption: Claude is already very smart.** Only add context Claude doesn't already have. Challenge each piece of information: "Does Claude really need this explanation?" and "Does this paragraph justify its token cost?" + +Prefer concise examples over verbose explanations. + +### Set Appropriate Degrees of Freedom + +Match the level of specificity to the task's fragility and variability: + +**High freedom (text-based instructions)**: Use when multiple approaches are valid, decisions depend on context, or heuristics guide the approach. + +**Medium freedom (pseudocode or scripts with parameters)**: Use when a preferred pattern exists, some variation is acceptable, or configuration affects behavior. + +**Low freedom (specific scripts, few parameters)**: Use when operations are fragile and error-prone, consistency is critical, or a specific sequence must be followed. + +Think of Claude as exploring a path: a narrow bridge with cliffs needs specific guardrails (low freedom), while an open field allows many routes (high freedom). + +### Anatomy of a Skill + +Every skill consists of a required SKILL.md file and optional bundled resources: + +``` +skill-name/ +├── SKILL.md (required) +│ ├── YAML frontmatter metadata (required) +│ │ ├── name: (required) +│ │ └── description: (required) +│ └── Markdown instructions (required) +└── Bundled Resources (optional) + ├── scripts/ - Executable code (Python/Bash/etc.) + ├── references/ - Documentation intended to be loaded into context as needed + └── assets/ - Files used in output (templates, icons, fonts, etc.) +``` + +#### SKILL.md (required) + +Every SKILL.md consists of: + +- **Frontmatter** (YAML): Contains `name` and `description` fields. These are the only fields that Claude reads to determine when the skill gets used, thus it is very important to be clear and comprehensive in describing what the skill is, and when it should be used. +- **Body** (Markdown): Instructions and guidance for using the skill. Only loaded AFTER the skill triggers (if at all). + +#### Bundled Resources (optional) + +##### Scripts (`scripts/`) + +Executable code (Python/Bash/etc.) for tasks that require deterministic reliability or are repeatedly rewritten. + +- **When to include**: When the same code is being rewritten repeatedly or deterministic reliability is needed +- **Example**: `scripts/rotate_pdf.py` for PDF rotation tasks +- **Benefits**: Token efficient, deterministic, may be executed without loading into context +- **Note**: Scripts may still need to be read by Claude for patching or environment-specific adjustments + +##### References (`references/`) + +Documentation and reference material intended to be loaded as needed into context to inform Claude's process and thinking. + +- **When to include**: For documentation that Claude should reference while working +- **Examples**: `references/finance.md` for financial schemas, `references/mnda.md` for company NDA template, `references/policies.md` for company policies, `references/api_docs.md` for API specifications +- **Use cases**: Database schemas, API documentation, domain knowledge, company policies, detailed workflow guides +- **Benefits**: Keeps SKILL.md lean, loaded only when Claude determines it's needed +- **Best practice**: If files are large (>10k words), include grep search patterns in SKILL.md +- **Avoid duplication**: Information should live in either SKILL.md or references files, not both. Prefer references files for detailed information unless it's truly core to the skill—this keeps SKILL.md lean while making information discoverable without hogging the context window. Keep only essential procedural instructions and workflow guidance in SKILL.md; move detailed reference material, schemas, and examples to references files. + +##### Assets (`assets/`) + +Files not intended to be loaded into context, but rather used within the output Claude produces. + +- **When to include**: When the skill needs files that will be used in the final output +- **Examples**: `assets/logo.png` for brand assets, `assets/slides.pptx` for PowerPoint templates, `assets/frontend-template/` for HTML/React boilerplate, `assets/font.ttf` for typography +- **Use cases**: Templates, images, icons, boilerplate code, fonts, sample documents that get copied or modified +- **Benefits**: Separates output resources from documentation, enables Claude to use files without loading them into context + +#### What to Not Include in a Skill + +A skill should only contain essential files that directly support its functionality. Do NOT create extraneous documentation or auxiliary files, including: + +- README.md +- INSTALLATION_GUIDE.md +- QUICK_REFERENCE.md +- CHANGELOG.md +- etc. + +The skill should only contain the information needed for an AI agent to do the job at hand. It should not contain auxilary context about the process that went into creating it, setup and testing procedures, user-facing documentation, etc. Creating additional documentation files just adds clutter and confusion. + +### Progressive Disclosure Design Principle + +Skills use a three-level loading system to manage context efficiently: + +1. **Metadata (name + description)** - Always in context (~100 words) +2. **SKILL.md body** - When skill triggers (<5k words) +3. **Bundled resources** - As needed by Claude (Unlimited because scripts can be executed without reading into context window) + +#### Progressive Disclosure Patterns + +Keep SKILL.md body to the essentials and under 500 lines to minimize context bloat. Split content into separate files when approaching this limit. When splitting out content into other files, it is very important to reference them from SKILL.md and describe clearly when to read them, to ensure the reader of the skill knows they exist and when to use them. + +**Key principle:** When a skill supports multiple variations, frameworks, or options, keep only the core workflow and selection guidance in SKILL.md. Move variant-specific details (patterns, examples, configuration) into separate reference files. + +**Pattern 1: High-level guide with references** + +```markdown +# PDF Processing + +## Quick start + +Extract text with pdfplumber: +[code example] + +## Advanced features + +- **Form filling**: See [FORMS.md](FORMS.md) for complete guide +- **API reference**: See [REFERENCE.md](REFERENCE.md) for all methods +- **Examples**: See [EXAMPLES.md](EXAMPLES.md) for common patterns +``` + +Claude loads FORMS.md, REFERENCE.md, or EXAMPLES.md only when needed. + +**Pattern 2: Domain-specific organization** + +For Skills with multiple domains, organize content by domain to avoid loading irrelevant context: + +``` +bigquery-skill/ +├── SKILL.md (overview and navigation) +└── reference/ + ├── finance.md (revenue, billing metrics) + ├── sales.md (opportunities, pipeline) + ├── product.md (API usage, features) + └── marketing.md (campaigns, attribution) +``` + +When a user asks about sales metrics, Claude only reads sales.md. + +Similarly, for skills supporting multiple frameworks or variants, organize by variant: + +``` +cloud-deploy/ +├── SKILL.md (workflow + provider selection) +└── references/ + ├── aws.md (AWS deployment patterns) + ├── gcp.md (GCP deployment patterns) + └── azure.md (Azure deployment patterns) +``` + +When the user chooses AWS, Claude only reads aws.md. + +**Pattern 3: Conditional details** + +Show basic content, link to advanced content: + +```markdown +# DOCX Processing + +## Creating documents + +Use docx-js for new documents. See [DOCX-JS.md](DOCX-JS.md). + +## Editing documents + +For simple edits, modify the XML directly. + +**For tracked changes**: See [REDLINING.md](REDLINING.md) +**For OOXML details**: See [OOXML.md](OOXML.md) +``` + +Claude reads REDLINING.md or OOXML.md only when the user needs those features. + +**Important guidelines:** + +- **Avoid deeply nested references** - Keep references one level deep from SKILL.md. All reference files should link directly from SKILL.md. +- **Structure longer reference files** - For files longer than 100 lines, include a table of contents at the top so Claude can see the full scope when previewing. + +## Skill Creation Process + +Skill creation involves these steps: + +1. Understand the skill with concrete examples +2. Plan reusable skill contents (scripts, references, assets) +3. Initialize the skill (run init_skill.py) +4. Edit the skill (implement resources and write SKILL.md) +5. Package the skill (run package_skill.py) +6. Iterate based on real usage + +Follow these steps in order, skipping only if there is a clear reason why they are not applicable. + +### Step 1: Understanding the Skill with Concrete Examples + +Skip this step only when the skill's usage patterns are already clearly understood. It remains valuable even when working with an existing skill. + +To create an effective skill, clearly understand concrete examples of how the skill will be used. This understanding can come from either direct user examples or generated examples that are validated with user feedback. + +For example, when building an image-editor skill, relevant questions include: + +- "What functionality should the image-editor skill support? Editing, rotating, anything else?" +- "Can you give some examples of how this skill would be used?" +- "I can imagine users asking for things like 'Remove the red-eye from this image' or 'Rotate this image'. Are there other ways you imagine this skill being used?" +- "What would a user say that should trigger this skill?" + +To avoid overwhelming users, avoid asking too many questions in a single message. Start with the most important questions and follow up as needed for better effectiveness. + +Conclude this step when there is a clear sense of the functionality the skill should support. + +### Step 2: Planning the Reusable Skill Contents + +To turn concrete examples into an effective skill, analyze each example by: + +1. Considering how to execute on the example from scratch +2. Identifying what scripts, references, and assets would be helpful when executing these workflows repeatedly + +Example: When building a `pdf-editor` skill to handle queries like "Help me rotate this PDF," the analysis shows: + +1. Rotating a PDF requires re-writing the same code each time +2. A `scripts/rotate_pdf.py` script would be helpful to store in the skill + +Example: When designing a `frontend-webapp-builder` skill for queries like "Build me a todo app" or "Build me a dashboard to track my steps," the analysis shows: + +1. Writing a frontend webapp requires the same boilerplate HTML/React each time +2. An `assets/hello-world/` template containing the boilerplate HTML/React project files would be helpful to store in the skill + +Example: When building a `big-query` skill to handle queries like "How many users have logged in today?" the analysis shows: + +1. Querying BigQuery requires re-discovering the table schemas and relationships each time +2. A `references/schema.md` file documenting the table schemas would be helpful to store in the skill + +To establish the skill's contents, analyze each concrete example to create a list of the reusable resources to include: scripts, references, and assets. + +### Step 3: Initializing the Skill + +At this point, it is time to actually create the skill. + +Skip this step only if the skill being developed already exists, and iteration or packaging is needed. In this case, continue to the next step. + +When creating a new skill from scratch, always run the `init_skill.py` script. The script conveniently generates a new template skill directory that automatically includes everything a skill requires, making the skill creation process much more efficient and reliable. + +Usage: + +```bash +scripts/init_skill.py --path +``` + +The script: + +- Creates the skill directory at the specified path +- Generates a SKILL.md template with proper frontmatter and TODO placeholders +- Creates example resource directories: `scripts/`, `references/`, and `assets/` +- Adds example files in each directory that can be customized or deleted + +After initialization, customize or remove the generated SKILL.md and example files as needed. + +### Step 4: Edit the Skill + +When editing the (newly-generated or existing) skill, remember that the skill is being created for another instance of Claude to use. Include information that would be beneficial and non-obvious to Claude. Consider what procedural knowledge, domain-specific details, or reusable assets would help another Claude instance execute these tasks more effectively. + +#### Learn Proven Design Patterns + +Consult these helpful guides based on your skill's needs: + +- **Multi-step processes**: See references/workflows.md for sequential workflows and conditional logic +- **Specific output formats or quality standards**: See references/output-patterns.md for template and example patterns + +These files contain established best practices for effective skill design. + +#### Start with Reusable Skill Contents + +To begin implementation, start with the reusable resources identified above: `scripts/`, `references/`, and `assets/` files. Note that this step may require user input. For example, when implementing a `brand-guidelines` skill, the user may need to provide brand assets or templates to store in `assets/`, or documentation to store in `references/`. + +Added scripts must be tested by actually running them to ensure there are no bugs and that the output matches what is expected. If there are many similar scripts, only a representative sample needs to be tested to ensure confidence that they all work while balancing time to completion. + +Any example files and directories not needed for the skill should be deleted. The initialization script creates example files in `scripts/`, `references/`, and `assets/` to demonstrate structure, but most skills won't need all of them. + +#### Update SKILL.md + +**Writing Guidelines:** Always use imperative/infinitive form. + +##### Frontmatter + +Write the YAML frontmatter with `name` and `description`: + +- `name`: The skill name +- `description`: This is the primary triggering mechanism for your skill, and helps Claude understand when to use the skill. + - Include both what the Skill does and specific triggers/contexts for when to use it. + - Include all "when to use" information here - Not in the body. The body is only loaded after triggering, so "When to Use This Skill" sections in the body are not helpful to Claude. + - Example description for a `docx` skill: "Comprehensive document creation, editing, and analysis with support for tracked changes, comments, formatting preservation, and text extraction. Use when Claude needs to work with professional documents (.docx files) for: (1) Creating new documents, (2) Modifying or editing content, (3) Working with tracked changes, (4) Adding comments, or any other document tasks" + +Do not include any other fields in YAML frontmatter. + +##### Body + +Write instructions for using the skill and its bundled resources. + +### Step 5: Packaging a Skill + +Once development of the skill is complete, it must be packaged into a distributable .skill file that gets shared with the user. The packaging process automatically validates the skill first to ensure it meets all requirements: + +```bash +scripts/package_skill.py +``` + +Optional output directory specification: + +```bash +scripts/package_skill.py ./dist +``` + +The packaging script will: + +1. **Validate** the skill automatically, checking: + + - YAML frontmatter format and required fields + - Skill naming conventions and directory structure + - Description completeness and quality + - File organization and resource references + +2. **Package** the skill if validation passes, creating a .skill file named after the skill (e.g., `my-skill.skill`) that includes all files and maintains the proper directory structure for distribution. The .skill file is a zip file with a .skill extension. + +If validation fails, the script will report the errors and exit without creating a package. Fix any validation errors and run the packaging command again. + +### Step 6: Iterate + +After testing the skill, users may request improvements. Often this happens right after using the skill, with fresh context of how the skill performed. + +**Iteration workflow:** + +1. Use the skill on real tasks +2. Notice struggles or inefficiencies +3. Identify how SKILL.md or bundled resources should be updated +4. Implement changes and test again diff --git a/.claude/skills/skill-creator/references/output-patterns.md b/.claude/skills/skill-creator/references/output-patterns.md new file mode 100644 index 0000000000..022e85fe5e --- /dev/null +++ b/.claude/skills/skill-creator/references/output-patterns.md @@ -0,0 +1,86 @@ +# Output Patterns + +Use these patterns when skills need to produce consistent, high-quality output. + +## Template Pattern + +Provide templates for output format. Match the level of strictness to your needs. + +**For strict requirements (like API responses or data formats):** + +```markdown +## Report structure + +ALWAYS use this exact template structure: + +# [Analysis Title] + +## Executive summary +[One-paragraph overview of key findings] + +## Key findings +- Finding 1 with supporting data +- Finding 2 with supporting data +- Finding 3 with supporting data + +## Recommendations +1. Specific actionable recommendation +2. Specific actionable recommendation +``` + +**For flexible guidance (when adaptation is useful):** + +```markdown +## Report structure + +Here is a sensible default format, but use your best judgment: + +# [Analysis Title] + +## Executive summary +[Overview] + +## Key findings +[Adapt sections based on what you discover] + +## Recommendations +[Tailor to the specific context] + +Adjust sections as needed for the specific analysis type. +``` + +## Examples Pattern + +For skills where output quality depends on seeing examples, provide input/output pairs: + +```markdown +## Commit message format + +Generate commit messages following these examples: + +**Example 1:** +Input: Added user authentication with JWT tokens +Output: +``` + +feat(auth): implement JWT-based authentication + +Add login endpoint and token validation middleware + +``` + +**Example 2:** +Input: Fixed bug where dates displayed incorrectly in reports +Output: +``` + +fix(reports): correct date formatting in timezone conversion + +Use UTC timestamps consistently across report generation + +``` + +Follow this style: type(scope): brief description, then detailed explanation. +``` + +Examples help Claude understand the desired style and level of detail more clearly than descriptions alone. diff --git a/.claude/skills/skill-creator/references/workflows.md b/.claude/skills/skill-creator/references/workflows.md new file mode 100644 index 0000000000..54b0174078 --- /dev/null +++ b/.claude/skills/skill-creator/references/workflows.md @@ -0,0 +1,28 @@ +# Workflow Patterns + +## Sequential Workflows + +For complex tasks, break operations into clear, sequential steps. It is often helpful to give Claude an overview of the process towards the beginning of SKILL.md: + +```markdown +Filling a PDF form involves these steps: + +1. Analyze the form (run analyze_form.py) +2. Create field mapping (edit fields.json) +3. Validate mapping (run validate_fields.py) +4. Fill the form (run fill_form.py) +5. Verify output (run verify_output.py) +``` + +## Conditional Workflows + +For tasks with branching logic, guide Claude through decision points: + +```markdown +1. Determine the modification type: + **Creating new content?** → Follow "Creation workflow" below + **Editing existing content?** → Follow "Editing workflow" below + +2. Creation workflow: [steps] +3. Editing workflow: [steps] +``` diff --git a/.claude/skills/skill-creator/scripts/init_skill.py b/.claude/skills/skill-creator/scripts/init_skill.py new file mode 100755 index 0000000000..249fffcbbd --- /dev/null +++ b/.claude/skills/skill-creator/scripts/init_skill.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 +""" +Skill Initializer - Creates a new skill from template + +Usage: + init_skill.py --path + +Examples: + init_skill.py my-new-skill --path skills/public + init_skill.py my-api-helper --path skills/private + init_skill.py custom-skill --path /custom/location +""" + +import sys +from pathlib import Path + + +SKILL_TEMPLATE = """--- +name: {skill_name} +description: [TODO: Complete and informative explanation of what the skill does and when to use it. Include WHEN to use this skill - specific scenarios, file types, or tasks that trigger it.] +--- + +# {skill_title} + +## Overview + +[TODO: 1-2 sentences explaining what this skill enables] + +## Structuring This Skill + +[TODO: Choose the structure that best fits this skill's purpose. Common patterns: + +**1. Workflow-Based** (best for sequential processes) +- Works well when there are clear step-by-step procedures +- Example: DOCX skill with "Workflow Decision Tree" → "Reading" → "Creating" → "Editing" +- Structure: ## Overview → ## Workflow Decision Tree → ## Step 1 → ## Step 2... + +**2. Task-Based** (best for tool collections) +- Works well when the skill offers different operations/capabilities +- Example: PDF skill with "Quick Start" → "Merge PDFs" → "Split PDFs" → "Extract Text" +- Structure: ## Overview → ## Quick Start → ## Task Category 1 → ## Task Category 2... + +**3. Reference/Guidelines** (best for standards or specifications) +- Works well for brand guidelines, coding standards, or requirements +- Example: Brand styling with "Brand Guidelines" → "Colors" → "Typography" → "Features" +- Structure: ## Overview → ## Guidelines → ## Specifications → ## Usage... + +**4. Capabilities-Based** (best for integrated systems) +- Works well when the skill provides multiple interrelated features +- Example: Product Management with "Core Capabilities" → numbered capability list +- Structure: ## Overview → ## Core Capabilities → ### 1. Feature → ### 2. Feature... + +Patterns can be mixed and matched as needed. Most skills combine patterns (e.g., start with task-based, add workflow for complex operations). + +Delete this entire "Structuring This Skill" section when done - it's just guidance.] + +## [TODO: Replace with the first main section based on chosen structure] + +[TODO: Add content here. See examples in existing skills: +- Code samples for technical skills +- Decision trees for complex workflows +- Concrete examples with realistic user requests +- References to scripts/templates/references as needed] + +## Resources + +This skill includes example resource directories that demonstrate how to organize different types of bundled resources: + +### scripts/ +Executable code (Python/Bash/etc.) that can be run directly to perform specific operations. + +**Examples from other skills:** +- PDF skill: `fill_fillable_fields.py`, `extract_form_field_info.py` - utilities for PDF manipulation +- DOCX skill: `document.py`, `utilities.py` - Python modules for document processing + +**Appropriate for:** Python scripts, shell scripts, or any executable code that performs automation, data processing, or specific operations. + +**Note:** Scripts may be executed without loading into context, but can still be read by Claude for patching or environment adjustments. + +### references/ +Documentation and reference material intended to be loaded into context to inform Claude's process and thinking. + +**Examples from other skills:** +- Product management: `communication.md`, `context_building.md` - detailed workflow guides +- BigQuery: API reference documentation and query examples +- Finance: Schema documentation, company policies + +**Appropriate for:** In-depth documentation, API references, database schemas, comprehensive guides, or any detailed information that Claude should reference while working. + +### assets/ +Files not intended to be loaded into context, but rather used within the output Claude produces. + +**Examples from other skills:** +- Brand styling: PowerPoint template files (.pptx), logo files +- Frontend builder: HTML/React boilerplate project directories +- Typography: Font files (.ttf, .woff2) + +**Appropriate for:** Templates, boilerplate code, document templates, images, icons, fonts, or any files meant to be copied or used in the final output. + +--- + +**Any unneeded directories can be deleted.** Not every skill requires all three types of resources. +""" + +EXAMPLE_SCRIPT = '''#!/usr/bin/env python3 +""" +Example helper script for {skill_name} + +This is a placeholder script that can be executed directly. +Replace with actual implementation or delete if not needed. + +Example real scripts from other skills: +- pdf/scripts/fill_fillable_fields.py - Fills PDF form fields +- pdf/scripts/convert_pdf_to_images.py - Converts PDF pages to images +""" + +def main(): + print("This is an example script for {skill_name}") + # TODO: Add actual script logic here + # This could be data processing, file conversion, API calls, etc. + +if __name__ == "__main__": + main() +''' + +EXAMPLE_REFERENCE = """# Reference Documentation for {skill_title} + +This is a placeholder for detailed reference documentation. +Replace with actual reference content or delete if not needed. + +Example real reference docs from other skills: +- product-management/references/communication.md - Comprehensive guide for status updates +- product-management/references/context_building.md - Deep-dive on gathering context +- bigquery/references/ - API references and query examples + +## When Reference Docs Are Useful + +Reference docs are ideal for: +- Comprehensive API documentation +- Detailed workflow guides +- Complex multi-step processes +- Information too lengthy for main SKILL.md +- Content that's only needed for specific use cases + +## Structure Suggestions + +### API Reference Example +- Overview +- Authentication +- Endpoints with examples +- Error codes +- Rate limits + +### Workflow Guide Example +- Prerequisites +- Step-by-step instructions +- Common patterns +- Troubleshooting +- Best practices +""" + +EXAMPLE_ASSET = """# Example Asset File + +This placeholder represents where asset files would be stored. +Replace with actual asset files (templates, images, fonts, etc.) or delete if not needed. + +Asset files are NOT intended to be loaded into context, but rather used within +the output Claude produces. + +Example asset files from other skills: +- Brand guidelines: logo.png, slides_template.pptx +- Frontend builder: hello-world/ directory with HTML/React boilerplate +- Typography: custom-font.ttf, font-family.woff2 +- Data: sample_data.csv, test_dataset.json + +## Common Asset Types + +- Templates: .pptx, .docx, boilerplate directories +- Images: .png, .jpg, .svg, .gif +- Fonts: .ttf, .otf, .woff, .woff2 +- Boilerplate code: Project directories, starter files +- Icons: .ico, .svg +- Data files: .csv, .json, .xml, .yaml + +Note: This is a text placeholder. Actual assets can be any file type. +""" + + +def title_case_skill_name(skill_name): + """Convert hyphenated skill name to Title Case for display.""" + return " ".join(word.capitalize() for word in skill_name.split("-")) + + +def init_skill(skill_name, path): + """ + Initialize a new skill directory with template SKILL.md. + + Args: + skill_name: Name of the skill + path: Path where the skill directory should be created + + Returns: + Path to created skill directory, or None if error + """ + # Determine skill directory path + skill_dir = Path(path).resolve() / skill_name + + # Check if directory already exists + if skill_dir.exists(): + print(f"❌ Error: Skill directory already exists: {skill_dir}") + return None + + # Create skill directory + try: + skill_dir.mkdir(parents=True, exist_ok=False) + print(f"✅ Created skill directory: {skill_dir}") + except Exception as e: + print(f"❌ Error creating directory: {e}") + return None + + # Create SKILL.md from template + skill_title = title_case_skill_name(skill_name) + skill_content = SKILL_TEMPLATE.format(skill_name=skill_name, skill_title=skill_title) + + skill_md_path = skill_dir / "SKILL.md" + try: + skill_md_path.write_text(skill_content) + print("✅ Created SKILL.md") + except Exception as e: + print(f"❌ Error creating SKILL.md: {e}") + return None + + # Create resource directories with example files + try: + # Create scripts/ directory with example script + scripts_dir = skill_dir / "scripts" + scripts_dir.mkdir(exist_ok=True) + example_script = scripts_dir / "example.py" + example_script.write_text(EXAMPLE_SCRIPT.format(skill_name=skill_name)) + example_script.chmod(0o755) + print("✅ Created scripts/example.py") + + # Create references/ directory with example reference doc + references_dir = skill_dir / "references" + references_dir.mkdir(exist_ok=True) + example_reference = references_dir / "api_reference.md" + example_reference.write_text(EXAMPLE_REFERENCE.format(skill_title=skill_title)) + print("✅ Created references/api_reference.md") + + # Create assets/ directory with example asset placeholder + assets_dir = skill_dir / "assets" + assets_dir.mkdir(exist_ok=True) + example_asset = assets_dir / "example_asset.txt" + example_asset.write_text(EXAMPLE_ASSET) + print("✅ Created assets/example_asset.txt") + except Exception as e: + print(f"❌ Error creating resource directories: {e}") + return None + + # Print next steps + print(f"\n✅ Skill '{skill_name}' initialized successfully at {skill_dir}") + print("\nNext steps:") + print("1. Edit SKILL.md to complete the TODO items and update the description") + print("2. Customize or delete the example files in scripts/, references/, and assets/") + print("3. Run the validator when ready to check the skill structure") + + return skill_dir + + +def main(): + if len(sys.argv) < 4 or sys.argv[2] != "--path": + print("Usage: init_skill.py --path ") + print("\nSkill name requirements:") + print(" - Hyphen-case identifier (e.g., 'data-analyzer')") + print(" - Lowercase letters, digits, and hyphens only") + print(" - Max 40 characters") + print(" - Must match directory name exactly") + print("\nExamples:") + print(" init_skill.py my-new-skill --path skills/public") + print(" init_skill.py my-api-helper --path skills/private") + print(" init_skill.py custom-skill --path /custom/location") + sys.exit(1) + + skill_name = sys.argv[1] + path = sys.argv[3] + + print(f"🚀 Initializing skill: {skill_name}") + print(f" Location: {path}") + print() + + result = init_skill(skill_name, path) + + if result: + sys.exit(0) + else: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/.claude/skills/skill-creator/scripts/package_skill.py b/.claude/skills/skill-creator/scripts/package_skill.py new file mode 100755 index 0000000000..736b928be0 --- /dev/null +++ b/.claude/skills/skill-creator/scripts/package_skill.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +""" +Skill Packager - Creates a distributable .skill file of a skill folder + +Usage: + python utils/package_skill.py [output-directory] + +Example: + python utils/package_skill.py skills/public/my-skill + python utils/package_skill.py skills/public/my-skill ./dist +""" + +import sys +import zipfile +from pathlib import Path +from quick_validate import validate_skill + + +def package_skill(skill_path, output_dir=None): + """ + Package a skill folder into a .skill file. + + Args: + skill_path: Path to the skill folder + output_dir: Optional output directory for the .skill file (defaults to current directory) + + Returns: + Path to the created .skill file, or None if error + """ + skill_path = Path(skill_path).resolve() + + # Validate skill folder exists + if not skill_path.exists(): + print(f"❌ Error: Skill folder not found: {skill_path}") + return None + + if not skill_path.is_dir(): + print(f"❌ Error: Path is not a directory: {skill_path}") + return None + + # Validate SKILL.md exists + skill_md = skill_path / "SKILL.md" + if not skill_md.exists(): + print(f"❌ Error: SKILL.md not found in {skill_path}") + return None + + # Run validation before packaging + print("🔍 Validating skill...") + valid, message = validate_skill(skill_path) + if not valid: + print(f"❌ Validation failed: {message}") + print(" Please fix the validation errors before packaging.") + return None + print(f"✅ {message}\n") + + # Determine output location + skill_name = skill_path.name + if output_dir: + output_path = Path(output_dir).resolve() + output_path.mkdir(parents=True, exist_ok=True) + else: + output_path = Path.cwd() + + skill_filename = output_path / f"{skill_name}.skill" + + # Create the .skill file (zip format) + try: + with zipfile.ZipFile(skill_filename, "w", zipfile.ZIP_DEFLATED) as zipf: + # Walk through the skill directory + for file_path in skill_path.rglob("*"): + if file_path.is_file(): + # Calculate the relative path within the zip + arcname = file_path.relative_to(skill_path.parent) + zipf.write(file_path, arcname) + print(f" Added: {arcname}") + + print(f"\n✅ Successfully packaged skill to: {skill_filename}") + return skill_filename + + except Exception as e: + print(f"❌ Error creating .skill file: {e}") + return None + + +def main(): + if len(sys.argv) < 2: + print("Usage: python utils/package_skill.py [output-directory]") + print("\nExample:") + print(" python utils/package_skill.py skills/public/my-skill") + print(" python utils/package_skill.py skills/public/my-skill ./dist") + sys.exit(1) + + skill_path = sys.argv[1] + output_dir = sys.argv[2] if len(sys.argv) > 2 else None + + print(f"📦 Packaging skill: {skill_path}") + if output_dir: + print(f" Output directory: {output_dir}") + print() + + result = package_skill(skill_path, output_dir) + + if result: + sys.exit(0) + else: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/.claude/skills/skill-creator/scripts/quick_validate.py b/.claude/skills/skill-creator/scripts/quick_validate.py new file mode 100755 index 0000000000..66eb0a71bf --- /dev/null +++ b/.claude/skills/skill-creator/scripts/quick_validate.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +""" +Quick validation script for skills - minimal version +""" + +import sys +import os +import re +import yaml +from pathlib import Path + + +def validate_skill(skill_path): + """Basic validation of a skill""" + skill_path = Path(skill_path) + + # Check SKILL.md exists + skill_md = skill_path / "SKILL.md" + if not skill_md.exists(): + return False, "SKILL.md not found" + + # Read and validate frontmatter + content = skill_md.read_text() + if not content.startswith("---"): + return False, "No YAML frontmatter found" + + # Extract frontmatter + match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL) + if not match: + return False, "Invalid frontmatter format" + + frontmatter_text = match.group(1) + + # Parse YAML frontmatter + try: + frontmatter = yaml.safe_load(frontmatter_text) + if not isinstance(frontmatter, dict): + return False, "Frontmatter must be a YAML dictionary" + except yaml.YAMLError as e: + return False, f"Invalid YAML in frontmatter: {e}" + + # Define allowed properties + ALLOWED_PROPERTIES = {"name", "description", "license", "allowed-tools", "metadata"} + + # Check for unexpected properties (excluding nested keys under metadata) + unexpected_keys = set(frontmatter.keys()) - ALLOWED_PROPERTIES + if unexpected_keys: + return False, ( + f"Unexpected key(s) in SKILL.md frontmatter: {', '.join(sorted(unexpected_keys))}. " + f"Allowed properties are: {', '.join(sorted(ALLOWED_PROPERTIES))}" + ) + + # Check required fields + if "name" not in frontmatter: + return False, "Missing 'name' in frontmatter" + if "description" not in frontmatter: + return False, "Missing 'description' in frontmatter" + + # Extract name for validation + name = frontmatter.get("name", "") + if not isinstance(name, str): + return False, f"Name must be a string, got {type(name).__name__}" + name = name.strip() + if name: + # Check naming convention (hyphen-case: lowercase with hyphens) + if not re.match(r"^[a-z0-9-]+$", name): + return False, f"Name '{name}' should be hyphen-case (lowercase letters, digits, and hyphens only)" + if name.startswith("-") or name.endswith("-") or "--" in name: + return False, f"Name '{name}' cannot start/end with hyphen or contain consecutive hyphens" + # Check name length (max 64 characters per spec) + if len(name) > 64: + return False, f"Name is too long ({len(name)} characters). Maximum is 64 characters." + + # Extract and validate description + description = frontmatter.get("description", "") + if not isinstance(description, str): + return False, f"Description must be a string, got {type(description).__name__}" + description = description.strip() + if description: + # Check for angle brackets + if "<" in description or ">" in description: + return False, "Description cannot contain angle brackets (< or >)" + # Check description length (max 1024 characters per spec) + if len(description) > 1024: + return False, f"Description is too long ({len(description)} characters). Maximum is 1024 characters." + + return True, "Skill is valid!" + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python quick_validate.py ") + sys.exit(1) + + valid, message = validate_skill(sys.argv[1]) + print(message) + sys.exit(0 if valid else 1) diff --git a/api/.importlinter b/api/.importlinter index acb21ae522..2dec958788 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -53,7 +53,6 @@ ignore_imports = core.workflow.nodes.llm.llm_utils -> extensions.ext_database core.workflow.nodes.llm.node -> extensions.ext_database core.workflow.nodes.tool.tool_node -> extensions.ext_database - core.workflow.nodes.variable_assigner.common.impl -> extensions.ext_database core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis core.workflow.graph_engine.manager -> extensions.ext_redis core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis diff --git a/api/Dockerfile b/api/Dockerfile index e800e60322..a08d4e3aab 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -50,16 +50,33 @@ WORKDIR /app/api # Create non-root user ARG dify_uid=1001 +ARG NODE_MAJOR=22 +ARG NODE_PACKAGE_VERSION=22.21.0-1nodesource1 +ARG NODESOURCE_KEY_FPR=6F71F525282841EEDAF851B42F59B5F99B1BE0B4 RUN groupadd -r -g ${dify_uid} dify && \ useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \ chown -R dify:dify /app RUN \ apt-get update \ + && apt-get install -y --no-install-recommends \ + ca-certificates \ + curl \ + gnupg \ + && mkdir -p /etc/apt/keyrings \ + && curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key -o /tmp/nodesource.gpg \ + && gpg --show-keys --with-colons /tmp/nodesource.gpg \ + | awk -F: '/^fpr:/ {print $10}' \ + | grep -Fx "${NODESOURCE_KEY_FPR}" \ + && gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg /tmp/nodesource.gpg \ + && rm -f /tmp/nodesource.gpg \ + && echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_${NODE_MAJOR}.x nodistro main" \ + > /etc/apt/sources.list.d/nodesource.list \ + && apt-get update \ # Install dependencies && apt-get install -y --no-install-recommends \ # basic environment - curl nodejs \ + nodejs=${NODE_PACKAGE_VERSION} \ # for gmpy2 \ libgmp-dev libmpfr-dev libmpc-dev \ # For Security diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 44cf89d6a9..d66bb7063f 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,14 +1,16 @@ import re import uuid -from typing import Literal +from datetime import datetime +from typing import Any, Literal, TypeAlias from flask import request -from flask_restx import Resource, fields, marshal, marshal_with -from pydantic import BaseModel, Field, field_validator +from flask_restx import Resource +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import ( @@ -19,27 +21,19 @@ from controllers.console.wraps import ( is_admin_or_owner_required, setup_required, ) +from core.file import helpers as file_helpers from core.ops.ops_trace_manager import OpsTraceManager from core.workflow.enums import NodeType from extensions.ext_database import db -from fields.app_fields import ( - deleted_tool_fields, - model_config_fields, - model_config_partial_fields, - site_fields, - tag_fields, -) -from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict -from libs.helper import AppIconUrlField, TimestampField from libs.login import current_account_with_tenant, login_required from models import App, Workflow +from models.model import IconType from services.app_dsl_service import AppDslService, ImportMode from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class AppListQuery(BaseModel): @@ -192,124 +186,292 @@ class AppTracePayload(BaseModel): return value -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +JSONValue: TypeAlias = Any -reg(AppListQuery) -reg(CreateAppPayload) -reg(UpdateAppPayload) -reg(CopyAppPayload) -reg(AppExportQuery) -reg(AppNamePayload) -reg(AppIconPayload) -reg(AppSiteStatusPayload) -reg(AppApiStatusPayload) -reg(AppTracePayload) +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) -# Register models for flask_restx to avoid dict type issues in Swagger -# Register base models first -tag_model = console_ns.model("Tag", tag_fields) -workflow_partial_model = console_ns.model("WorkflowPartial", _workflow_partial_fields_dict) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -model_config_model = console_ns.model("ModelConfig", model_config_fields) -model_config_partial_model = console_ns.model("ModelConfigPartial", model_config_partial_fields) +def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None: + if icon is None or icon_type is None: + return None + icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type) + if icon_type_value.lower() != IconType.IMAGE.value: + return None + return file_helpers.get_signed_file_url(icon) -deleted_tool_model = console_ns.model("DeletedTool", deleted_tool_fields) -site_model = console_ns.model("Site", site_fields) +class Tag(ResponseModel): + id: str + name: str + type: str -app_partial_model = console_ns.model( - "AppPartial", - { - "id": fields.String, - "name": fields.String, - "max_active_requests": fields.Raw(), - "description": fields.String(attribute="desc_or_prompt"), - "mode": fields.String(attribute="mode_compatible_with_agent"), - "icon_type": fields.String, - "icon": fields.String, - "icon_background": fields.String, - "icon_url": AppIconUrlField, - "model_config": fields.Nested(model_config_partial_model, attribute="app_model_config", allow_null=True), - "workflow": fields.Nested(workflow_partial_model, allow_null=True), - "use_icon_as_answer_icon": fields.Boolean, - "created_by": fields.String, - "created_at": TimestampField, - "updated_by": fields.String, - "updated_at": TimestampField, - "tags": fields.List(fields.Nested(tag_model)), - "access_mode": fields.String, - "create_user_name": fields.String, - "author_name": fields.String, - "has_draft_trigger": fields.Boolean, - }, -) -app_detail_model = console_ns.model( - "AppDetail", - { - "id": fields.String, - "name": fields.String, - "description": fields.String, - "mode": fields.String(attribute="mode_compatible_with_agent"), - "icon": fields.String, - "icon_background": fields.String, - "enable_site": fields.Boolean, - "enable_api": fields.Boolean, - "model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True), - "workflow": fields.Nested(workflow_partial_model, allow_null=True), - "tracing": fields.Raw, - "use_icon_as_answer_icon": fields.Boolean, - "created_by": fields.String, - "created_at": TimestampField, - "updated_by": fields.String, - "updated_at": TimestampField, - "access_mode": fields.String, - "tags": fields.List(fields.Nested(tag_model)), - }, -) +class WorkflowPartial(ResponseModel): + id: str + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None -app_detail_with_site_model = console_ns.model( - "AppDetailWithSite", - { - "id": fields.String, - "name": fields.String, - "description": fields.String, - "mode": fields.String(attribute="mode_compatible_with_agent"), - "icon_type": fields.String, - "icon": fields.String, - "icon_background": fields.String, - "icon_url": AppIconUrlField, - "enable_site": fields.Boolean, - "enable_api": fields.Boolean, - "model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True), - "workflow": fields.Nested(workflow_partial_model, allow_null=True), - "api_base_url": fields.String, - "use_icon_as_answer_icon": fields.Boolean, - "max_active_requests": fields.Integer, - "created_by": fields.String, - "created_at": TimestampField, - "updated_by": fields.String, - "updated_at": TimestampField, - "deleted_tools": fields.List(fields.Nested(deleted_tool_model)), - "access_mode": fields.String, - "tags": fields.List(fields.Nested(tag_model)), - "site": fields.Nested(site_model), - }, -) + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -app_pagination_model = console_ns.model( - "AppPagination", - { - "page": fields.Integer, - "limit": fields.Integer(attribute="per_page"), - "total": fields.Integer, - "has_more": fields.Boolean(attribute="has_next"), - "data": fields.List(fields.Nested(app_partial_model), attribute="items"), - }, + +class ModelConfigPartial(ResponseModel): + model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model")) + pre_prompt: str | None = None + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class ModelConfig(ResponseModel): + opening_statement: str | None = None + suggested_questions: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("suggested_questions_list", "suggested_questions") + ) + suggested_questions_after_answer: JSONValue | None = Field( + default=None, + validation_alias=AliasChoices("suggested_questions_after_answer_dict", "suggested_questions_after_answer"), + ) + speech_to_text: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("speech_to_text_dict", "speech_to_text") + ) + text_to_speech: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("text_to_speech_dict", "text_to_speech") + ) + retriever_resource: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("retriever_resource_dict", "retriever_resource") + ) + annotation_reply: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("annotation_reply_dict", "annotation_reply") + ) + more_like_this: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("more_like_this_dict", "more_like_this") + ) + sensitive_word_avoidance: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("sensitive_word_avoidance_dict", "sensitive_word_avoidance") + ) + external_data_tools: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("external_data_tools_list", "external_data_tools") + ) + model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model")) + user_input_form: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("user_input_form_list", "user_input_form") + ) + dataset_query_variable: str | None = None + pre_prompt: str | None = None + agent_mode: JSONValue | None = Field(default=None, validation_alias=AliasChoices("agent_mode_dict", "agent_mode")) + prompt_type: str | None = None + chat_prompt_config: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("chat_prompt_config_dict", "chat_prompt_config") + ) + completion_prompt_config: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("completion_prompt_config_dict", "completion_prompt_config") + ) + dataset_configs: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("dataset_configs_dict", "dataset_configs") + ) + file_upload: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("file_upload_dict", "file_upload") + ) + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class Site(ResponseModel): + access_token: str | None = Field(default=None, validation_alias="code") + code: str | None = None + title: str | None = None + icon_type: str | IconType | None = None + icon: str | None = None + icon_background: str | None = None + description: str | None = None + default_language: str | None = None + chat_color_theme: str | None = None + chat_color_theme_inverted: bool | None = None + customize_domain: str | None = None + copyright: str | None = None + privacy_policy: str | None = None + custom_disclaimer: str | None = None + customize_token_strategy: str | None = None + prompt_public: bool | None = None + app_base_url: str | None = None + show_workflow_steps: bool | None = None + use_icon_as_answer_icon: bool | None = None + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + + @computed_field(return_type=str | None) # type: ignore + @property + def icon_url(self) -> str | None: + return _build_icon_url(self.icon_type, self.icon) + + @field_validator("icon_type", mode="before") + @classmethod + def _normalize_icon_type(cls, value: str | IconType | None) -> str | None: + if isinstance(value, IconType): + return value.value + return value + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class DeletedTool(ResponseModel): + type: str + tool_name: str + provider_id: str + + +class AppPartial(ResponseModel): + id: str + name: str + max_active_requests: int | None = None + description: str | None = Field(default=None, validation_alias=AliasChoices("desc_or_prompt", "description")) + mode: str = Field(validation_alias="mode_compatible_with_agent") + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + model_config_: ModelConfigPartial | None = Field( + default=None, + validation_alias=AliasChoices("app_model_config", "model_config"), + alias="model_config", + ) + workflow: WorkflowPartial | None = None + use_icon_as_answer_icon: bool | None = None + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + tags: list[Tag] = Field(default_factory=list) + access_mode: str | None = None + create_user_name: str | None = None + author_name: str | None = None + has_draft_trigger: bool | None = None + + @computed_field(return_type=str | None) # type: ignore + @property + def icon_url(self) -> str | None: + return _build_icon_url(self.icon_type, self.icon) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AppDetail(ResponseModel): + id: str + name: str + description: str | None = None + mode: str = Field(validation_alias="mode_compatible_with_agent") + icon: str | None = None + icon_background: str | None = None + enable_site: bool + enable_api: bool + model_config_: ModelConfig | None = Field( + default=None, + validation_alias=AliasChoices("app_model_config", "model_config"), + alias="model_config", + ) + workflow: WorkflowPartial | None = None + tracing: JSONValue | None = None + use_icon_as_answer_icon: bool | None = None + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + access_mode: str | None = None + tags: list[Tag] = Field(default_factory=list) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AppDetailWithSite(AppDetail): + icon_type: str | None = None + api_base_url: str | None = None + max_active_requests: int | None = None + deleted_tools: list[DeletedTool] = Field(default_factory=list) + site: Site | None = None + + @computed_field(return_type=str | None) # type: ignore + @property + def icon_url(self) -> str | None: + return _build_icon_url(self.icon_type, self.icon) + + +class AppPagination(ResponseModel): + page: int + limit: int = Field(validation_alias=AliasChoices("per_page", "limit")) + total: int + has_more: bool = Field(validation_alias=AliasChoices("has_next", "has_more")) + data: list[AppPartial] = Field(validation_alias=AliasChoices("items", "data")) + + +class AppExportResponse(ResponseModel): + data: str + + +register_schema_models( + console_ns, + AppListQuery, + CreateAppPayload, + UpdateAppPayload, + CopyAppPayload, + AppExportQuery, + AppNamePayload, + AppIconPayload, + AppSiteStatusPayload, + AppApiStatusPayload, + AppTracePayload, + Tag, + WorkflowPartial, + ModelConfigPartial, + ModelConfig, + Site, + DeletedTool, + AppPartial, + AppDetail, + AppDetailWithSite, + AppPagination, + AppExportResponse, ) @@ -318,7 +480,7 @@ class AppListApi(Resource): @console_ns.doc("list_apps") @console_ns.doc(description="Get list of applications with pagination and filtering") @console_ns.expect(console_ns.models[AppListQuery.__name__]) - @console_ns.response(200, "Success", app_pagination_model) + @console_ns.response(200, "Success", console_ns.models[AppPagination.__name__]) @setup_required @login_required @account_initialization_required @@ -334,7 +496,8 @@ class AppListApi(Resource): app_service = AppService() app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict) if not app_pagination: - return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} + empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[]) + return empty.model_dump(mode="json"), 200 if FeatureService.get_system_features().webapp_auth.enabled: app_ids = [str(app.id) for app in app_pagination.items] @@ -378,18 +541,18 @@ class AppListApi(Resource): for app in app_pagination.items: app.has_draft_trigger = str(app.id) in draft_trigger_app_ids - return marshal(app_pagination, app_pagination_model), 200 + pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True) + return pagination_model.model_dump(mode="json"), 200 @console_ns.doc("create_app") @console_ns.doc(description="Create a new application") @console_ns.expect(console_ns.models[CreateAppPayload.__name__]) - @console_ns.response(201, "App created successfully", app_detail_model) + @console_ns.response(201, "App created successfully", console_ns.models[AppDetail.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required - @marshal_with(app_detail_model) @cloud_edition_billing_resource_check("apps") @edit_permission_required def post(self): @@ -399,8 +562,8 @@ class AppListApi(Resource): app_service = AppService() app = app_service.create_app(current_tenant_id, args.model_dump(), current_user) - - return app, 201 + app_detail = AppDetail.model_validate(app, from_attributes=True) + return app_detail.model_dump(mode="json"), 201 @console_ns.route("/apps/") @@ -408,13 +571,12 @@ class AppApi(Resource): @console_ns.doc("get_app_detail") @console_ns.doc(description="Get application details") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.response(200, "Success", app_detail_with_site_model) + @console_ns.response(200, "Success", console_ns.models[AppDetailWithSite.__name__]) @setup_required @login_required @account_initialization_required @enterprise_license_required - @get_app_model - @marshal_with(app_detail_with_site_model) + @get_app_model(mode=None) def get(self, app_model): """Get app detail""" app_service = AppService() @@ -425,21 +587,21 @@ class AppApi(Resource): app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id)) app_model.access_mode = app_setting.access_mode - return app_model + response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.doc("update_app") @console_ns.doc(description="Update application details") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[UpdateAppPayload.__name__]) - @console_ns.response(200, "App updated successfully", app_detail_with_site_model) + @console_ns.response(200, "App updated successfully", console_ns.models[AppDetailWithSite.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required - @get_app_model + @get_app_model(mode=None) @edit_permission_required - @marshal_with(app_detail_with_site_model) def put(self, app_model): """Update app""" args = UpdateAppPayload.model_validate(console_ns.payload) @@ -456,8 +618,8 @@ class AppApi(Resource): "max_active_requests": args.max_active_requests or 0, } app_model = app_service.update_app(app_model, args_dict) - - return app_model + response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.doc("delete_app") @console_ns.doc(description="Delete application") @@ -483,14 +645,13 @@ class AppCopyApi(Resource): @console_ns.doc(description="Create a copy of an existing application") @console_ns.doc(params={"app_id": "Application ID to copy"}) @console_ns.expect(console_ns.models[CopyAppPayload.__name__]) - @console_ns.response(201, "App copied successfully", app_detail_with_site_model) + @console_ns.response(201, "App copied successfully", console_ns.models[AppDetailWithSite.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required - @get_app_model + @get_app_model(mode=None) @edit_permission_required - @marshal_with(app_detail_with_site_model) def post(self, app_model): """Copy app""" # The role of the current user in the ta table must be admin, owner, or editor @@ -516,7 +677,8 @@ class AppCopyApi(Resource): stmt = select(App).where(App.id == result.app_id) app = session.scalar(stmt) - return app, 201 + response_model = AppDetailWithSite.model_validate(app, from_attributes=True) + return response_model.model_dump(mode="json"), 201 @console_ns.route("/apps//export") @@ -525,11 +687,7 @@ class AppExportApi(Resource): @console_ns.doc(description="Export application configuration as DSL") @console_ns.doc(params={"app_id": "Application ID to export"}) @console_ns.expect(console_ns.models[AppExportQuery.__name__]) - @console_ns.response( - 200, - "App exported successfully", - console_ns.model("AppExportResponse", {"data": fields.String(description="DSL export data")}), - ) + @console_ns.response(200, "App exported successfully", console_ns.models[AppExportResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @get_app_model @setup_required @@ -540,13 +698,14 @@ class AppExportApi(Resource): """Export app""" args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - return { - "data": AppDslService.export_dsl( + payload = AppExportResponse( + data=AppDslService.export_dsl( app_model=app_model, include_secret=args.include_secret, workflow_id=args.workflow_id, ) - } + ) + return payload.model_dump(mode="json") @console_ns.route("/apps//name") @@ -555,20 +714,19 @@ class AppNameApi(Resource): @console_ns.doc(description="Check if app name is available") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[AppNamePayload.__name__]) - @console_ns.response(200, "Name availability checked") + @console_ns.response(200, "Name availability checked", console_ns.models[AppDetail.__name__]) @setup_required @login_required @account_initialization_required - @get_app_model - @marshal_with(app_detail_model) + @get_app_model(mode=None) @edit_permission_required def post(self, app_model): args = AppNamePayload.model_validate(console_ns.payload) app_service = AppService() app_model = app_service.update_app_name(app_model, args.name) - - return app_model + response_model = AppDetail.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.route("/apps//icon") @@ -582,16 +740,15 @@ class AppIconApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model - @marshal_with(app_detail_model) + @get_app_model(mode=None) @edit_permission_required def post(self, app_model): args = AppIconPayload.model_validate(console_ns.payload or {}) app_service = AppService() app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "") - - return app_model + response_model = AppDetail.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.route("/apps//site-enable") @@ -600,21 +757,20 @@ class AppSiteStatus(Resource): @console_ns.doc(description="Enable or disable app site") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__]) - @console_ns.response(200, "Site status updated successfully", app_detail_model) + @console_ns.response(200, "Site status updated successfully", console_ns.models[AppDetail.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required - @get_app_model - @marshal_with(app_detail_model) + @get_app_model(mode=None) @edit_permission_required def post(self, app_model): args = AppSiteStatusPayload.model_validate(console_ns.payload) app_service = AppService() app_model = app_service.update_app_site_status(app_model, args.enable_site) - - return app_model + response_model = AppDetail.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.route("/apps//api-enable") @@ -623,21 +779,20 @@ class AppApiStatus(Resource): @console_ns.doc(description="Enable or disable app API") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[AppApiStatusPayload.__name__]) - @console_ns.response(200, "API status updated successfully", app_detail_model) + @console_ns.response(200, "API status updated successfully", console_ns.models[AppDetail.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @is_admin_or_owner_required @account_initialization_required - @get_app_model - @marshal_with(app_detail_model) + @get_app_model(mode=None) def post(self, app_model): args = AppApiStatusPayload.model_validate(console_ns.payload) app_service = AppService() app_model = app_service.update_app_api_status(app_model, args.enable_api) - - return app_model + response_model = AppDetail.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.route("/apps//trace") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index ef2f86d4be..56816dd462 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -348,10 +348,13 @@ class CompletionConversationApi(Resource): ) if args.keyword: + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(args.keyword) query = query.join(Message, Message.conversation_id == Conversation.id).where( or_( - Message.query.ilike(f"%{args.keyword}%"), - Message.answer.ilike(f"%{args.keyword}%"), + Message.query.ilike(f"%{escaped_keyword}%", escape="\\"), + Message.answer.ilike(f"%{escaped_keyword}%", escape="\\"), ) ) @@ -460,7 +463,10 @@ class ChatConversationApi(Resource): query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False)) if args.keyword: - keyword_filter = f"%{args.keyword}%" + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(args.keyword) + keyword_filter = f"%{escaped_keyword}%" query = ( query.join( Message, @@ -469,11 +475,11 @@ class ChatConversationApi(Resource): .join(subquery, subquery.c.conversation_id == Conversation.id) .where( or_( - Message.query.ilike(keyword_filter), - Message.answer.ilike(keyword_filter), - Conversation.name.ilike(keyword_filter), - Conversation.introduction.ilike(keyword_filter), - subquery.c.from_end_user_session_id.ilike(keyword_filter), + Message.query.ilike(keyword_filter, escape="\\"), + Message.answer.ilike(keyword_filter, escape="\\"), + Conversation.name.ilike(keyword_filter, escape="\\"), + Conversation.introduction.ilike(keyword_filter, escape="\\"), + subquery.c.from_end_user_session_id.ilike(keyword_filter, escape="\\"), ), ) .group_by(Conversation.id) diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 12ada8b798..66f4524156 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -202,6 +202,7 @@ message_detail_model = console_ns.model( "status": fields.String, "error": fields.String, "parent_message_id": fields.String, + "generation_detail": fields.Raw, }, ) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 772d98822e..4a52bf8abe 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,3 +1,5 @@ +from typing import Any + import flask_login from flask import make_response, request from flask_restx import Resource @@ -96,14 +98,13 @@ class LoginApi(Resource): if is_login_error_rate_limit: raise EmailPasswordLoginLimitError() - # TODO: why invitation is re-assigned with different type? - invitation = args.invite_token # type: ignore - if invitation: - invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore + invitation_data: dict[str, Any] | None = None + if args.invite_token: + invitation_data = RegisterService.get_invitation_if_token_valid(None, args.email, args.invite_token) try: - if invitation: - data = invitation.get("data", {}) # type: ignore + if invitation_data: + data = invitation_data.get("data", {}) invitee_email = data.get("email") if data else None if invitee_email != args.email: raise InvalidEmailError() diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 5a536af6d2..16fecb41c6 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -30,6 +30,7 @@ from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields +from libs.helper import escape_like_pattern from libs.login import current_account_with_tenant, login_required from models.dataset import ChildChunk, DocumentSegment from models.model import UploadFile @@ -145,6 +146,8 @@ class DatasetDocumentSegmentListApi(Resource): query = query.where(DocumentSegment.hit_count >= hit_count_gte) if keyword: + # Escape special characters in keyword to prevent SQL injection via LIKE wildcards + escaped_keyword = escape_like_pattern(keyword) # Search in both content and keywords fields # Use database-specific methods for JSON array search if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql": @@ -156,15 +159,15 @@ class DatasetDocumentSegmentListApi(Resource): .scalar_subquery() ), ",", - ).ilike(f"%{keyword}%") + ).ilike(f"%{escaped_keyword}%", escape="\\") else: # MySQL: Cast JSON to string for pattern matching # MySQL stores Chinese text directly in JSON without Unicode escaping - keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{keyword}%") + keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{escaped_keyword}%", escape="\\") query = query.where( or_( - DocumentSegment.content.ilike(f"%{keyword}%"), + DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"), keywords_condition, ) ) diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index db7c50f422..db1a874437 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,7 +1,7 @@ import logging from typing import Any -from flask_restx import marshal, reqparse +from flask_restx import marshal from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -56,15 +56,10 @@ class DatasetsHitTestingBase: HitTestingService.hit_testing_args_check(args) @staticmethod - def parse_args(): - parser = ( - reqparse.RequestParser() - .add_argument("query", type=str, required=False, location="json") - .add_argument("attachment_ids", type=list, required=False, location="json") - .add_argument("retrieval_model", type=dict, required=False, location="json") - .add_argument("external_retrieval_model", type=dict, required=False, location="json") - ) - return parser.parse_args() + def parse_args(payload: dict[str, Any]) -> dict[str, Any]: + """Validate and return hit-testing arguments from an incoming payload.""" + hit_testing_payload = HitTestingPayload.model_validate(payload or {}) + return hit_testing_payload.model_dump(exclude_none=True) @staticmethod def perform_hit_testing(dataset, args): diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 46d67f0581..02efc54eea 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -355,7 +355,7 @@ class PublishedRagPipelineRunApi(Resource): pipeline=pipeline, user=current_user, args=args, - invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED, + invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED_PIPELINE, streaming=streaming, ) diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 29417dc896..109a3cd0d3 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -1,7 +1,7 @@ from typing import Literal from flask import request -from flask_restx import Resource, marshal_with +from flask_restx import Resource from werkzeug.exceptions import Forbidden import services @@ -15,18 +15,21 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_schema_models from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, setup_required, ) from extensions.ext_database import db -from fields.file_fields import file_fields, upload_config_fields +from fields.file_fields import FileResponse, UploadConfig from libs.login import current_account_with_tenant, login_required from services.file_service import FileService from . import console_ns +register_schema_models(console_ns, UploadConfig, FileResponse) + PREVIEW_WORDS_LIMIT = 3000 @@ -35,26 +38,27 @@ class FileApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(upload_config_fields) + @console_ns.response(200, "Success", console_ns.models[UploadConfig.__name__]) def get(self): - return { - "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, - "batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT, - "file_upload_limit": dify_config.BATCH_UPLOAD_LIMIT, - "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, - "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, - "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, - "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, - "image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT, - "single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT, - "attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT, - }, 200 + config = UploadConfig( + file_size_limit=dify_config.UPLOAD_FILE_SIZE_LIMIT, + batch_count_limit=dify_config.UPLOAD_FILE_BATCH_LIMIT, + file_upload_limit=dify_config.BATCH_UPLOAD_LIMIT, + image_file_size_limit=dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, + video_file_size_limit=dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, + audio_file_size_limit=dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, + workflow_file_upload_limit=dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, + image_file_batch_limit=dify_config.IMAGE_FILE_BATCH_LIMIT, + single_chunk_attachment_limit=dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT, + attachment_image_file_size_limit=dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT, + ) + return config.model_dump(mode="json"), 200 @setup_required @login_required @account_initialization_required - @marshal_with(file_fields) @cloud_edition_billing_resource_check("documents") + @console_ns.response(201, "File uploaded successfully", console_ns.models[FileResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() source_str = request.form.get("source") @@ -90,7 +94,8 @@ class FileApi(Resource): except services.errors.file.BlockedFileExtensionError as blocked_extension_error: raise BlockedFileExtensionError(blocked_extension_error.description) - return upload_file, 201 + response = FileResponse.model_validate(upload_file, from_attributes=True) + return response.model_dump(mode="json"), 201 @console_ns.route("/files//preview") diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 47eef7eb7e..70c7b80ffa 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -1,7 +1,7 @@ import urllib.parse import httpx -from flask_restx import Resource, marshal_with +from flask_restx import Resource from pydantic import BaseModel, Field import services @@ -11,19 +11,22 @@ from controllers.common.errors import ( RemoteFileUploadError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_schema_models from core.file import helpers as file_helpers from core.helper import ssrf_proxy from extensions.ext_database import db -from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields +from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from libs.login import current_account_with_tenant from services.file_service import FileService from . import console_ns +register_schema_models(console_ns, RemoteFileInfo, FileWithSignedUrl) + @console_ns.route("/remote-files/") class RemoteFileInfoApi(Resource): - @marshal_with(remote_file_info_fields) + @console_ns.response(200, "Remote file info", console_ns.models[RemoteFileInfo.__name__]) def get(self, url): decoded_url = urllib.parse.unquote(url) resp = ssrf_proxy.head(decoded_url) @@ -31,10 +34,11 @@ class RemoteFileInfoApi(Resource): # failed back to get method resp = ssrf_proxy.get(decoded_url, timeout=3) resp.raise_for_status() - return { - "file_type": resp.headers.get("Content-Type", "application/octet-stream"), - "file_length": int(resp.headers.get("Content-Length", 0)), - } + info = RemoteFileInfo( + file_type=resp.headers.get("Content-Type", "application/octet-stream"), + file_length=int(resp.headers.get("Content-Length", 0)), + ) + return info.model_dump(mode="json") class RemoteFileUploadPayload(BaseModel): @@ -50,7 +54,7 @@ console_ns.schema_model( @console_ns.route("/remote-files/upload") class RemoteFileUploadApi(Resource): @console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__]) - @marshal_with(file_fields_with_signed_url) + @console_ns.response(201, "Remote file uploaded", console_ns.models[FileWithSignedUrl.__name__]) def post(self): args = RemoteFileUploadPayload.model_validate(console_ns.payload) url = args.url @@ -85,13 +89,14 @@ class RemoteFileUploadApi(Resource): except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return { - "id": upload_file.id, - "name": upload_file.name, - "size": upload_file.size, - "extension": upload_file.extension, - "url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id), - "mime_type": upload_file.mime_type, - "created_by": upload_file.created_by, - "created_at": upload_file.created_at, - }, 201 + payload = FileWithSignedUrl( + id=upload_file.id, + name=upload_file.name, + size=upload_file.size, + extension=upload_file.extension, + url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id), + mime_type=upload_file.mime_type, + created_by=upload_file.created_by, + created_at=int(upload_file.created_at.timestamp()), + ) + return payload.model_dump(mode="json"), 201 diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 55eaa2f09f..03ad0f423b 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime from typing import Literal @@ -99,7 +101,7 @@ class AccountPasswordPayload(BaseModel): repeat_new_password: str @model_validator(mode="after") - def check_passwords_match(self) -> "AccountPasswordPayload": + def check_passwords_match(self) -> AccountPasswordPayload: if self.new_password != self.repeat_new_password: raise RepeatPasswordNotMatchError() return self diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index 6096a87c56..28ec4b3935 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -4,18 +4,18 @@ from flask import request from flask_restx import Resource from flask_restx.api import HTTPStatus from pydantic import BaseModel, Field -from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden import services from core.file.helpers import verify_plugin_file_signature from core.tools.tool_file_manager import ToolFileManager -from fields.file_fields import build_file_model +from fields.file_fields import FileResponse from ..common.errors import ( FileTooLargeError, UnsupportedFileTypeError, ) +from ..common.schema import register_schema_models from ..console.wraps import setup_required from ..files import files_ns from ..inner_api.plugin.wraps import get_user @@ -35,6 +35,8 @@ files_ns.schema_model( PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) ) +register_schema_models(files_ns, FileResponse) + @files_ns.route("/upload/for-plugin") class PluginUploadFileApi(Resource): @@ -51,7 +53,7 @@ class PluginUploadFileApi(Resource): 415: "Unsupported file type", } ) - @files_ns.marshal_with(build_file_model(files_ns), code=HTTPStatus.CREATED) + @files_ns.response(HTTPStatus.CREATED, "File uploaded", files_ns.models[FileResponse.__name__]) def post(self): """Upload a file for plugin usage. @@ -69,7 +71,7 @@ class PluginUploadFileApi(Resource): """ args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - file: FileStorage | None = request.files.get("file") + file = request.files.get("file") if file is None: raise Forbidden("File is required.") @@ -80,8 +82,8 @@ class PluginUploadFileApi(Resource): user_id = args.user_id user = get_user(tenant_id, user_id) - filename: str | None = file.filename - mimetype: str | None = file.mimetype + filename = file.filename + mimetype = file.mimetype if not filename or not mimetype: raise Forbidden("Invalid request.") @@ -111,22 +113,22 @@ class PluginUploadFileApi(Resource): preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension) # Create a dictionary with all the necessary attributes - result = { - "id": tool_file.id, - "user_id": tool_file.user_id, - "tenant_id": tool_file.tenant_id, - "conversation_id": tool_file.conversation_id, - "file_key": tool_file.file_key, - "mimetype": tool_file.mimetype, - "original_url": tool_file.original_url, - "name": tool_file.name, - "size": tool_file.size, - "mime_type": mimetype, - "extension": extension, - "preview_url": preview_url, - } + result = FileResponse( + id=tool_file.id, + name=tool_file.name, + size=tool_file.size, + extension=extension, + mime_type=mimetype, + preview_url=preview_url, + source_url=tool_file.original_url, + original_url=tool_file.original_url, + user_id=tool_file.user_id, + tenant_id=tool_file.tenant_id, + conversation_id=tool_file.conversation_id, + file_key=tool_file.file_key, + ) - return result, 201 + return result.model_dump(mode="json"), 201 except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index ffe4e0b492..6f6dadf768 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -10,13 +10,16 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from extensions.ext_database import db -from fields.file_fields import build_file_model +from fields.file_fields import FileResponse from models import App, EndUser from services.file_service import FileService +register_schema_models(service_api_ns, FileResponse) + @service_api_ns.route("/files/upload") class FileApi(Resource): @@ -31,8 +34,8 @@ class FileApi(Resource): 415: "Unsupported file type", } ) - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) - @service_api_ns.marshal_with(build_file_model(service_api_ns), code=HTTPStatus.CREATED) + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) # type: ignore + @service_api_ns.response(HTTPStatus.CREATED, "File uploaded", service_api_ns.models[FileResponse.__name__]) def post(self, app_model: App, end_user: EndUser): """Upload a file for use in conversations. @@ -64,4 +67,5 @@ class FileApi(Resource): except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return upload_file, 201 + response = FileResponse.model_validate(upload_file, from_attributes=True) + return response.model_dump(mode="json"), 201 diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index d81287d56f..8dbb690901 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -24,7 +24,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): dataset_id_str = str(dataset_id) dataset = self.get_and_validate_dataset(dataset_id_str) - args = self.parse_args() + args = self.parse_args(service_api_ns.payload) self.hit_testing_args_check(args) return self.perform_hit_testing(dataset, args) diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index 0a2017e2bd..70b5030237 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -174,7 +174,7 @@ class PipelineRunApi(DatasetApiResource): pipeline=pipeline, user=current_user, args=payload.model_dump(), - invoke_from=InvokeFrom.PUBLISHED if payload.is_published else InvokeFrom.DEBUGGER, + invoke_from=InvokeFrom.PUBLISHED_PIPELINE if payload.is_published else InvokeFrom.DEBUGGER, streaming=payload.response_mode == "streaming", ) diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 527eef6094..e76649495a 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,9 +1,11 @@ -from flask_restx import reqparse -from flask_restx.inputs import int_range -from pydantic import TypeAdapter +from typing import Literal + +from flask import request +from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource @@ -21,6 +23,35 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers from services.web_conversation_service import WebConversationService +class ConversationListQuery(BaseModel): + last_id: str | None = None + limit: int = Field(default=20, ge=1, le=100) + pinned: bool | None = None + sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = "-updated_at" + + @field_validator("last_id") + @classmethod + def validate_last_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +class ConversationRenamePayload(BaseModel): + name: str | None = None + auto_generate: bool = False + + @model_validator(mode="after") + def validate_name_requirement(self): + if not self.auto_generate: + if self.name is None or not self.name.strip(): + raise ValueError("name is required when auto_generate is false") + return self + + +register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload) + + @web_ns.route("/conversations") class ConversationListApi(WebApiResource): @web_ns.doc("Get Conversation List") @@ -64,25 +95,8 @@ class ConversationListApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - .add_argument("pinned", type=str, choices=["true", "false", None], location="args") - .add_argument( - "sort_by", - type=str, - choices=["created_at", "-created_at", "updated_at", "-updated_at"], - required=False, - default="-updated_at", - location="args", - ) - ) - args = parser.parse_args() - - pinned = None - if "pinned" in args and args["pinned"] is not None: - pinned = args["pinned"] == "true" + raw_args = request.args.to_dict() + query = ConversationListQuery.model_validate(raw_args) try: with Session(db.engine) as session: @@ -90,11 +104,11 @@ class ConversationListApi(WebApiResource): session=session, app_model=app_model, user=end_user, - last_id=args["last_id"], - limit=args["limit"], + last_id=query.last_id, + limit=query.limit, invoke_from=InvokeFrom.WEB_APP, - pinned=pinned, - sort_by=args["sort_by"], + pinned=query.pinned, + sort_by=query.sort_by, ) adapter = TypeAdapter(SimpleConversation) conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data] @@ -168,16 +182,11 @@ class ConversationRenameApi(WebApiResource): conversation_id = str(c_id) - parser = ( - reqparse.RequestParser() - .add_argument("name", type=str, required=False, location="json") - .add_argument("auto_generate", type=bool, required=False, default=False, location="json") - ) - args = parser.parse_args() + payload = ConversationRenamePayload.model_validate(web_ns.payload or {}) try: conversation = ConversationService.rename( - app_model, conversation_id, end_user, args["name"], args["auto_generate"] + app_model, conversation_id, end_user, payload.name, payload.auto_generate ) return ( TypeAdapter(SimpleConversation) diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py index 80ad61e549..0036c90800 100644 --- a/api/controllers/web/files.py +++ b/api/controllers/web/files.py @@ -1,5 +1,4 @@ from flask import request -from flask_restx import marshal_with import services from controllers.common.errors import ( @@ -9,12 +8,15 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.wraps import WebApiResource from extensions.ext_database import db -from fields.file_fields import build_file_model +from fields.file_fields import FileResponse from services.file_service import FileService +register_schema_models(web_ns, FileResponse) + @web_ns.route("/files/upload") class FileApi(WebApiResource): @@ -28,7 +30,7 @@ class FileApi(WebApiResource): 415: "Unsupported file type", } ) - @marshal_with(build_file_model(web_ns)) + @web_ns.response(201, "File uploaded successfully", web_ns.models[FileResponse.__name__]) def post(self, app_model, end_user): """Upload a file for use in web applications. @@ -81,4 +83,5 @@ class FileApi(WebApiResource): except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return upload_file, 201 + response = FileResponse.model_validate(upload_file, from_attributes=True) + return response.model_dump(mode="json"), 201 diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index c1f976829f..b08b3fe858 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,7 +1,6 @@ import urllib.parse import httpx -from flask_restx import marshal_with from pydantic import BaseModel, Field, HttpUrl import services @@ -14,7 +13,7 @@ from controllers.common.errors import ( from core.file import helpers as file_helpers from core.helper import ssrf_proxy from extensions.ext_database import db -from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model +from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from services.file_service import FileService from ..common.schema import register_schema_models @@ -26,7 +25,7 @@ class RemoteFileUploadPayload(BaseModel): url: HttpUrl = Field(description="Remote file URL") -register_schema_models(web_ns, RemoteFileUploadPayload) +register_schema_models(web_ns, RemoteFileUploadPayload, RemoteFileInfo, FileWithSignedUrl) @web_ns.route("/remote-files/") @@ -41,7 +40,7 @@ class RemoteFileInfoApi(WebApiResource): 500: "Failed to fetch remote file", } ) - @marshal_with(build_remote_file_info_model(web_ns)) + @web_ns.response(200, "Remote file info", web_ns.models[RemoteFileInfo.__name__]) def get(self, app_model, end_user, url): """Get information about a remote file. @@ -65,10 +64,11 @@ class RemoteFileInfoApi(WebApiResource): # failed back to get method resp = ssrf_proxy.get(decoded_url, timeout=3) resp.raise_for_status() - return { - "file_type": resp.headers.get("Content-Type", "application/octet-stream"), - "file_length": int(resp.headers.get("Content-Length", -1)), - } + info = RemoteFileInfo( + file_type=resp.headers.get("Content-Type", "application/octet-stream"), + file_length=int(resp.headers.get("Content-Length", -1)), + ) + return info.model_dump(mode="json") @web_ns.route("/remote-files/upload") @@ -84,7 +84,7 @@ class RemoteFileUploadApi(WebApiResource): 500: "Failed to fetch remote file", } ) - @marshal_with(build_file_with_signed_url_model(web_ns)) + @web_ns.response(201, "Remote file uploaded", web_ns.models[FileWithSignedUrl.__name__]) def post(self, app_model, end_user): """Upload a file from a remote URL. @@ -139,13 +139,14 @@ class RemoteFileUploadApi(WebApiResource): except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError - return { - "id": upload_file.id, - "name": upload_file.name, - "size": upload_file.size, - "extension": upload_file.extension, - "url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id), - "mime_type": upload_file.mime_type, - "created_by": upload_file.created_by, - "created_at": upload_file.created_at, - }, 201 + payload1 = FileWithSignedUrl( + id=upload_file.id, + name=upload_file.name, + size=upload_file.size, + extension=upload_file.extension, + url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id), + mime_type=upload_file.mime_type, + created_by=upload_file.created_by, + created_at=int(upload_file.created_at.timestamp()), + ) + return payload1.model_dump(mode="json"), 201 diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 4e20690e9e..29993100f6 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,18 +1,30 @@ -from flask_restx import reqparse -from flask_restx.inputs import int_range -from pydantic import TypeAdapter +from flask import request +from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import NotCompletionAppError from controllers.web.wraps import WebApiResource from fields.conversation_fields import ResultResponse from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem -from libs.helper import uuid_value +from libs.helper import UUIDStrOrEmpty from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService +class SavedMessageListQuery(BaseModel): + last_id: UUIDStrOrEmpty | None = None + limit: int = Field(default=20, ge=1, le=100) + + +class SavedMessageCreatePayload(BaseModel): + message_id: UUIDStrOrEmpty + + +register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload) + + @web_ns.route("/saved-messages") class SavedMessageListApi(WebApiResource): @web_ns.doc("Get Saved Messages") @@ -42,14 +54,10 @@ class SavedMessageListApi(WebApiResource): if app_model.mode != "completion": raise NotCompletionAppError() - parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - ) - args = parser.parse_args() + raw_args = request.args.to_dict() + query = SavedMessageListQuery.model_validate(raw_args) - pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"]) + pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit) adapter = TypeAdapter(SavedMessageItem) items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data] return SavedMessageInfiniteScrollPagination( @@ -79,11 +87,10 @@ class SavedMessageListApi(WebApiResource): if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json") - args = parser.parse_args() + payload = SavedMessageCreatePayload.model_validate(web_ns.payload or {}) try: - SavedMessageService.save(app_model, end_user, args["message_id"]) + SavedMessageService.save(app_model, end_user, payload.message_id) except MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/core/agent/agent_app_runner.py b/api/core/agent/agent_app_runner.py new file mode 100644 index 0000000000..2ee0a23aab --- /dev/null +++ b/api/core/agent/agent_app_runner.py @@ -0,0 +1,380 @@ +import logging +from collections.abc import Generator +from copy import deepcopy +from typing import Any + +from core.agent.base_agent_runner import BaseAgentRunner +from core.agent.entities import AgentEntity, AgentLog, AgentResult +from core.agent.patterns.strategy_factory import StrategyFactory +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent +from core.file import file_manager +from core.model_runtime.entities import ( + AssistantPromptMessage, + LLMResult, + LLMResultChunk, + LLMUsage, + PromptMessage, + PromptMessageContentType, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from core.tools.__base.tool import Tool +from core.tools.entities.tool_entities import ToolInvokeMeta +from core.tools.tool_engine import ToolEngine +from models.model import Message + +logger = logging.getLogger(__name__) + + +class AgentAppRunner(BaseAgentRunner): + def _create_tool_invoke_hook(self, message: Message): + """ + Create a tool invoke hook that uses ToolEngine.agent_invoke. + This hook handles file creation and returns proper meta information. + """ + # Get trace manager from app generate entity + trace_manager = self.application_generate_entity.trace_manager + + def tool_invoke_hook( + tool: Tool, tool_args: dict[str, Any], tool_name: str + ) -> tuple[str, list[str], ToolInvokeMeta]: + """Hook that uses agent_invoke for proper file and meta handling.""" + tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( + tool=tool, + tool_parameters=tool_args, + user_id=self.user_id, + tenant_id=self.tenant_id, + message=message, + invoke_from=self.application_generate_entity.invoke_from, + agent_tool_callback=self.agent_callback, + trace_manager=trace_manager, + app_id=self.application_generate_entity.app_config.app_id, + message_id=message.id, + conversation_id=self.conversation.id, + ) + + # Publish files and track IDs + for message_file_id in message_files: + self.queue_manager.publish( + QueueMessageFileEvent(message_file_id=message_file_id), + PublishFrom.APPLICATION_MANAGER, + ) + self._current_message_file_ids.append(message_file_id) + + return tool_invoke_response, message_files, tool_invoke_meta + + return tool_invoke_hook + + def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]: + """ + Run Agent application + """ + self.query = query + app_generate_entity = self.application_generate_entity + + app_config = self.app_config + assert app_config is not None, "app_config is required" + assert app_config.agent is not None, "app_config.agent is required" + + # convert tools into ModelRuntime Tool format + tool_instances, _ = self._init_prompt_tools() + + assert app_config.agent + + # Create tool invoke hook for agent_invoke + tool_invoke_hook = self._create_tool_invoke_hook(message) + + # Get instruction for ReAct strategy + instruction = self.app_config.prompt_template.simple_prompt_template or "" + + # Use factory to create appropriate strategy + strategy = StrategyFactory.create_strategy( + model_features=self.model_features, + model_instance=self.model_instance, + tools=list(tool_instances.values()), + files=list(self.files), + max_iterations=app_config.agent.max_iteration, + context=self.build_execution_context(), + agent_strategy=self.config.strategy, + tool_invoke_hook=tool_invoke_hook, + instruction=instruction, + ) + + # Initialize state variables + current_agent_thought_id = None + has_published_thought = False + current_tool_name: str | None = None + self._current_message_file_ids: list[str] = [] + + # organize prompt messages + prompt_messages = self._organize_prompt_messages() + + # Run strategy + generator = strategy.run( + prompt_messages=prompt_messages, + model_parameters=app_generate_entity.model_conf.parameters, + stop=app_generate_entity.model_conf.stop, + stream=True, + ) + + # Consume generator and collect result + result: AgentResult | None = None + try: + while True: + try: + output = next(generator) + except StopIteration as e: + # Generator finished, get the return value + result = e.value + break + + if isinstance(output, LLMResultChunk): + # Handle LLM chunk + if current_agent_thought_id and not has_published_thought: + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) + has_published_thought = True + + yield output + + elif isinstance(output, AgentLog): + # Handle Agent Log using log_type for type-safe dispatch + if output.status == AgentLog.LogStatus.START: + if output.log_type == AgentLog.LogType.ROUND: + # Start of a new round + message_file_ids: list[str] = [] + current_agent_thought_id = self.create_agent_thought( + message_id=message.id, + message="", + tool_name="", + tool_input="", + messages_ids=message_file_ids, + ) + has_published_thought = False + + elif output.log_type == AgentLog.LogType.TOOL_CALL: + if current_agent_thought_id is None: + continue + + # Tool call start - extract data from structured fields + current_tool_name = output.data.get("tool_name", "") + tool_input = output.data.get("tool_args", {}) + + self.save_agent_thought( + agent_thought_id=current_agent_thought_id, + tool_name=current_tool_name, + tool_input=tool_input, + thought=None, + observation=None, + tool_invoke_meta=None, + answer=None, + messages_ids=[], + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) + + elif output.status == AgentLog.LogStatus.SUCCESS: + if output.log_type == AgentLog.LogType.THOUGHT: + if current_agent_thought_id is None: + continue + + thought_text = output.data.get("thought") + self.save_agent_thought( + agent_thought_id=current_agent_thought_id, + tool_name=None, + tool_input=None, + thought=thought_text, + observation=None, + tool_invoke_meta=None, + answer=None, + messages_ids=[], + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) + + elif output.log_type == AgentLog.LogType.TOOL_CALL: + if current_agent_thought_id is None: + continue + + # Tool call finished + tool_output = output.data.get("output") + # Get meta from strategy output (now properly populated) + tool_meta = output.data.get("meta") + + # Wrap tool_meta with tool_name as key (required by agent_service) + if tool_meta and current_tool_name: + tool_meta = {current_tool_name: tool_meta} + + self.save_agent_thought( + agent_thought_id=current_agent_thought_id, + tool_name=None, + tool_input=None, + thought=None, + observation=tool_output, + tool_invoke_meta=tool_meta, + answer=None, + messages_ids=self._current_message_file_ids, + ) + # Clear message file ids after saving + self._current_message_file_ids = [] + current_tool_name = None + + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) + + elif output.log_type == AgentLog.LogType.ROUND: + if current_agent_thought_id is None: + continue + + # Round finished - save LLM usage and answer + llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE) + llm_result = output.data.get("llm_result") + final_answer = output.data.get("final_answer") + + self.save_agent_thought( + agent_thought_id=current_agent_thought_id, + tool_name=None, + tool_input=None, + thought=llm_result, + observation=None, + tool_invoke_meta=None, + answer=final_answer, + messages_ids=[], + llm_usage=llm_usage, + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) + + except Exception: + # Re-raise any other exceptions + raise + + # Process final result + if isinstance(result, AgentResult): + final_answer = result.text + usage = result.usage or LLMUsage.empty_usage() + + # Publish end event + self.queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=self.model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=final_answer), + usage=usage, + system_fingerprint="", + ) + ), + PublishFrom.APPLICATION_MANAGER, + ) + + def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + Initialize system message + """ + if not prompt_template: + return prompt_messages or [] + + prompt_messages = prompt_messages or [] + + if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage): + prompt_messages[0] = SystemPromptMessage(content=prompt_template) + return prompt_messages + + if not prompt_messages: + return [SystemPromptMessage(content=prompt_template)] + + prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) + return prompt_messages + + def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + Organize user query + """ + if self.files: + # get image detail config + image_detail_config = ( + self.application_generate_entity.file_upload_config.image_config.detail + if ( + self.application_generate_entity.file_upload_config + and self.application_generate_entity.file_upload_config.image_config + ) + else None + ) + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] + for file in self.files: + prompt_message_contents.append( + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config, + ) + ) + prompt_message_contents.append(TextPromptMessageContent(data=query)) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + prompt_messages.append(UserPromptMessage(content=query)) + + return prompt_messages + + def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + As for now, gpt supports both fc and vision at the first iteration. + We need to remove the image messages from the prompt messages at the first iteration. + """ + prompt_messages = deepcopy(prompt_messages) + + for prompt_message in prompt_messages: + if isinstance(prompt_message, UserPromptMessage): + if isinstance(prompt_message.content, list): + prompt_message.content = "\n".join( + [ + content.data + if content.type == PromptMessageContentType.TEXT + else "[image]" + if content.type == PromptMessageContentType.IMAGE + else "[file]" + for content in prompt_message.content + ] + ) + + return prompt_messages + + def _organize_prompt_messages(self): + # For ReAct strategy, use the agent prompt template + if self.config.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT and self.config.prompt: + prompt_template = self.config.prompt.first_prompt + else: + prompt_template = self.app_config.prompt_template.simple_prompt_template or "" + + self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) + query_prompt_messages = self._organize_user_query(self.query or "", []) + + self.history_prompt_messages = AgentHistoryPromptTransform( + model_config=self.model_config, + prompt_messages=[*query_prompt_messages, *self._current_thoughts], + history_messages=self.history_prompt_messages, + memory=self.memory, + ).get_prompt() + + prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts] + if len(self._current_thoughts) != 0: + # clear messages after the first iteration + prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) + return prompt_messages diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index c196dbbdf1..b59a9a3859 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -5,7 +5,7 @@ from typing import Union, cast from sqlalchemy import select -from core.agent.entities import AgentEntity, AgentToolEntity +from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager @@ -114,9 +114,20 @@ class BaseAgentRunner(AppRunner): features = model_schema.features if model_schema and model_schema.features else [] self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features self.files = application_generate_entity.files if ModelFeature.VISION in features else [] + self.model_features = features self.query: str | None = "" self._current_thoughts: list[PromptMessage] = [] + def build_execution_context(self) -> ExecutionContext: + """Build execution context.""" + return ExecutionContext( + user_id=self.user_id, + app_id=self.app_config.app_id, + conversation_id=self.conversation.id, + message_id=self.message.id, + tenant_id=self.tenant_id, + ) + def _repack_app_generate_entity( self, app_generate_entity: AgentChatAppGenerateEntity ) -> AgentChatAppGenerateEntity: diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py deleted file mode 100644 index a55f2d0f5f..0000000000 --- a/api/core/agent/cot_agent_runner.py +++ /dev/null @@ -1,437 +0,0 @@ -import json -import logging -from abc import ABC, abstractmethod -from collections.abc import Generator, Mapping, Sequence -from typing import Any - -from core.agent.base_agent_runner import BaseAgentRunner -from core.agent.entities import AgentScratchpadUnit -from core.agent.output_parser.cot_output_parser import CotAgentOutputParser -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageTool, - ToolPromptMessage, - UserPromptMessage, -) -from core.ops.ops_trace_manager import TraceQueueManager -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ToolInvokeMeta -from core.tools.tool_engine import ToolEngine -from core.workflow.nodes.agent.exc import AgentMaxIterationError -from models.model import Message - -logger = logging.getLogger(__name__) - - -class CotAgentRunner(BaseAgentRunner, ABC): - _is_first_iteration = True - _ignore_observation_providers = ["wenxin"] - _historic_prompt_messages: list[PromptMessage] - _agent_scratchpad: list[AgentScratchpadUnit] - _instruction: str - _query: str - _prompt_messages_tools: Sequence[PromptMessageTool] - - def run( - self, - message: Message, - query: str, - inputs: Mapping[str, str], - ) -> Generator: - """ - Run Cot agent application - """ - - app_generate_entity = self.application_generate_entity - self._repack_app_generate_entity(app_generate_entity) - self._init_react_state(query) - - trace_manager = app_generate_entity.trace_manager - - # check model mode - if "Observation" not in app_generate_entity.model_conf.stop: - if app_generate_entity.model_conf.provider not in self._ignore_observation_providers: - app_generate_entity.model_conf.stop.append("Observation") - - app_config = self.app_config - assert app_config.agent - - # init instruction - inputs = inputs or {} - instruction = app_config.prompt_template.simple_prompt_template or "" - self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) - - iteration_step = 1 - max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1 - - # convert tools into ModelRuntime Tool format - tool_instances, prompt_messages_tools = self._init_prompt_tools() - self._prompt_messages_tools = prompt_messages_tools - - function_call_state = True - llm_usage: dict[str, LLMUsage | None] = {"usage": None} - final_answer = "" - prompt_messages: list = [] # Initialize prompt_messages - agent_thought_id = "" # Initialize agent_thought_id - - def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage): - if not final_llm_usage_dict["usage"]: - final_llm_usage_dict["usage"] = usage - else: - llm_usage = final_llm_usage_dict["usage"] - llm_usage.prompt_tokens += usage.prompt_tokens - llm_usage.completion_tokens += usage.completion_tokens - llm_usage.total_tokens += usage.total_tokens - llm_usage.prompt_price += usage.prompt_price - llm_usage.completion_price += usage.completion_price - llm_usage.total_price += usage.total_price - - model_instance = self.model_instance - - while function_call_state and iteration_step <= max_iteration_steps: - # continue to run until there is not any tool call - function_call_state = False - - if iteration_step == max_iteration_steps: - # the last iteration, remove all tools - self._prompt_messages_tools = [] - - message_file_ids: list[str] = [] - - agent_thought_id = self.create_agent_thought( - message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids - ) - - if iteration_step > 1: - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - # recalc llm max tokens - prompt_messages = self._organize_prompt_messages() - self.recalc_llm_max_tokens(self.model_config, prompt_messages) - # invoke model - chunks = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=app_generate_entity.model_conf.parameters, - tools=[], - stop=app_generate_entity.model_conf.stop, - stream=True, - user=self.user_id, - callbacks=[], - ) - - usage_dict: dict[str, LLMUsage | None] = {} - react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) - scratchpad = AgentScratchpadUnit( - agent_response="", - thought="", - action_str="", - observation="", - action=None, - ) - - # publish agent thought if it's first iteration - if iteration_step == 1: - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - for chunk in react_chunks: - if isinstance(chunk, AgentScratchpadUnit.Action): - action = chunk - # detect action - assert scratchpad.agent_response is not None - scratchpad.agent_response += json.dumps(chunk.model_dump()) - scratchpad.action_str = json.dumps(chunk.model_dump()) - scratchpad.action = action - else: - assert scratchpad.agent_response is not None - scratchpad.agent_response += chunk - assert scratchpad.thought is not None - scratchpad.thought += chunk - yield LLMResultChunk( - model=self.model_config.model, - prompt_messages=prompt_messages, - system_fingerprint="", - delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), - ) - - assert scratchpad.thought is not None - scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" - self._agent_scratchpad.append(scratchpad) - - # Check if max iteration is reached and model still wants to call tools - if iteration_step == max_iteration_steps and scratchpad.action: - if scratchpad.action.action_name.lower() != "final answer": - raise AgentMaxIterationError(app_config.agent.max_iteration) - - # get llm usage - if "usage" in usage_dict: - if usage_dict["usage"] is not None: - increase_usage(llm_usage, usage_dict["usage"]) - else: - usage_dict["usage"] = LLMUsage.empty_usage() - - self.save_agent_thought( - agent_thought_id=agent_thought_id, - tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""), - tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {}, - tool_invoke_meta={}, - thought=scratchpad.thought or "", - observation="", - answer=scratchpad.agent_response or "", - messages_ids=[], - llm_usage=usage_dict["usage"], - ) - - if not scratchpad.is_final(): - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - if not scratchpad.action: - # failed to extract action, return final answer directly - final_answer = "" - else: - if scratchpad.action.action_name.lower() == "final answer": - # action is final answer, return final answer directly - try: - if isinstance(scratchpad.action.action_input, dict): - final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False) - elif isinstance(scratchpad.action.action_input, str): - final_answer = scratchpad.action.action_input - else: - final_answer = f"{scratchpad.action.action_input}" - except TypeError: - final_answer = f"{scratchpad.action.action_input}" - else: - function_call_state = True - # action is tool call, invoke tool - tool_invoke_response, tool_invoke_meta = self._handle_invoke_action( - action=scratchpad.action, - tool_instances=tool_instances, - message_file_ids=message_file_ids, - trace_manager=trace_manager, - ) - scratchpad.observation = tool_invoke_response - scratchpad.agent_response = tool_invoke_response - - self.save_agent_thought( - agent_thought_id=agent_thought_id, - tool_name=scratchpad.action.action_name, - tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, - thought=scratchpad.thought or "", - observation={scratchpad.action.action_name: tool_invoke_response}, - tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()}, - answer=scratchpad.agent_response, - messages_ids=message_file_ids, - llm_usage=usage_dict["usage"], - ) - - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - # update prompt tool message - for prompt_tool in self._prompt_messages_tools: - self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) - - iteration_step += 1 - - yield LLMResultChunk( - model=model_instance.model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"] - ), - system_fingerprint="", - ) - - # save agent thought - self.save_agent_thought( - agent_thought_id=agent_thought_id, - tool_name="", - tool_input={}, - tool_invoke_meta={}, - thought=final_answer, - observation={}, - answer=final_answer, - messages_ids=[], - ) - # publish end event - self.queue_manager.publish( - QueueMessageEndEvent( - llm_result=LLMResult( - model=model_instance.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content=final_answer), - usage=llm_usage["usage"] or LLMUsage.empty_usage(), - system_fingerprint="", - ) - ), - PublishFrom.APPLICATION_MANAGER, - ) - - def _handle_invoke_action( - self, - action: AgentScratchpadUnit.Action, - tool_instances: Mapping[str, Tool], - message_file_ids: list[str], - trace_manager: TraceQueueManager | None = None, - ) -> tuple[str, ToolInvokeMeta]: - """ - handle invoke action - :param action: action - :param tool_instances: tool instances - :param message_file_ids: message file ids - :param trace_manager: trace manager - :return: observation, meta - """ - # action is tool call, invoke tool - tool_call_name = action.action_name - tool_call_args = action.action_input - tool_instance = tool_instances.get(tool_call_name) - - if not tool_instance: - answer = f"there is not a tool named {tool_call_name}" - return answer, ToolInvokeMeta.error_instance(answer) - - if isinstance(tool_call_args, str): - try: - tool_call_args = json.loads(tool_call_args) - except json.JSONDecodeError: - pass - - # invoke tool - tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( - tool=tool_instance, - tool_parameters=tool_call_args, - user_id=self.user_id, - tenant_id=self.tenant_id, - message=self.message, - invoke_from=self.application_generate_entity.invoke_from, - agent_tool_callback=self.agent_callback, - trace_manager=trace_manager, - ) - - # publish files - for message_file_id in message_files: - # publish message file - self.queue_manager.publish( - QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER - ) - # add message file ids - message_file_ids.append(message_file_id) - - return tool_invoke_response, tool_invoke_meta - - def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action: - """ - convert dict to action - """ - return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"]) - - def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str: - """ - fill in inputs from external data tools - """ - for key, value in inputs.items(): - try: - instruction = instruction.replace(f"{{{{{key}}}}}", str(value)) - except Exception: - continue - - return instruction - - def _init_react_state(self, query): - """ - init agent scratchpad - """ - self._query = query - self._agent_scratchpad = [] - self._historic_prompt_messages = self._organize_historic_prompt_messages() - - @abstractmethod - def _organize_prompt_messages(self) -> list[PromptMessage]: - """ - organize prompt messages - """ - - def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: - """ - format assistant message - """ - message = "" - for scratchpad in agent_scratchpad: - if scratchpad.is_final(): - message += f"Final Answer: {scratchpad.agent_response}" - else: - message += f"Thought: {scratchpad.thought}\n\n" - if scratchpad.action_str: - message += f"Action: {scratchpad.action_str}\n\n" - if scratchpad.observation: - message += f"Observation: {scratchpad.observation}\n\n" - - return message - - def _organize_historic_prompt_messages( - self, current_session_messages: list[PromptMessage] | None = None - ) -> list[PromptMessage]: - """ - organize historic prompt messages - """ - result: list[PromptMessage] = [] - scratchpads: list[AgentScratchpadUnit] = [] - current_scratchpad: AgentScratchpadUnit | None = None - - for message in self.history_prompt_messages: - if isinstance(message, AssistantPromptMessage): - if not current_scratchpad: - assert isinstance(message.content, str) - current_scratchpad = AgentScratchpadUnit( - agent_response=message.content, - thought=message.content or "I am thinking about how to help you", - action_str="", - action=None, - observation=None, - ) - scratchpads.append(current_scratchpad) - if message.tool_calls: - try: - current_scratchpad.action = AgentScratchpadUnit.Action( - action_name=message.tool_calls[0].function.name, - action_input=json.loads(message.tool_calls[0].function.arguments), - ) - current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict()) - except Exception: - logger.exception("Failed to parse tool call from assistant message") - elif isinstance(message, ToolPromptMessage): - if current_scratchpad: - assert isinstance(message.content, str) - current_scratchpad.observation = message.content - else: - raise NotImplementedError("expected str type") - elif isinstance(message, UserPromptMessage): - if scratchpads: - result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) - scratchpads = [] - current_scratchpad = None - - result.append(message) - - if scratchpads: - result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) - - historic_prompts = AgentHistoryPromptTransform( - model_config=self.model_config, - prompt_messages=current_session_messages or [], - history_messages=result, - memory=self.memory, - ).get_prompt() - return historic_prompts diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py deleted file mode 100644 index 4d1d94eadc..0000000000 --- a/api/core/agent/cot_chat_agent_runner.py +++ /dev/null @@ -1,118 +0,0 @@ -import json - -from core.agent.cot_agent_runner import CotAgentRunner -from core.file import file_manager -from core.model_runtime.entities import ( - AssistantPromptMessage, - PromptMessage, - SystemPromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) -from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from core.model_runtime.utils.encoders import jsonable_encoder - - -class CotChatAgentRunner(CotAgentRunner): - def _organize_system_prompt(self) -> SystemPromptMessage: - """ - Organize system prompt - """ - assert self.app_config.agent - assert self.app_config.agent.prompt - - prompt_entity = self.app_config.agent.prompt - if not prompt_entity: - raise ValueError("Agent prompt configuration is not set") - first_prompt = prompt_entity.first_prompt - - system_prompt = ( - first_prompt.replace("{{instruction}}", self._instruction) - .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) - .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) - ) - - return SystemPromptMessage(content=system_prompt) - - def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: - """ - Organize user query - """ - if self.files: - # get image detail config - image_detail_config = ( - self.application_generate_entity.file_upload_config.image_config.detail - if ( - self.application_generate_entity.file_upload_config - and self.application_generate_entity.file_upload_config.image_config - ) - else None - ) - image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - - prompt_message_contents: list[PromptMessageContentUnionTypes] = [] - for file in self.files: - prompt_message_contents.append( - file_manager.to_prompt_message_content( - file, - image_detail_config=image_detail_config, - ) - ) - prompt_message_contents.append(TextPromptMessageContent(data=query)) - - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: - prompt_messages.append(UserPromptMessage(content=query)) - - return prompt_messages - - def _organize_prompt_messages(self) -> list[PromptMessage]: - """ - Organize - """ - # organize system prompt - system_message = self._organize_system_prompt() - - # organize current assistant messages - agent_scratchpad = self._agent_scratchpad - if not agent_scratchpad: - assistant_messages = [] - else: - assistant_message = AssistantPromptMessage(content="") - assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str - for unit in agent_scratchpad: - if unit.is_final(): - assert isinstance(assistant_message.content, str) - assistant_message.content += f"Final Answer: {unit.agent_response}" - else: - assert isinstance(assistant_message.content, str) - assistant_message.content += f"Thought: {unit.thought}\n\n" - if unit.action_str: - assistant_message.content += f"Action: {unit.action_str}\n\n" - if unit.observation: - assistant_message.content += f"Observation: {unit.observation}\n\n" - - assistant_messages = [assistant_message] - - # query messages - query_messages = self._organize_user_query(self._query, []) - - if assistant_messages: - # organize historic prompt messages - historic_messages = self._organize_historic_prompt_messages( - [system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")] - ) - messages = [ - system_message, - *historic_messages, - *query_messages, - *assistant_messages, - UserPromptMessage(content="continue"), - ] - else: - # organize historic prompt messages - historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages]) - messages = [system_message, *historic_messages, *query_messages] - - # join all messages - return messages diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py deleted file mode 100644 index da9a001d84..0000000000 --- a/api/core/agent/cot_completion_agent_runner.py +++ /dev/null @@ -1,87 +0,0 @@ -import json - -from core.agent.cot_agent_runner import CotAgentRunner -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) -from core.model_runtime.utils.encoders import jsonable_encoder - - -class CotCompletionAgentRunner(CotAgentRunner): - def _organize_instruction_prompt(self) -> str: - """ - Organize instruction prompt - """ - if self.app_config.agent is None: - raise ValueError("Agent configuration is not set") - prompt_entity = self.app_config.agent.prompt - if prompt_entity is None: - raise ValueError("prompt entity is not set") - first_prompt = prompt_entity.first_prompt - - system_prompt = ( - first_prompt.replace("{{instruction}}", self._instruction) - .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) - .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) - ) - - return system_prompt - - def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] | None = None) -> str: - """ - Organize historic prompt - """ - historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages) - historic_prompt = "" - - for message in historic_prompt_messages: - if isinstance(message, UserPromptMessage): - historic_prompt += f"Question: {message.content}\n\n" - elif isinstance(message, AssistantPromptMessage): - if isinstance(message.content, str): - historic_prompt += message.content + "\n\n" - elif isinstance(message.content, list): - for content in message.content: - if not isinstance(content, TextPromptMessageContent): - continue - historic_prompt += content.data - - return historic_prompt - - def _organize_prompt_messages(self) -> list[PromptMessage]: - """ - Organize prompt messages - """ - # organize system prompt - system_prompt = self._organize_instruction_prompt() - - # organize historic prompt messages - historic_prompt = self._organize_historic_prompt() - - # organize current assistant messages - agent_scratchpad = self._agent_scratchpad - assistant_prompt = "" - for unit in agent_scratchpad or []: - if unit.is_final(): - assistant_prompt += f"Final Answer: {unit.agent_response}" - else: - assistant_prompt += f"Thought: {unit.thought}\n\n" - if unit.action_str: - assistant_prompt += f"Action: {unit.action_str}\n\n" - if unit.observation: - assistant_prompt += f"Observation: {unit.observation}\n\n" - - # query messages - query_prompt = f"Question: {self._query}" - - # join all messages - prompt = ( - system_prompt.replace("{{historic_messages}}", historic_prompt) - .replace("{{agent_scratchpad}}", assistant_prompt) - .replace("{{query}}", query_prompt) - ) - - return [UserPromptMessage(content=prompt)] diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 220feced1d..46af4d2d72 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -1,3 +1,5 @@ +import uuid +from collections.abc import Mapping from enum import StrEnum from typing import Any, Union @@ -92,3 +94,96 @@ class AgentInvokeMessage(ToolInvokeMessage): """ pass + + +class ExecutionContext(BaseModel): + """Execution context containing trace and audit information. + + This context carries all the IDs and metadata that are not part of + the core business logic but needed for tracing, auditing, and + correlation purposes. + """ + + user_id: str | None = None + app_id: str | None = None + conversation_id: str | None = None + message_id: str | None = None + tenant_id: str | None = None + + @classmethod + def create_minimal(cls, user_id: str | None = None) -> "ExecutionContext": + """Create a minimal context with only essential fields.""" + return cls(user_id=user_id) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for passing to legacy code.""" + return { + "user_id": self.user_id, + "app_id": self.app_id, + "conversation_id": self.conversation_id, + "message_id": self.message_id, + "tenant_id": self.tenant_id, + } + + def with_updates(self, **kwargs) -> "ExecutionContext": + """Create a new context with updated fields.""" + data = self.to_dict() + data.update(kwargs) + + return ExecutionContext( + user_id=data.get("user_id"), + app_id=data.get("app_id"), + conversation_id=data.get("conversation_id"), + message_id=data.get("message_id"), + tenant_id=data.get("tenant_id"), + ) + + +class AgentLog(BaseModel): + """ + Agent Log. + """ + + class LogType(StrEnum): + """Type of agent log entry.""" + + ROUND = "round" # A complete iteration round + THOUGHT = "thought" # LLM thinking/reasoning + TOOL_CALL = "tool_call" # Tool invocation + + class LogMetadata(StrEnum): + STARTED_AT = "started_at" + FINISHED_AT = "finished_at" + ELAPSED_TIME = "elapsed_time" + TOTAL_PRICE = "total_price" + TOTAL_TOKENS = "total_tokens" + PROVIDER = "provider" + CURRENCY = "currency" + LLM_USAGE = "llm_usage" + ICON = "icon" + ICON_DARK = "icon_dark" + + class LogStatus(StrEnum): + START = "start" + ERROR = "error" + SUCCESS = "success" + + id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="The id of the log") + label: str = Field(..., description="The label of the log") + log_type: LogType = Field(..., description="The type of the log") + parent_id: str | None = Field(default=None, description="Leave empty for root log") + error: str | None = Field(default=None, description="The error message") + status: LogStatus = Field(..., description="The status of the log") + data: Mapping[str, Any] = Field(..., description="Detailed log data") + metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log") + + +class AgentResult(BaseModel): + """ + Agent execution result. + """ + + text: str = Field(default="", description="The generated text") + files: list[Any] = Field(default_factory=list, description="Files produced during execution") + usage: Any | None = Field(default=None, description="LLM usage statistics") + finish_reason: str | None = Field(default=None, description="Reason for completion") diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py deleted file mode 100644 index 68d14ad027..0000000000 --- a/api/core/agent/fc_agent_runner.py +++ /dev/null @@ -1,470 +0,0 @@ -import json -import logging -from collections.abc import Generator -from copy import deepcopy -from typing import Any, Union - -from core.agent.base_agent_runner import BaseAgentRunner -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.file import file_manager -from core.model_runtime.entities import ( - AssistantPromptMessage, - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMUsage, - PromptMessage, - PromptMessageContentType, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, -) -from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from core.tools.entities.tool_entities import ToolInvokeMeta -from core.tools.tool_engine import ToolEngine -from core.workflow.nodes.agent.exc import AgentMaxIterationError -from models.model import Message - -logger = logging.getLogger(__name__) - - -class FunctionCallAgentRunner(BaseAgentRunner): - def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]: - """ - Run FunctionCall agent application - """ - self.query = query - app_generate_entity = self.application_generate_entity - - app_config = self.app_config - assert app_config is not None, "app_config is required" - assert app_config.agent is not None, "app_config.agent is required" - - # convert tools into ModelRuntime Tool format - tool_instances, prompt_messages_tools = self._init_prompt_tools() - - assert app_config.agent - - iteration_step = 1 - max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1 - - # continue to run until there is not any tool call - function_call_state = True - llm_usage: dict[str, LLMUsage | None] = {"usage": None} - final_answer = "" - prompt_messages: list = [] # Initialize prompt_messages - - # get tracing instance - trace_manager = app_generate_entity.trace_manager - - def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage): - if not final_llm_usage_dict["usage"]: - final_llm_usage_dict["usage"] = usage - else: - llm_usage = final_llm_usage_dict["usage"] - llm_usage.prompt_tokens += usage.prompt_tokens - llm_usage.completion_tokens += usage.completion_tokens - llm_usage.total_tokens += usage.total_tokens - llm_usage.prompt_price += usage.prompt_price - llm_usage.completion_price += usage.completion_price - llm_usage.total_price += usage.total_price - - model_instance = self.model_instance - - while function_call_state and iteration_step <= max_iteration_steps: - function_call_state = False - - if iteration_step == max_iteration_steps: - # the last iteration, remove all tools - prompt_messages_tools = [] - - message_file_ids: list[str] = [] - agent_thought_id = self.create_agent_thought( - message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids - ) - - # recalc llm max tokens - prompt_messages = self._organize_prompt_messages() - self.recalc_llm_max_tokens(self.model_config, prompt_messages) - # invoke model - chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=app_generate_entity.model_conf.parameters, - tools=prompt_messages_tools, - stop=app_generate_entity.model_conf.stop, - stream=self.stream_tool_call, - user=self.user_id, - callbacks=[], - ) - - tool_calls: list[tuple[str, str, dict[str, Any]]] = [] - - # save full response - response = "" - - # save tool call names and inputs - tool_call_names = "" - tool_call_inputs = "" - - current_llm_usage = None - - if isinstance(chunks, Generator): - is_first_chunk = True - for chunk in chunks: - if is_first_chunk: - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - is_first_chunk = False - # check if there is any tool call - if self.check_tool_calls(chunk): - function_call_state = True - tool_calls.extend(self.extract_tool_calls(chunk) or []) - tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) - try: - tool_call_inputs = json.dumps( - {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False - ) - except TypeError: - # fallback: force ASCII to handle non-serializable objects - tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) - - if chunk.delta.message and chunk.delta.message.content: - if isinstance(chunk.delta.message.content, list): - for content in chunk.delta.message.content: - response += content.data - else: - response += str(chunk.delta.message.content) - - if chunk.delta.usage: - increase_usage(llm_usage, chunk.delta.usage) - current_llm_usage = chunk.delta.usage - - yield chunk - else: - result = chunks - # check if there is any tool call - if self.check_blocking_tool_calls(result): - function_call_state = True - tool_calls.extend(self.extract_blocking_tool_calls(result) or []) - tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) - try: - tool_call_inputs = json.dumps( - {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False - ) - except TypeError: - # fallback: force ASCII to handle non-serializable objects - tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) - - if result.usage: - increase_usage(llm_usage, result.usage) - current_llm_usage = result.usage - - if result.message and result.message.content: - if isinstance(result.message.content, list): - for content in result.message.content: - response += content.data - else: - response += str(result.message.content) - - if not result.message.content: - result.message.content = "" - - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - yield LLMResultChunk( - model=model_instance.model, - prompt_messages=result.prompt_messages, - system_fingerprint=result.system_fingerprint, - delta=LLMResultChunkDelta( - index=0, - message=result.message, - usage=result.usage, - ), - ) - - assistant_message = AssistantPromptMessage(content="", tool_calls=[]) - if tool_calls: - assistant_message.tool_calls = [ - AssistantPromptMessage.ToolCall( - id=tool_call[0], - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False) - ), - ) - for tool_call in tool_calls - ] - else: - assistant_message.content = response - - self._current_thoughts.append(assistant_message) - - # save thought - self.save_agent_thought( - agent_thought_id=agent_thought_id, - tool_name=tool_call_names, - tool_input=tool_call_inputs, - thought=response, - tool_invoke_meta=None, - observation=None, - answer=response, - messages_ids=[], - llm_usage=current_llm_usage, - ) - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - final_answer += response + "\n" - - # Check if max iteration is reached and model still wants to call tools - if iteration_step == max_iteration_steps and tool_calls: - raise AgentMaxIterationError(app_config.agent.max_iteration) - - # call tools - tool_responses = [] - for tool_call_id, tool_call_name, tool_call_args in tool_calls: - tool_instance = tool_instances.get(tool_call_name) - if not tool_instance: - tool_response = { - "tool_call_id": tool_call_id, - "tool_call_name": tool_call_name, - "tool_response": f"there is not a tool named {tool_call_name}", - "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(), - } - else: - # invoke tool - tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( - tool=tool_instance, - tool_parameters=tool_call_args, - user_id=self.user_id, - tenant_id=self.tenant_id, - message=self.message, - invoke_from=self.application_generate_entity.invoke_from, - agent_tool_callback=self.agent_callback, - trace_manager=trace_manager, - app_id=self.application_generate_entity.app_config.app_id, - message_id=self.message.id, - conversation_id=self.conversation.id, - ) - # publish files - for message_file_id in message_files: - # publish message file - self.queue_manager.publish( - QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER - ) - # add message file ids - message_file_ids.append(message_file_id) - - tool_response = { - "tool_call_id": tool_call_id, - "tool_call_name": tool_call_name, - "tool_response": tool_invoke_response, - "meta": tool_invoke_meta.to_dict(), - } - - tool_responses.append(tool_response) - if tool_response["tool_response"] is not None: - self._current_thoughts.append( - ToolPromptMessage( - content=str(tool_response["tool_response"]), - tool_call_id=tool_call_id, - name=tool_call_name, - ) - ) - - if len(tool_responses) > 0: - # save agent thought - self.save_agent_thought( - agent_thought_id=agent_thought_id, - tool_name="", - tool_input="", - thought="", - tool_invoke_meta={ - tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses - }, - observation={ - tool_response["tool_call_name"]: tool_response["tool_response"] - for tool_response in tool_responses - }, - answer="", - messages_ids=message_file_ids, - ) - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - # update prompt tool - for prompt_tool in prompt_messages_tools: - self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) - - iteration_step += 1 - - # publish end event - self.queue_manager.publish( - QueueMessageEndEvent( - llm_result=LLMResult( - model=model_instance.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content=final_answer), - usage=llm_usage["usage"] or LLMUsage.empty_usage(), - system_fingerprint="", - ) - ), - PublishFrom.APPLICATION_MANAGER, - ) - - def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: - """ - Check if there is any tool call in llm result chunk - """ - if llm_result_chunk.delta.message.tool_calls: - return True - return False - - def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool: - """ - Check if there is any blocking tool call in llm result - """ - if llm_result.message.tool_calls: - return True - return False - - def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]: - """ - Extract tool calls from llm result chunk - - Returns: - List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)] - """ - tool_calls = [] - for prompt_message in llm_result_chunk.delta.message.tool_calls: - args = {} - if prompt_message.function.arguments != "": - args = json.loads(prompt_message.function.arguments) - - tool_calls.append( - ( - prompt_message.id, - prompt_message.function.name, - args, - ) - ) - - return tool_calls - - def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]: - """ - Extract blocking tool calls from llm result - - Returns: - List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)] - """ - tool_calls = [] - for prompt_message in llm_result.message.tool_calls: - args = {} - if prompt_message.function.arguments != "": - args = json.loads(prompt_message.function.arguments) - - tool_calls.append( - ( - prompt_message.id, - prompt_message.function.name, - args, - ) - ) - - return tool_calls - - def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: - """ - Initialize system message - """ - if not prompt_messages and prompt_template: - return [ - SystemPromptMessage(content=prompt_template), - ] - - if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: - prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) - - return prompt_messages or [] - - def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: - """ - Organize user query - """ - if self.files: - # get image detail config - image_detail_config = ( - self.application_generate_entity.file_upload_config.image_config.detail - if ( - self.application_generate_entity.file_upload_config - and self.application_generate_entity.file_upload_config.image_config - ) - else None - ) - image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - - prompt_message_contents: list[PromptMessageContentUnionTypes] = [] - for file in self.files: - prompt_message_contents.append( - file_manager.to_prompt_message_content( - file, - image_detail_config=image_detail_config, - ) - ) - prompt_message_contents.append(TextPromptMessageContent(data=query)) - - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: - prompt_messages.append(UserPromptMessage(content=query)) - - return prompt_messages - - def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: - """ - As for now, gpt supports both fc and vision at the first iteration. - We need to remove the image messages from the prompt messages at the first iteration. - """ - prompt_messages = deepcopy(prompt_messages) - - for prompt_message in prompt_messages: - if isinstance(prompt_message, UserPromptMessage): - if isinstance(prompt_message.content, list): - prompt_message.content = "\n".join( - [ - content.data - if content.type == PromptMessageContentType.TEXT - else "[image]" - if content.type == PromptMessageContentType.IMAGE - else "[file]" - for content in prompt_message.content - ] - ) - - return prompt_messages - - def _organize_prompt_messages(self): - prompt_template = self.app_config.prompt_template.simple_prompt_template or "" - self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) - query_prompt_messages = self._organize_user_query(self.query or "", []) - - self.history_prompt_messages = AgentHistoryPromptTransform( - model_config=self.model_config, - prompt_messages=[*query_prompt_messages, *self._current_thoughts], - history_messages=self.history_prompt_messages, - memory=self.memory, - ).get_prompt() - - prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts] - if len(self._current_thoughts) != 0: - # clear messages after the first iteration - prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) - return prompt_messages diff --git a/api/core/agent/patterns/README.md b/api/core/agent/patterns/README.md new file mode 100644 index 0000000000..95b1bf87fa --- /dev/null +++ b/api/core/agent/patterns/README.md @@ -0,0 +1,55 @@ +# Agent Patterns + +A unified agent pattern module that powers both Agent V2 workflow nodes and agent applications. Strategies share a common execution contract while adapting to model capabilities and tool availability. + +## Overview + +The module applies a strategy pattern around LLM/tool orchestration. `StrategyFactory` auto-selects the best implementation based on model features or an explicit agent strategy, and each strategy streams logs and usage consistently. + +## Key Features + +- **Dual strategies** + - `FunctionCallStrategy`: uses native LLM function/tool calling when the model exposes `TOOL_CALL`, `MULTI_TOOL_CALL`, or `STREAM_TOOL_CALL`. + - `ReActStrategy`: ReAct (reasoning + acting) flow driven by `CotAgentOutputParser`, used when function calling is unavailable or explicitly requested. +- **Explicit or auto selection** + - `StrategyFactory.create_strategy` prefers an explicit `AgentEntity.Strategy` (FUNCTION_CALLING or CHAIN_OF_THOUGHT). + - Otherwise it falls back to function calling when tool-call features exist, or ReAct when they do not. +- **Unified execution contract** + - `AgentPattern.run` yields streaming `AgentLog` entries and `LLMResultChunk` data, returning an `AgentResult` with text, files, usage, and `finish_reason`. + - Iterations are configurable and hard-capped at 99 rounds; the last round forces a final answer by withholding tools. +- **Tool handling and hooks** + - Tools convert to `PromptMessageTool` objects before invocation. + - Optional `tool_invoke_hook` lets callers override tool execution (e.g., agent apps) while workflow runs use `ToolEngine.generic_invoke`. + - Tool outputs support text, links, JSON, variables, blobs, retriever resources, and file attachments; `target=="self"` files are reloaded into model context, others are returned as outputs. +- **File-aware arguments** + - Tool args accept `[File: ]` or `[Files: ]` placeholders that resolve to `File` objects before invocation, enabling models to reference uploaded files safely. +- **ReAct prompt shaping** + - System prompts replace `{{instruction}}`, `{{tools}}`, and `{{tool_names}}` placeholders. + - Adds `Observation` to stop sequences and appends scratchpad text so the model sees prior Thought/Action/Observation history. +- **Observability and accounting** + - Standardized `AgentLog` entries for rounds, model thoughts, and tool calls, including usage aggregation (`LLMUsage`) across streaming and non-streaming paths. + +## Architecture + +``` +agent/patterns/ +├── base.py # Shared utilities: logging, usage, tool invocation, file handling +├── function_call.py # Native function-calling loop with tool execution +├── react.py # ReAct loop with CoT parsing and scratchpad wiring +└── strategy_factory.py # Strategy selection by model features or explicit override +``` + +## Usage + +- For auto-selection: + - Call `StrategyFactory.create_strategy(model_features, model_instance, context, tools, files, ...)` and run the returned strategy with prompt messages and model params. +- For explicit behavior: + - Pass `agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING` to force native calls (falls back to ReAct if unsupported), or `CHAIN_OF_THOUGHT` to force ReAct. +- Both strategies stream chunks and logs; collect the generator output until it returns an `AgentResult`. + +## Integration Points + +- **Model runtime**: delegates to `ModelInstance.invoke_llm` for both streaming and non-streaming calls. +- **Tool system**: defaults to `ToolEngine.generic_invoke`, with `tool_invoke_hook` for custom callers. +- **Files**: flows through `File` objects for tool inputs/outputs and model-context attachments. +- **Execution context**: `ExecutionContext` fields (user/app/conversation/message) propagate to tool invocations and logging. diff --git a/api/core/agent/patterns/__init__.py b/api/core/agent/patterns/__init__.py new file mode 100644 index 0000000000..8a3b125533 --- /dev/null +++ b/api/core/agent/patterns/__init__.py @@ -0,0 +1,19 @@ +"""Agent patterns module. + +This module provides different strategies for agent execution: +- FunctionCallStrategy: Uses native function/tool calling +- ReActStrategy: Uses ReAct (Reasoning + Acting) approach +- StrategyFactory: Factory for creating strategies based on model features +""" + +from .base import AgentPattern +from .function_call import FunctionCallStrategy +from .react import ReActStrategy +from .strategy_factory import StrategyFactory + +__all__ = [ + "AgentPattern", + "FunctionCallStrategy", + "ReActStrategy", + "StrategyFactory", +] diff --git a/api/core/agent/patterns/base.py b/api/core/agent/patterns/base.py new file mode 100644 index 0000000000..d98fa005a3 --- /dev/null +++ b/api/core/agent/patterns/base.py @@ -0,0 +1,474 @@ +"""Base class for agent strategies.""" + +from __future__ import annotations + +import json +import re +import time +from abc import ABC, abstractmethod +from collections.abc import Callable, Generator +from typing import TYPE_CHECKING, Any + +from core.agent.entities import AgentLog, AgentResult, ExecutionContext +from core.file import File +from core.model_manager import ModelInstance +from core.model_runtime.entities import ( + AssistantPromptMessage, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.message_entities import TextPromptMessageContent +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMeta + +if TYPE_CHECKING: + from core.tools.__base.tool import Tool + +# Type alias for tool invoke hook +# Returns: (response_content, message_file_ids, tool_invoke_meta) +ToolInvokeHook = Callable[["Tool", dict[str, Any], str], tuple[str, list[str], ToolInvokeMeta]] + + +class AgentPattern(ABC): + """Base class for agent execution strategies.""" + + def __init__( + self, + model_instance: ModelInstance, + tools: list[Tool], + context: ExecutionContext, + max_iterations: int = 10, + workflow_call_depth: int = 0, + files: list[File] = [], + tool_invoke_hook: ToolInvokeHook | None = None, + ): + """Initialize the agent strategy.""" + self.model_instance = model_instance + self.tools = tools + self.context = context + self.max_iterations = min(max_iterations, 99) # Cap at 99 iterations + self.workflow_call_depth = workflow_call_depth + self.files: list[File] = files + self.tool_invoke_hook = tool_invoke_hook + + @abstractmethod + def run( + self, + prompt_messages: list[PromptMessage], + model_parameters: dict[str, Any], + stop: list[str] = [], + stream: bool = True, + ) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]: + """Execute the agent strategy.""" + pass + + def _accumulate_usage(self, total_usage: dict[str, Any], delta_usage: LLMUsage) -> None: + """Accumulate LLM usage statistics.""" + if not total_usage.get("usage"): + # Create a copy to avoid modifying the original + total_usage["usage"] = LLMUsage( + prompt_tokens=delta_usage.prompt_tokens, + prompt_unit_price=delta_usage.prompt_unit_price, + prompt_price_unit=delta_usage.prompt_price_unit, + prompt_price=delta_usage.prompt_price, + completion_tokens=delta_usage.completion_tokens, + completion_unit_price=delta_usage.completion_unit_price, + completion_price_unit=delta_usage.completion_price_unit, + completion_price=delta_usage.completion_price, + total_tokens=delta_usage.total_tokens, + total_price=delta_usage.total_price, + currency=delta_usage.currency, + latency=delta_usage.latency, + ) + else: + current: LLMUsage = total_usage["usage"] + current.prompt_tokens += delta_usage.prompt_tokens + current.completion_tokens += delta_usage.completion_tokens + current.total_tokens += delta_usage.total_tokens + current.prompt_price += delta_usage.prompt_price + current.completion_price += delta_usage.completion_price + current.total_price += delta_usage.total_price + + def _extract_content(self, content: Any) -> str: + """Extract text content from message content.""" + if isinstance(content, list): + # Content items are PromptMessageContentUnionTypes + text_parts = [] + for c in content: + # Check if it's a TextPromptMessageContent (which has data attribute) + if isinstance(c, TextPromptMessageContent): + text_parts.append(c.data) + return "".join(text_parts) + return str(content) + + def _has_tool_calls(self, chunk: LLMResultChunk) -> bool: + """Check if chunk contains tool calls.""" + # LLMResultChunk always has delta attribute + return bool(chunk.delta.message and chunk.delta.message.tool_calls) + + def _has_tool_calls_result(self, result: LLMResult) -> bool: + """Check if result contains tool calls (non-streaming).""" + # LLMResult always has message attribute + return bool(result.message and result.message.tool_calls) + + def _extract_tool_calls(self, chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]: + """Extract tool calls from streaming chunk.""" + tool_calls: list[tuple[str, str, dict[str, Any]]] = [] + if chunk.delta.message and chunk.delta.message.tool_calls: + for tool_call in chunk.delta.message.tool_calls: + if tool_call.function: + try: + args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {} + except json.JSONDecodeError: + args = {} + tool_calls.append((tool_call.id or "", tool_call.function.name, args)) + return tool_calls + + def _extract_tool_calls_result(self, result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]: + """Extract tool calls from non-streaming result.""" + tool_calls = [] + if result.message and result.message.tool_calls: + for tool_call in result.message.tool_calls: + if tool_call.function: + try: + args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {} + except json.JSONDecodeError: + args = {} + tool_calls.append((tool_call.id or "", tool_call.function.name, args)) + return tool_calls + + def _extract_text_from_message(self, message: PromptMessage) -> str: + """Extract text content from a prompt message.""" + # PromptMessage always has content attribute + content = message.content + if isinstance(content, str): + return content + elif isinstance(content, list): + # Extract text from content list + text_parts = [] + for item in content: + if isinstance(item, TextPromptMessageContent): + text_parts.append(item.data) + return " ".join(text_parts) + return "" + + def _get_tool_metadata(self, tool_instance: Tool) -> dict[AgentLog.LogMetadata, Any]: + """Get metadata for a tool including provider and icon info.""" + from core.tools.tool_manager import ToolManager + + metadata: dict[AgentLog.LogMetadata, Any] = {} + if tool_instance.entity and tool_instance.entity.identity: + identity = tool_instance.entity.identity + if identity.provider: + metadata[AgentLog.LogMetadata.PROVIDER] = identity.provider + + # Get icon using ToolManager for proper URL generation + tenant_id = self.context.tenant_id + if tenant_id and identity.provider: + try: + provider_type = tool_instance.tool_provider_type() + icon = ToolManager.get_tool_icon(tenant_id, provider_type, identity.provider) + if isinstance(icon, str): + metadata[AgentLog.LogMetadata.ICON] = icon + elif isinstance(icon, dict): + # Handle icon dict with background/content or light/dark variants + metadata[AgentLog.LogMetadata.ICON] = icon + except Exception: + # Fallback to identity.icon if ToolManager fails + if identity.icon: + metadata[AgentLog.LogMetadata.ICON] = identity.icon + elif identity.icon: + metadata[AgentLog.LogMetadata.ICON] = identity.icon + return metadata + + def _create_log( + self, + label: str, + log_type: AgentLog.LogType, + status: AgentLog.LogStatus, + data: dict[str, Any] | None = None, + parent_id: str | None = None, + extra_metadata: dict[AgentLog.LogMetadata, Any] | None = None, + ) -> AgentLog: + """Create a new AgentLog with standard metadata.""" + metadata: dict[AgentLog.LogMetadata, Any] = { + AgentLog.LogMetadata.STARTED_AT: time.perf_counter(), + } + if extra_metadata: + metadata.update(extra_metadata) + + return AgentLog( + label=label, + log_type=log_type, + status=status, + data=data or {}, + parent_id=parent_id, + metadata=metadata, + ) + + def _finish_log( + self, + log: AgentLog, + data: dict[str, Any] | None = None, + usage: LLMUsage | None = None, + ) -> AgentLog: + """Finish an AgentLog by updating its status and metadata.""" + log.status = AgentLog.LogStatus.SUCCESS + + if data is not None: + log.data = data + + # Calculate elapsed time + started_at = log.metadata.get(AgentLog.LogMetadata.STARTED_AT, time.perf_counter()) + finished_at = time.perf_counter() + + # Update metadata + log.metadata = { + **log.metadata, + AgentLog.LogMetadata.FINISHED_AT: finished_at, + # Calculate elapsed time in seconds + AgentLog.LogMetadata.ELAPSED_TIME: round(finished_at - started_at, 4), + } + + # Add usage information if provided + if usage: + log.metadata.update( + { + AgentLog.LogMetadata.TOTAL_PRICE: usage.total_price, + AgentLog.LogMetadata.CURRENCY: usage.currency, + AgentLog.LogMetadata.TOTAL_TOKENS: usage.total_tokens, + AgentLog.LogMetadata.LLM_USAGE: usage, + } + ) + + return log + + def _replace_file_references(self, tool_args: dict[str, Any]) -> dict[str, Any]: + """ + Replace file references in tool arguments with actual File objects. + + Args: + tool_args: Dictionary of tool arguments + + Returns: + Updated tool arguments with file references replaced + """ + # Process each argument in the dictionary + processed_args: dict[str, Any] = {} + for key, value in tool_args.items(): + processed_args[key] = self._process_file_reference(value) + return processed_args + + def _process_file_reference(self, data: Any) -> Any: + """ + Recursively process data to replace file references. + Supports both single file [File: file_id] and multiple files [Files: file_id1, file_id2, ...]. + + Args: + data: The data to process (can be dict, list, str, or other types) + + Returns: + Processed data with file references replaced + """ + single_file_pattern = re.compile(r"^\[File:\s*([^\]]+)\]$") + multiple_files_pattern = re.compile(r"^\[Files:\s*([^\]]+)\]$") + + if isinstance(data, dict): + # Process dictionary recursively + return {key: self._process_file_reference(value) for key, value in data.items()} + elif isinstance(data, list): + # Process list recursively + return [self._process_file_reference(item) for item in data] + elif isinstance(data, str): + # Check for single file pattern [File: file_id] + single_match = single_file_pattern.match(data.strip()) + if single_match: + file_id = single_match.group(1).strip() + # Find the file in self.files + for file in self.files: + if file.id and str(file.id) == file_id: + return file + # If file not found, return original value + return data + + # Check for multiple files pattern [Files: file_id1, file_id2, ...] + multiple_match = multiple_files_pattern.match(data.strip()) + if multiple_match: + file_ids_str = multiple_match.group(1).strip() + # Split by comma and strip whitespace + file_ids = [fid.strip() for fid in file_ids_str.split(",")] + + # Find all matching files + matched_files: list[File] = [] + for file_id in file_ids: + for file in self.files: + if file.id and str(file.id) == file_id: + matched_files.append(file) + break + + # Return list of files if any were found, otherwise return original + return matched_files or data + + return data + else: + # Return other types as-is + return data + + def _create_text_chunk(self, text: str, prompt_messages: list[PromptMessage]) -> LLMResultChunk: + """Create a text chunk for streaming.""" + return LLMResultChunk( + model=self.model_instance.model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=text), + usage=None, + ), + system_fingerprint="", + ) + + def _invoke_tool( + self, + tool_instance: Tool, + tool_args: dict[str, Any], + tool_name: str, + ) -> tuple[str, list[File], ToolInvokeMeta | None]: + """ + Invoke a tool and collect its response. + + Args: + tool_instance: The tool instance to invoke + tool_args: Tool arguments + tool_name: Name of the tool + + Returns: + Tuple of (response_content, tool_files, tool_invoke_meta) + """ + # Process tool_args to replace file references with actual File objects + tool_args = self._replace_file_references(tool_args) + + # If a tool invoke hook is set, use it instead of generic_invoke + if self.tool_invoke_hook: + response_content, _, tool_invoke_meta = self.tool_invoke_hook(tool_instance, tool_args, tool_name) + # Note: message_file_ids are stored in DB, we don't convert them to File objects here + # The caller (AgentAppRunner) handles file publishing + return response_content, [], tool_invoke_meta + + # Default: use generic_invoke for workflow scenarios + # Import here to avoid circular import + from core.tools.tool_engine import DifyWorkflowCallbackHandler, ToolEngine + + tool_response = ToolEngine().generic_invoke( + tool=tool_instance, + tool_parameters=tool_args, + user_id=self.context.user_id or "", + workflow_tool_callback=DifyWorkflowCallbackHandler(), + workflow_call_depth=self.workflow_call_depth, + app_id=self.context.app_id, + conversation_id=self.context.conversation_id, + message_id=self.context.message_id, + ) + + # Collect response and files + response_content = "" + tool_files: list[File] = [] + + for response in tool_response: + if response.type == ToolInvokeMessage.MessageType.TEXT: + assert isinstance(response.message, ToolInvokeMessage.TextMessage) + response_content += response.message.text + + elif response.type == ToolInvokeMessage.MessageType.LINK: + # Handle link messages + if isinstance(response.message, ToolInvokeMessage.TextMessage): + response_content += f"[Link: {response.message.text}]" + + elif response.type == ToolInvokeMessage.MessageType.IMAGE: + # Handle image URL messages + if isinstance(response.message, ToolInvokeMessage.TextMessage): + response_content += f"[Image: {response.message.text}]" + + elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK: + # Handle image link messages + if isinstance(response.message, ToolInvokeMessage.TextMessage): + response_content += f"[Image: {response.message.text}]" + + elif response.type == ToolInvokeMessage.MessageType.BINARY_LINK: + # Handle binary file link messages + if isinstance(response.message, ToolInvokeMessage.TextMessage): + filename = response.meta.get("filename", "file") if response.meta else "file" + response_content += f"[File: {filename} - {response.message.text}]" + + elif response.type == ToolInvokeMessage.MessageType.JSON: + # Handle JSON messages + if isinstance(response.message, ToolInvokeMessage.JsonMessage): + response_content += json.dumps(response.message.json_object, ensure_ascii=False, indent=2) + + elif response.type == ToolInvokeMessage.MessageType.BLOB: + # Handle blob messages - convert to text representation + if isinstance(response.message, ToolInvokeMessage.BlobMessage): + mime_type = ( + response.meta.get("mime_type", "application/octet-stream") + if response.meta + else "application/octet-stream" + ) + size = len(response.message.blob) + response_content += f"[Binary data: {mime_type}, size: {size} bytes]" + + elif response.type == ToolInvokeMessage.MessageType.VARIABLE: + # Handle variable messages + if isinstance(response.message, ToolInvokeMessage.VariableMessage): + var_name = response.message.variable_name + var_value = response.message.variable_value + if isinstance(var_value, str): + response_content += var_value + else: + response_content += f"[Variable {var_name}: {json.dumps(var_value, ensure_ascii=False)}]" + + elif response.type == ToolInvokeMessage.MessageType.BLOB_CHUNK: + # Handle blob chunk messages - these are parts of a larger blob + if isinstance(response.message, ToolInvokeMessage.BlobChunkMessage): + response_content += f"[Blob chunk {response.message.sequence}: {len(response.message.blob)} bytes]" + + elif response.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES: + # Handle retriever resources messages + if isinstance(response.message, ToolInvokeMessage.RetrieverResourceMessage): + response_content += response.message.context + + elif response.type == ToolInvokeMessage.MessageType.FILE: + # Extract file from meta + if response.meta and "file" in response.meta: + file = response.meta["file"] + if isinstance(file, File): + # Check if file is for model or tool output + if response.meta.get("target") == "self": + # File is for model - add to files for next prompt + self.files.append(file) + response_content += f"File '{file.filename}' has been loaded into your context." + else: + # File is tool output + tool_files.append(file) + + return response_content, tool_files, None + + def _find_tool_by_name(self, tool_name: str) -> Tool | None: + """Find a tool instance by its name.""" + for tool in self.tools: + if tool.entity.identity.name == tool_name: + return tool + return None + + def _convert_tools_to_prompt_format(self) -> list[PromptMessageTool]: + """Convert tools to prompt message format.""" + prompt_tools: list[PromptMessageTool] = [] + for tool in self.tools: + prompt_tools.append(tool.to_prompt_message_tool()) + return prompt_tools + + def _update_usage_with_empty(self, llm_usage: dict[str, Any]) -> None: + """Initialize usage tracking with empty usage if not set.""" + if "usage" not in llm_usage or llm_usage["usage"] is None: + llm_usage["usage"] = LLMUsage.empty_usage() diff --git a/api/core/agent/patterns/function_call.py b/api/core/agent/patterns/function_call.py new file mode 100644 index 0000000000..a0f0d30ab3 --- /dev/null +++ b/api/core/agent/patterns/function_call.py @@ -0,0 +1,299 @@ +"""Function Call strategy implementation.""" + +import json +from collections.abc import Generator +from typing import Any, Union + +from core.agent.entities import AgentLog, AgentResult +from core.file import File +from core.model_runtime.entities import ( + AssistantPromptMessage, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMUsage, + PromptMessage, + PromptMessageTool, + ToolPromptMessage, +) +from core.tools.entities.tool_entities import ToolInvokeMeta + +from .base import AgentPattern + + +class FunctionCallStrategy(AgentPattern): + """Function Call strategy using model's native tool calling capability.""" + + def run( + self, + prompt_messages: list[PromptMessage], + model_parameters: dict[str, Any], + stop: list[str] = [], + stream: bool = True, + ) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]: + """Execute the function call agent strategy.""" + # Convert tools to prompt format + prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format() + + # Initialize tracking + iteration_step: int = 1 + max_iterations: int = self.max_iterations + 1 + function_call_state: bool = True + total_usage: dict[str, LLMUsage | None] = {"usage": None} + messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy + final_text: str = "" + finish_reason: str | None = None + output_files: list[File] = [] # Track files produced by tools + + while function_call_state and iteration_step <= max_iterations: + function_call_state = False + round_log = self._create_log( + label=f"ROUND {iteration_step}", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={}, + ) + yield round_log + # On last iteration, remove tools to force final answer + current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools + model_log = self._create_log( + label=f"{self.model_instance.model} Thought", + log_type=AgentLog.LogType.THOUGHT, + status=AgentLog.LogStatus.START, + data={}, + parent_id=round_log.id, + extra_metadata={ + AgentLog.LogMetadata.PROVIDER: self.model_instance.provider, + }, + ) + yield model_log + + # Track usage for this round only + round_usage: dict[str, LLMUsage | None] = {"usage": None} + + # Invoke model + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm( + prompt_messages=messages, + model_parameters=model_parameters, + tools=current_tools, + stop=stop, + stream=stream, + user=self.context.user_id, + callbacks=[], + ) + + # Process response + tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks( + chunks, round_usage, model_log + ) + messages.append(self._create_assistant_message(response_content, tool_calls)) + + # Accumulate to total usage + round_usage_value = round_usage.get("usage") + if round_usage_value: + self._accumulate_usage(total_usage, round_usage_value) + + # Update final text if no tool calls (this is likely the final answer) + if not tool_calls: + final_text = response_content + + # Update finish reason + if chunk_finish_reason: + finish_reason = chunk_finish_reason + + # Process tool calls + tool_outputs: dict[str, str] = {} + if tool_calls: + function_call_state = True + # Execute tools + for tool_call_id, tool_name, tool_args in tool_calls: + tool_response, tool_files, _ = yield from self._handle_tool_call( + tool_name, tool_args, tool_call_id, messages, round_log + ) + tool_outputs[tool_name] = tool_response + # Track files produced by tools + output_files.extend(tool_files) + yield self._finish_log( + round_log, + data={ + "llm_result": response_content, + "tool_calls": [ + {"name": tc[1], "args": tc[2], "output": tool_outputs.get(tc[1], "")} for tc in tool_calls + ] + if tool_calls + else [], + "final_answer": final_text if not function_call_state else None, + }, + usage=round_usage.get("usage"), + ) + iteration_step += 1 + + # Return final result + from core.agent.entities import AgentResult + + return AgentResult( + text=final_text, + files=output_files, + usage=total_usage.get("usage") or LLMUsage.empty_usage(), + finish_reason=finish_reason, + ) + + def _handle_chunks( + self, + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult], + llm_usage: dict[str, LLMUsage | None], + start_log: AgentLog, + ) -> Generator[ + LLMResultChunk | AgentLog, + None, + tuple[list[tuple[str, str, dict[str, Any]]], str, str | None], + ]: + """Handle LLM response chunks and extract tool calls and content. + + Returns a tuple of (tool_calls, response_content, finish_reason). + """ + tool_calls: list[tuple[str, str, dict[str, Any]]] = [] + response_content: str = "" + finish_reason: str | None = None + if isinstance(chunks, Generator): + # Streaming response + for chunk in chunks: + # Extract tool calls + if self._has_tool_calls(chunk): + tool_calls.extend(self._extract_tool_calls(chunk)) + + # Extract content + if chunk.delta.message and chunk.delta.message.content: + response_content += self._extract_content(chunk.delta.message.content) + + # Track usage + if chunk.delta.usage: + self._accumulate_usage(llm_usage, chunk.delta.usage) + + # Capture finish reason + if chunk.delta.finish_reason: + finish_reason = chunk.delta.finish_reason + + yield chunk + else: + # Non-streaming response + result: LLMResult = chunks + + if self._has_tool_calls_result(result): + tool_calls.extend(self._extract_tool_calls_result(result)) + + if result.message and result.message.content: + response_content += self._extract_content(result.message.content) + + if result.usage: + self._accumulate_usage(llm_usage, result.usage) + + # Convert to streaming format + yield LLMResultChunk( + model=result.model, + prompt_messages=result.prompt_messages, + delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage), + ) + yield self._finish_log( + start_log, + data={ + "result": response_content, + }, + usage=llm_usage.get("usage"), + ) + return tool_calls, response_content, finish_reason + + def _create_assistant_message( + self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None + ) -> AssistantPromptMessage: + """Create assistant message with tool calls.""" + if tool_calls is None: + return AssistantPromptMessage(content=content) + return AssistantPromptMessage( + content=content or "", + tool_calls=[ + AssistantPromptMessage.ToolCall( + id=tc[0], + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tc[1], arguments=json.dumps(tc[2])), + ) + for tc in tool_calls + ], + ) + + def _handle_tool_call( + self, + tool_name: str, + tool_args: dict[str, Any], + tool_call_id: str, + messages: list[PromptMessage], + round_log: AgentLog, + ) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None]]: + """Handle a single tool call and return response with files and meta.""" + # Find tool + tool_instance = self._find_tool_by_name(tool_name) + if not tool_instance: + raise ValueError(f"Tool {tool_name} not found") + + # Get tool metadata (provider, icon, etc.) + tool_metadata = self._get_tool_metadata(tool_instance) + + # Create tool call log + tool_call_log = self._create_log( + label=f"CALL {tool_name}", + log_type=AgentLog.LogType.TOOL_CALL, + status=AgentLog.LogStatus.START, + data={ + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "tool_args": tool_args, + }, + parent_id=round_log.id, + extra_metadata=tool_metadata, + ) + yield tool_call_log + + # Invoke tool using base class method with error handling + try: + response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name) + + yield self._finish_log( + tool_call_log, + data={ + **tool_call_log.data, + "output": response_content, + "files": len(tool_files), + "meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None, + }, + ) + final_content = response_content or "Tool executed successfully" + # Add tool response to messages + messages.append( + ToolPromptMessage( + content=final_content, + tool_call_id=tool_call_id, + name=tool_name, + ) + ) + return response_content, tool_files, tool_invoke_meta + except Exception as e: + # Tool invocation failed, yield error log + error_message = str(e) + tool_call_log.status = AgentLog.LogStatus.ERROR + tool_call_log.error = error_message + tool_call_log.data = { + **tool_call_log.data, + "error": error_message, + } + yield tool_call_log + + # Add error message to conversation + error_content = f"Tool execution failed: {error_message}" + messages.append( + ToolPromptMessage( + content=error_content, + tool_call_id=tool_call_id, + name=tool_name, + ) + ) + return error_content, [], None diff --git a/api/core/agent/patterns/react.py b/api/core/agent/patterns/react.py new file mode 100644 index 0000000000..87a9fa9b65 --- /dev/null +++ b/api/core/agent/patterns/react.py @@ -0,0 +1,418 @@ +"""ReAct strategy implementation.""" + +from __future__ import annotations + +import json +from collections.abc import Generator +from typing import TYPE_CHECKING, Any, Union + +from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext +from core.agent.output_parser.cot_output_parser import CotAgentOutputParser +from core.file import File +from core.model_manager import ModelInstance +from core.model_runtime.entities import ( + AssistantPromptMessage, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + PromptMessage, + SystemPromptMessage, +) + +from .base import AgentPattern, ToolInvokeHook + +if TYPE_CHECKING: + from core.tools.__base.tool import Tool + + +class ReActStrategy(AgentPattern): + """ReAct strategy using reasoning and acting approach.""" + + def __init__( + self, + model_instance: ModelInstance, + tools: list[Tool], + context: ExecutionContext, + max_iterations: int = 10, + workflow_call_depth: int = 0, + files: list[File] = [], + tool_invoke_hook: ToolInvokeHook | None = None, + instruction: str = "", + ): + """Initialize the ReAct strategy with instruction support.""" + super().__init__( + model_instance=model_instance, + tools=tools, + context=context, + max_iterations=max_iterations, + workflow_call_depth=workflow_call_depth, + files=files, + tool_invoke_hook=tool_invoke_hook, + ) + self.instruction = instruction + + def run( + self, + prompt_messages: list[PromptMessage], + model_parameters: dict[str, Any], + stop: list[str] = [], + stream: bool = True, + ) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]: + """Execute the ReAct agent strategy.""" + # Initialize tracking + agent_scratchpad: list[AgentScratchpadUnit] = [] + iteration_step: int = 1 + max_iterations: int = self.max_iterations + 1 + react_state: bool = True + total_usage: dict[str, Any] = {"usage": None} + output_files: list[File] = [] # Track files produced by tools + final_text: str = "" + finish_reason: str | None = None + + # Add "Observation" to stop sequences + if "Observation" not in stop: + stop = stop.copy() + stop.append("Observation") + + while react_state and iteration_step <= max_iterations: + react_state = False + round_log = self._create_log( + label=f"ROUND {iteration_step}", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={}, + ) + yield round_log + + # Build prompt with/without tools based on iteration + include_tools = iteration_step < max_iterations + current_messages = self._build_prompt_with_react_format( + prompt_messages, agent_scratchpad, include_tools, self.instruction + ) + + model_log = self._create_log( + label=f"{self.model_instance.model} Thought", + log_type=AgentLog.LogType.THOUGHT, + status=AgentLog.LogStatus.START, + data={}, + parent_id=round_log.id, + extra_metadata={ + AgentLog.LogMetadata.PROVIDER: self.model_instance.provider, + }, + ) + yield model_log + + # Track usage for this round only + round_usage: dict[str, Any] = {"usage": None} + + # Use current messages directly (files are handled by base class if needed) + messages_to_use = current_messages + + # Invoke model + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm( + prompt_messages=messages_to_use, + model_parameters=model_parameters, + stop=stop, + stream=stream, + user=self.context.user_id or "", + callbacks=[], + ) + + # Process response + scratchpad, chunk_finish_reason = yield from self._handle_chunks( + chunks, round_usage, model_log, current_messages + ) + agent_scratchpad.append(scratchpad) + + # Accumulate to total usage + round_usage_value = round_usage.get("usage") + if round_usage_value: + self._accumulate_usage(total_usage, round_usage_value) + + # Update finish reason + if chunk_finish_reason: + finish_reason = chunk_finish_reason + + # Check if we have an action to execute + if scratchpad.action and scratchpad.action.action_name.lower() != "final answer": + react_state = True + # Execute tool + observation, tool_files = yield from self._handle_tool_call( + scratchpad.action, current_messages, round_log + ) + scratchpad.observation = observation + # Track files produced by tools + output_files.extend(tool_files) + + # Add observation to scratchpad for display + yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages) + else: + # Extract final answer + if scratchpad.action and scratchpad.action.action_input: + final_answer = scratchpad.action.action_input + if isinstance(final_answer, dict): + final_answer = json.dumps(final_answer, ensure_ascii=False) + final_text = str(final_answer) + elif scratchpad.thought: + # If no action but we have thought, use thought as final answer + final_text = scratchpad.thought + + yield self._finish_log( + round_log, + data={ + "thought": scratchpad.thought, + "action": scratchpad.action_str if scratchpad.action else None, + "observation": scratchpad.observation or None, + "final_answer": final_text if not react_state else None, + }, + usage=round_usage.get("usage"), + ) + iteration_step += 1 + + # Return final result + + from core.agent.entities import AgentResult + + return AgentResult( + text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason + ) + + def _build_prompt_with_react_format( + self, + original_messages: list[PromptMessage], + agent_scratchpad: list[AgentScratchpadUnit], + include_tools: bool = True, + instruction: str = "", + ) -> list[PromptMessage]: + """Build prompt messages with ReAct format.""" + # Copy messages to avoid modifying original + messages = list(original_messages) + + # Find and update the system prompt that should already exist + system_prompt_found = False + for i, msg in enumerate(messages): + if isinstance(msg, SystemPromptMessage): + system_prompt_found = True + # The system prompt from frontend already has the template, just replace placeholders + + # Format tools + tools_str = "" + tool_names = [] + if include_tools and self.tools: + # Convert tools to prompt message tools format + prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools] + tool_names = [tool.name for tool in prompt_tools] + + # Format tools as JSON for comprehensive information + from core.model_runtime.utils.encoders import jsonable_encoder + + tools_str = json.dumps(jsonable_encoder(prompt_tools), indent=2) + tool_names_str = ", ".join(f'"{name}"' for name in tool_names) + else: + tools_str = "No tools available" + tool_names_str = "" + + # Replace placeholders in the existing system prompt + updated_content = msg.content + assert isinstance(updated_content, str) + updated_content = updated_content.replace("{{instruction}}", instruction) + updated_content = updated_content.replace("{{tools}}", tools_str) + updated_content = updated_content.replace("{{tool_names}}", tool_names_str) + + # Create new SystemPromptMessage with updated content + messages[i] = SystemPromptMessage(content=updated_content) + break + + # If no system prompt found, that's unexpected but add scratchpad anyway + if not system_prompt_found: + # This shouldn't happen if frontend is working correctly + pass + + # Format agent scratchpad + scratchpad_str = "" + if agent_scratchpad: + scratchpad_parts: list[str] = [] + for unit in agent_scratchpad: + if unit.thought: + scratchpad_parts.append(f"Thought: {unit.thought}") + if unit.action_str: + scratchpad_parts.append(f"Action:\n```\n{unit.action_str}\n```") + if unit.observation: + scratchpad_parts.append(f"Observation: {unit.observation}") + scratchpad_str = "\n".join(scratchpad_parts) + + # If there's a scratchpad, append it to the last message + if scratchpad_str: + messages.append(AssistantPromptMessage(content=scratchpad_str)) + + return messages + + def _handle_chunks( + self, + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult], + llm_usage: dict[str, Any], + model_log: AgentLog, + current_messages: list[PromptMessage], + ) -> Generator[ + LLMResultChunk | AgentLog, + None, + tuple[AgentScratchpadUnit, str | None], + ]: + """Handle LLM response chunks and extract action/thought. + + Returns a tuple of (scratchpad_unit, finish_reason). + """ + usage_dict: dict[str, Any] = {} + + # Convert non-streaming to streaming format if needed + if isinstance(chunks, LLMResult): + # Create a generator from the LLMResult + def result_to_chunks() -> Generator[LLMResultChunk, None, None]: + yield LLMResultChunk( + model=chunks.model, + prompt_messages=chunks.prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=chunks.message, + usage=chunks.usage, + finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do + ), + system_fingerprint=chunks.system_fingerprint or "", + ) + + streaming_chunks = result_to_chunks() + else: + streaming_chunks = chunks + + react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict) + + # Initialize scratchpad unit + scratchpad = AgentScratchpadUnit( + agent_response="", + thought="", + action_str="", + observation="", + action=None, + ) + + finish_reason: str | None = None + + # Process chunks + for chunk in react_chunks: + if isinstance(chunk, AgentScratchpadUnit.Action): + # Action detected + action_str = json.dumps(chunk.model_dump()) + scratchpad.agent_response = (scratchpad.agent_response or "") + action_str + scratchpad.action_str = action_str + scratchpad.action = chunk + + yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages) + else: + # Text chunk + chunk_text = str(chunk) + scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text + scratchpad.thought = (scratchpad.thought or "") + chunk_text + + yield self._create_text_chunk(chunk_text, current_messages) + + # Update usage + if usage_dict.get("usage"): + if llm_usage.get("usage"): + self._accumulate_usage(llm_usage, usage_dict["usage"]) + else: + llm_usage["usage"] = usage_dict["usage"] + + # Clean up thought + scratchpad.thought = (scratchpad.thought or "").strip() or "I am thinking about how to help you" + + # Finish model log + yield self._finish_log( + model_log, + data={ + "thought": scratchpad.thought, + "action": scratchpad.action_str if scratchpad.action else None, + }, + usage=llm_usage.get("usage"), + ) + + return scratchpad, finish_reason + + def _handle_tool_call( + self, + action: AgentScratchpadUnit.Action, + prompt_messages: list[PromptMessage], + round_log: AgentLog, + ) -> Generator[AgentLog, None, tuple[str, list[File]]]: + """Handle tool call and return observation with files.""" + tool_name = action.action_name + tool_args: dict[str, Any] | str = action.action_input + + # Find tool instance first to get metadata + tool_instance = self._find_tool_by_name(tool_name) + tool_metadata = self._get_tool_metadata(tool_instance) if tool_instance else {} + + # Start tool log with tool metadata + tool_log = self._create_log( + label=f"CALL {tool_name}", + log_type=AgentLog.LogType.TOOL_CALL, + status=AgentLog.LogStatus.START, + data={ + "tool_name": tool_name, + "tool_args": tool_args, + }, + parent_id=round_log.id, + extra_metadata=tool_metadata, + ) + yield tool_log + + if not tool_instance: + # Finish tool log with error + yield self._finish_log( + tool_log, + data={ + **tool_log.data, + "error": f"Tool {tool_name} not found", + }, + ) + return f"Tool {tool_name} not found", [] + + # Ensure tool_args is a dict + tool_args_dict: dict[str, Any] + if isinstance(tool_args, str): + try: + tool_args_dict = json.loads(tool_args) + except json.JSONDecodeError: + tool_args_dict = {"input": tool_args} + elif not isinstance(tool_args, dict): + tool_args_dict = {"input": str(tool_args)} + else: + tool_args_dict = tool_args + + # Invoke tool using base class method with error handling + try: + response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name) + + # Finish tool log + yield self._finish_log( + tool_log, + data={ + **tool_log.data, + "output": response_content, + "files": len(tool_files), + "meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None, + }, + ) + + return response_content or "Tool executed successfully", tool_files + except Exception as e: + # Tool invocation failed, yield error log + error_message = str(e) + tool_log.status = AgentLog.LogStatus.ERROR + tool_log.error = error_message + tool_log.data = { + **tool_log.data, + "error": error_message, + } + yield tool_log + + return f"Tool execution failed: {error_message}", [] diff --git a/api/core/agent/patterns/strategy_factory.py b/api/core/agent/patterns/strategy_factory.py new file mode 100644 index 0000000000..ad26075291 --- /dev/null +++ b/api/core/agent/patterns/strategy_factory.py @@ -0,0 +1,107 @@ +"""Strategy factory for creating agent strategies.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from core.agent.entities import AgentEntity, ExecutionContext +from core.file.models import File +from core.model_manager import ModelInstance +from core.model_runtime.entities.model_entities import ModelFeature + +from .base import AgentPattern, ToolInvokeHook +from .function_call import FunctionCallStrategy +from .react import ReActStrategy + +if TYPE_CHECKING: + from core.tools.__base.tool import Tool + + +class StrategyFactory: + """Factory for creating agent strategies based on model features.""" + + # Tool calling related features + TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL} + + @staticmethod + def create_strategy( + model_features: list[ModelFeature], + model_instance: ModelInstance, + context: ExecutionContext, + tools: list[Tool], + files: list[File], + max_iterations: int = 10, + workflow_call_depth: int = 0, + agent_strategy: AgentEntity.Strategy | None = None, + tool_invoke_hook: ToolInvokeHook | None = None, + instruction: str = "", + ) -> AgentPattern: + """ + Create an appropriate strategy based on model features. + + Args: + model_features: List of model features/capabilities + model_instance: Model instance to use + context: Execution context containing trace/audit information + tools: Available tools + files: Available files + max_iterations: Maximum iterations for the strategy + workflow_call_depth: Depth of workflow calls + agent_strategy: Optional explicit strategy override + tool_invoke_hook: Optional hook for custom tool invocation (e.g., agent_invoke) + instruction: Optional instruction for ReAct strategy + + Returns: + AgentStrategy instance + """ + # If explicit strategy is provided and it's Function Calling, try to use it if supported + if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING: + if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES: + return FunctionCallStrategy( + model_instance=model_instance, + context=context, + tools=tools, + files=files, + max_iterations=max_iterations, + workflow_call_depth=workflow_call_depth, + tool_invoke_hook=tool_invoke_hook, + ) + # Fallback to ReAct if FC is requested but not supported + + # If explicit strategy is Chain of Thought (ReAct) + if agent_strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: + return ReActStrategy( + model_instance=model_instance, + context=context, + tools=tools, + files=files, + max_iterations=max_iterations, + workflow_call_depth=workflow_call_depth, + tool_invoke_hook=tool_invoke_hook, + instruction=instruction, + ) + + # Default auto-selection logic + if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES: + # Model supports native function calling + return FunctionCallStrategy( + model_instance=model_instance, + context=context, + tools=tools, + files=files, + max_iterations=max_iterations, + workflow_call_depth=workflow_call_depth, + tool_invoke_hook=tool_invoke_hook, + ) + else: + # Use ReAct strategy for models without function calling + return ReActStrategy( + model_instance=model_instance, + context=context, + tools=tools, + files=files, + max_iterations=max_iterations, + workflow_call_depth=workflow_call_depth, + tool_invoke_hook=tool_invoke_hook, + instruction=instruction, + ) diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index ee092e55c5..a2ae8dec5b 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -20,6 +20,8 @@ from core.app.entities.queue_entities import ( QueueTextChunkEvent, ) from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature +from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer +from core.db.session_factory import session_factory from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration from core.variables.variables import VariableUnion @@ -40,6 +42,7 @@ from models import Workflow from models.enums import UserFrom from models.model import App, Conversation, Message, MessageAnnotation from models.workflow import ConversationVariable +from services.conversation_variable_updater import ConversationVariableUpdater logger = logging.getLogger(__name__) @@ -200,6 +203,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): ) workflow_entry.graph_engine.layer(persistence_layer) + conversation_variable_layer = ConversationVariablePersistenceLayer( + ConversationVariableUpdater(session_factory.get_session_maker()) + ) + workflow_entry.graph_engine.layer(conversation_variable_layer) for layer in self._graph_engine_layers: workflow_entry.graph_engine.layer(layer) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index da1e9f19b6..6c4f96c2e4 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -4,6 +4,7 @@ import re import time from collections.abc import Callable, Generator, Mapping from contextlib import contextmanager +from dataclasses import dataclass, field from threading import Thread from typing import Any, Union @@ -19,6 +20,7 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, ) from core.app.entities.queue_entities import ( + ChunkType, MessageQueueMessage, QueueAdvancedChatMessageEndEvent, QueueAgentLogEvent, @@ -70,13 +72,122 @@ from core.workflow.runtime import GraphRuntimeState from core.workflow.system_variable import SystemVariable from extensions.ext_database import db from libs.datetime_utils import naive_utc_now -from models import Account, Conversation, EndUser, Message, MessageFile +from models import Account, Conversation, EndUser, LLMGenerationDetail, Message, MessageFile from models.enums import CreatorUserRole from models.workflow import Workflow logger = logging.getLogger(__name__) +@dataclass +class StreamEventBuffer: + """ + Buffer for recording stream events in order to reconstruct the generation sequence. + Records the exact order of text chunks, thoughts, and tool calls as they stream. + """ + + # Accumulated reasoning content (each thought block is a separate element) + reasoning_content: list[str] = field(default_factory=list) + # Current reasoning buffer (accumulates until we see a different event type) + _current_reasoning: str = "" + # Tool calls with their details + tool_calls: list[dict] = field(default_factory=list) + # Tool call ID to index mapping for updating results + _tool_call_id_map: dict[str, int] = field(default_factory=dict) + # Sequence of events in stream order + sequence: list[dict] = field(default_factory=list) + # Current position in answer text + _content_position: int = 0 + # Track last event type to detect transitions + _last_event_type: str | None = None + + def _flush_current_reasoning(self) -> None: + """Flush accumulated reasoning to the list and add to sequence.""" + if self._current_reasoning.strip(): + self.reasoning_content.append(self._current_reasoning.strip()) + self.sequence.append({"type": "reasoning", "index": len(self.reasoning_content) - 1}) + self._current_reasoning = "" + + def record_text_chunk(self, text: str) -> None: + """Record a text chunk event.""" + if not text: + return + + # Flush any pending reasoning first + if self._last_event_type == "thought": + self._flush_current_reasoning() + + text_len = len(text) + start_pos = self._content_position + + # If last event was also content, extend it; otherwise create new + if self.sequence and self.sequence[-1].get("type") == "content": + self.sequence[-1]["end"] = start_pos + text_len + else: + self.sequence.append({"type": "content", "start": start_pos, "end": start_pos + text_len}) + + self._content_position += text_len + self._last_event_type = "content" + + def record_thought_chunk(self, text: str) -> None: + """Record a thought/reasoning chunk event.""" + if not text: + return + + # Accumulate thought content + self._current_reasoning += text + self._last_event_type = "thought" + + def record_tool_call(self, tool_call_id: str, tool_name: str, tool_arguments: str) -> None: + """Record a tool call event.""" + if not tool_call_id: + return + + # Flush any pending reasoning first + if self._last_event_type == "thought": + self._flush_current_reasoning() + + # Check if this tool call already exists (we might get multiple chunks) + if tool_call_id in self._tool_call_id_map: + idx = self._tool_call_id_map[tool_call_id] + # Update arguments if provided + if tool_arguments: + self.tool_calls[idx]["arguments"] = tool_arguments + else: + # New tool call + tool_call = { + "id": tool_call_id or "", + "name": tool_name or "", + "arguments": tool_arguments or "", + "result": "", + "elapsed_time": None, + } + self.tool_calls.append(tool_call) + idx = len(self.tool_calls) - 1 + self._tool_call_id_map[tool_call_id] = idx + self.sequence.append({"type": "tool_call", "index": idx}) + + self._last_event_type = "tool_call" + + def record_tool_result(self, tool_call_id: str, result: str, tool_elapsed_time: float | None = None) -> None: + """Record a tool result event (update existing tool call).""" + if not tool_call_id: + return + if tool_call_id in self._tool_call_id_map: + idx = self._tool_call_id_map[tool_call_id] + self.tool_calls[idx]["result"] = result + self.tool_calls[idx]["elapsed_time"] = tool_elapsed_time + + def finalize(self) -> None: + """Finalize the buffer, flushing any pending data.""" + if self._last_event_type == "thought": + self._flush_current_reasoning() + + def has_data(self) -> bool: + """Check if there's any meaningful data recorded.""" + return bool(self.reasoning_content or self.tool_calls or self.sequence) + + class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. @@ -144,6 +255,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._workflow_run_id: str = "" self._draft_var_saver_factory = draft_var_saver_factory self._graph_runtime_state: GraphRuntimeState | None = None + # Stream event buffer for recording generation sequence + self._stream_buffer = StreamEventBuffer() self._seed_graph_runtime_state_from_queue_manager() def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: @@ -358,6 +471,25 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): if node_finish_resp: yield node_finish_resp + # For ANSWER nodes, check if we need to send a message_replace event + # Only send if the final output differs from the accumulated task_state.answer + # This happens when variables were updated by variable_assigner during workflow execution + if event.node_type == NodeType.ANSWER and event.outputs: + final_answer = event.outputs.get("answer") + if final_answer is not None and final_answer != self._task_state.answer: + logger.info( + "ANSWER node final output '%s' differs from accumulated answer '%s', sending message_replace event", + final_answer, + self._task_state.answer, + ) + # Update the task state answer + self._task_state.answer = str(final_answer) + # Send message_replace event to update the UI + yield self._message_cycle_manager.message_replace_to_stream_response( + answer=str(final_answer), + reason="variable_update", + ) + def _handle_node_failed_events( self, event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent], @@ -383,7 +515,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: - """Handle text chunk events.""" + """Handle text chunk events and record to stream buffer for sequence reconstruction.""" delta_text = event.text if delta_text is None: return @@ -405,9 +537,52 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): if tts_publisher and queue_message: tts_publisher.publish(queue_message) - self._task_state.answer += delta_text + tool_call = event.tool_call + tool_result = event.tool_result + tool_payload = tool_call or tool_result + tool_call_id = tool_payload.id if tool_payload and tool_payload.id else "" + tool_name = tool_payload.name if tool_payload and tool_payload.name else "" + tool_arguments = tool_call.arguments if tool_call and tool_call.arguments else "" + tool_files = tool_result.files if tool_result else [] + tool_elapsed_time = tool_result.elapsed_time if tool_result else None + tool_icon = tool_payload.icon if tool_payload else None + tool_icon_dark = tool_payload.icon_dark if tool_payload else None + # Record stream event based on chunk type + chunk_type = event.chunk_type or ChunkType.TEXT + match chunk_type: + case ChunkType.TEXT: + self._stream_buffer.record_text_chunk(delta_text) + self._task_state.answer += delta_text + case ChunkType.THOUGHT: + # Reasoning should not be part of final answer text + self._stream_buffer.record_thought_chunk(delta_text) + case ChunkType.TOOL_CALL: + self._stream_buffer.record_tool_call( + tool_call_id=tool_call_id, + tool_name=tool_name, + tool_arguments=tool_arguments, + ) + case ChunkType.TOOL_RESULT: + self._stream_buffer.record_tool_result( + tool_call_id=tool_call_id, + result=delta_text, + tool_elapsed_time=tool_elapsed_time, + ) + self._task_state.answer += delta_text + case _: + pass yield self._message_cycle_manager.message_to_stream_response( - answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector + answer=delta_text, + message_id=self._message_id, + from_variable_selector=event.from_variable_selector, + chunk_type=event.chunk_type.value if event.chunk_type else None, + tool_call_id=tool_call_id or None, + tool_name=tool_name or None, + tool_arguments=tool_arguments or None, + tool_files=tool_files, + tool_elapsed_time=tool_elapsed_time, + tool_icon=tool_icon, + tool_icon_dark=tool_icon_dark, ) def _handle_iteration_start_event( @@ -775,6 +950,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): # If there are assistant files, remove markdown image links from answer answer_text = self._task_state.answer + answer_text = self._strip_think_blocks(answer_text) if self._recorded_files: # Remove markdown image links since we're storing files separately answer_text = re.sub(r"!\[.*?\]\(.*?\)", "", answer_text).strip() @@ -826,6 +1002,54 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ] session.add_all(message_files) + # Save generation detail (reasoning/tool calls/sequence) from stream buffer + self._save_generation_detail(session=session, message=message) + + @staticmethod + def _strip_think_blocks(text: str) -> str: + """Remove ... blocks (including their content) from text.""" + if not text or "]*>.*?", "", text, flags=re.IGNORECASE | re.DOTALL) + clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip() + return clean_text + + def _save_generation_detail(self, *, session: Session, message: Message) -> None: + """ + Save LLM generation detail for Chatflow using stream event buffer. + The buffer records the exact order of events as they streamed, + allowing accurate reconstruction of the generation sequence. + """ + # Finalize the stream buffer to flush any pending data + self._stream_buffer.finalize() + + # Only save if there's meaningful data + if not self._stream_buffer.has_data(): + return + + reasoning_content = self._stream_buffer.reasoning_content + tool_calls = self._stream_buffer.tool_calls + sequence = self._stream_buffer.sequence + + # Check if generation detail already exists for this message + existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first() + + if existing: + existing.reasoning_content = json.dumps(reasoning_content) if reasoning_content else None + existing.tool_calls = json.dumps(tool_calls) if tool_calls else None + existing.sequence = json.dumps(sequence) if sequence else None + else: + generation_detail = LLMGenerationDetail( + tenant_id=self._application_generate_entity.app_config.tenant_id, + app_id=self._application_generate_entity.app_config.app_id, + message_id=message.id, + reasoning_content=json.dumps(reasoning_content) if reasoning_content else None, + tool_calls=json.dumps(tool_calls) if tool_calls else None, + sequence=json.dumps(sequence) if sequence else None, + ) + session.add(generation_detail) + def _seed_graph_runtime_state_from_queue_manager(self) -> None: """Bootstrap the cached runtime state from the queue manager when present.""" candidate = self._base_task_pipeline.queue_manager.graph_runtime_state diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 2760466a3b..f5cf7a2c56 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -3,10 +3,8 @@ from typing import cast from sqlalchemy import select -from core.agent.cot_chat_agent_runner import CotChatAgentRunner -from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner +from core.agent.agent_app_runner import AgentAppRunner from core.agent.entities import AgentEntity -from core.agent.fc_agent_runner import FunctionCallAgentRunner from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner @@ -14,8 +12,7 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.moderation.base import ModerationError from extensions.ext_database import db @@ -194,22 +191,7 @@ class AgentChatAppRunner(AppRunner): raise ValueError("Message not found") db.session.close() - runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner] - # start agent runner - if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: - # check LLM mode - if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT: - runner_cls = CotChatAgentRunner - elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION: - runner_cls = CotCompletionAgentRunner - else: - raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}") - elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING: - runner_cls = FunctionCallAgentRunner - else: - raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}") - - runner = runner_cls( + runner = AgentAppRunner( tenant_id=app_config.tenant_id, application_generate_entity=application_generate_entity, conversation=conversation_result, diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 38ecec5d30..0f3f9972c3 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -671,7 +671,7 @@ class WorkflowResponseConverter: task_id=task_id, data=AgentLogStreamResponse.Data( node_execution_id=event.node_execution_id, - id=event.id, + message_id=event.id, parent_id=event.parent_id, label=event.label, error=event.error, diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 13eb40fd60..ea4441b5d8 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -130,7 +130,7 @@ class PipelineGenerator(BaseAppGenerator): pipeline=pipeline, workflow=workflow, start_node_id=start_node_id ) documents: list[Document] = [] - if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"): + if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry and not args.get("original_document_id"): from services.dataset_service import DocumentService for datasource_info in datasource_info_list: @@ -156,7 +156,7 @@ class PipelineGenerator(BaseAppGenerator): for i, datasource_info in enumerate(datasource_info_list): workflow_run_id = str(uuid.uuid4()) document_id = args.get("original_document_id") or None - if invoke_from == InvokeFrom.PUBLISHED and not is_retry: + if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry: document_id = document_id or documents[i].id document_pipeline_execution_log = DocumentPipelineExecutionLog( document_id=document_id, diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 842ad545ad..2b8ed38c63 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -13,6 +13,7 @@ from core.app.apps.common.workflow_response_converter import WorkflowResponseCon from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( AppQueueEvent, + ChunkType, MessageQueueMessage, QueueAgentLogEvent, QueueErrorEvent, @@ -483,11 +484,33 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): if delta_text is None: return + tool_call = event.tool_call + tool_result = event.tool_result + tool_payload = tool_call or tool_result + tool_call_id = tool_payload.id if tool_payload and tool_payload.id else None + tool_name = tool_payload.name if tool_payload and tool_payload.name else None + tool_arguments = tool_call.arguments if tool_call else None + tool_elapsed_time = tool_result.elapsed_time if tool_result else None + tool_files = tool_result.files if tool_result else [] + tool_icon = tool_payload.icon if tool_payload else None + tool_icon_dark = tool_payload.icon_dark if tool_payload else None + # only publish tts message at text chunk streaming if tts_publisher and queue_message: tts_publisher.publish(queue_message) - yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector) + yield self._text_chunk_to_stream_response( + text=delta_text, + from_variable_selector=event.from_variable_selector, + chunk_type=event.chunk_type, + tool_call_id=tool_call_id, + tool_name=tool_name, + tool_arguments=tool_arguments, + tool_files=tool_files, + tool_elapsed_time=tool_elapsed_time, + tool_icon=tool_icon, + tool_icon_dark=tool_icon_dark, + ) def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]: """Handle agent log events.""" @@ -650,16 +673,61 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): session.add(workflow_app_log) def _text_chunk_to_stream_response( - self, text: str, from_variable_selector: list[str] | None = None + self, + text: str, + from_variable_selector: list[str] | None = None, + chunk_type: ChunkType | None = None, + tool_call_id: str | None = None, + tool_name: str | None = None, + tool_arguments: str | None = None, + tool_files: list[str] | None = None, + tool_error: str | None = None, + tool_elapsed_time: float | None = None, + tool_icon: str | dict | None = None, + tool_icon_dark: str | dict | None = None, ) -> TextChunkStreamResponse: """ Handle completed event. :param text: text :return: """ + from core.app.entities.task_entities import ChunkType as ResponseChunkType + + response_chunk_type = ResponseChunkType(chunk_type.value) if chunk_type else ResponseChunkType.TEXT + + data = TextChunkStreamResponse.Data( + text=text, + from_variable_selector=from_variable_selector, + chunk_type=response_chunk_type, + ) + + if response_chunk_type == ResponseChunkType.TOOL_CALL: + data = data.model_copy( + update={ + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "tool_arguments": tool_arguments, + "tool_icon": tool_icon, + "tool_icon_dark": tool_icon_dark, + } + ) + elif response_chunk_type == ResponseChunkType.TOOL_RESULT: + data = data.model_copy( + update={ + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "tool_arguments": tool_arguments, + "tool_files": tool_files, + "tool_error": tool_error, + "tool_elapsed_time": tool_elapsed_time, + "tool_icon": tool_icon, + "tool_icon_dark": tool_icon_dark, + } + ) + response = TextChunkStreamResponse( task_id=self._application_generate_entity.task_id, - data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector), + data=data, ) return response diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 0e125b3538..3b02683764 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -455,12 +455,20 @@ class WorkflowBasedAppRunner: ) ) elif isinstance(event, NodeRunStreamChunkEvent): + from core.app.entities.queue_entities import ChunkType as QueueChunkType + + if event.is_final and not event.chunk: + return + self._publish_event( QueueTextChunkEvent( text=event.chunk, from_variable_selector=list(event.selector), in_iteration_id=event.in_iteration_id, in_loop_id=event.in_loop_id, + chunk_type=QueueChunkType(event.chunk_type.value), + tool_call=event.tool_call, + tool_result=event.tool_result, ) ) elif isinstance(event, NodeRunRetrieverResourceEvent): diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 0cb573cb86..5bc453420d 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -42,7 +42,8 @@ class InvokeFrom(StrEnum): # DEBUGGER indicates that this invocation is from # the workflow (or chatflow) edit page. DEBUGGER = "debugger" - PUBLISHED = "published" + # PUBLISHED_PIPELINE indicates that this invocation runs a published RAG pipeline workflow. + PUBLISHED_PIPELINE = "published" # VALIDATION indicates that this invocation is from validation. VALIDATION = "validation" diff --git a/api/core/app/entities/llm_generation_entities.py b/api/core/app/entities/llm_generation_entities.py new file mode 100644 index 0000000000..33e97e3299 --- /dev/null +++ b/api/core/app/entities/llm_generation_entities.py @@ -0,0 +1,70 @@ +""" +LLM Generation Detail entities. + +Defines the structure for storing and transmitting LLM generation details +including reasoning content, tool calls, and their sequence. +""" + +from typing import Literal + +from pydantic import BaseModel, Field + + +class ContentSegment(BaseModel): + """Represents a content segment in the generation sequence.""" + + type: Literal["content"] = "content" + start: int = Field(..., description="Start position in the text") + end: int = Field(..., description="End position in the text") + + +class ReasoningSegment(BaseModel): + """Represents a reasoning segment in the generation sequence.""" + + type: Literal["reasoning"] = "reasoning" + index: int = Field(..., description="Index into reasoning_content array") + + +class ToolCallSegment(BaseModel): + """Represents a tool call segment in the generation sequence.""" + + type: Literal["tool_call"] = "tool_call" + index: int = Field(..., description="Index into tool_calls array") + + +SequenceSegment = ContentSegment | ReasoningSegment | ToolCallSegment + + +class ToolCallDetail(BaseModel): + """Represents a tool call with its arguments and result.""" + + id: str = Field(default="", description="Unique identifier for the tool call") + name: str = Field(..., description="Name of the tool") + arguments: str = Field(default="", description="JSON string of tool arguments") + result: str = Field(default="", description="Result from the tool execution") + elapsed_time: float | None = Field(default=None, description="Elapsed time in seconds") + + +class LLMGenerationDetailData(BaseModel): + """ + Domain model for LLM generation detail. + + Contains the structured data for reasoning content, tool calls, + and their display sequence. + """ + + reasoning_content: list[str] = Field(default_factory=list, description="List of reasoning segments") + tool_calls: list[ToolCallDetail] = Field(default_factory=list, description="List of tool call details") + sequence: list[SequenceSegment] = Field(default_factory=list, description="Display order of segments") + + def is_empty(self) -> bool: + """Check if there's any meaningful generation detail.""" + return not self.reasoning_content and not self.tool_calls + + def to_response_dict(self) -> dict: + """Convert to dictionary for API response.""" + return { + "reasoning_content": self.reasoning_content, + "tool_calls": [tc.model_dump() for tc in self.tool_calls], + "sequence": [seg.model_dump() for seg in self.sequence], + } diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 77d6bf03b4..fdc4014caa 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.workflow.entities import AgentNodeStrategyInit +from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult from core.workflow.enums import WorkflowNodeExecutionMetadataKey from core.workflow.nodes import NodeType @@ -177,6 +177,17 @@ class QueueLoopCompletedEvent(AppQueueEvent): error: str | None = None +class ChunkType(StrEnum): + """Stream chunk type for LLM-related events.""" + + TEXT = "text" # Normal text streaming + TOOL_CALL = "tool_call" # Tool call arguments streaming + TOOL_RESULT = "tool_result" # Tool execution result + THOUGHT = "thought" # Agent thinking process (ReAct) + THOUGHT_START = "thought_start" # Agent thought start + THOUGHT_END = "thought_end" # Agent thought end + + class QueueTextChunkEvent(AppQueueEvent): """ QueueTextChunkEvent entity @@ -191,6 +202,16 @@ class QueueTextChunkEvent(AppQueueEvent): in_loop_id: str | None = None """loop id if node is in loop""" + # Extended fields for Agent/Tool streaming + chunk_type: ChunkType = ChunkType.TEXT + """type of the chunk""" + + # Tool streaming payloads + tool_call: ToolCall | None = None + """structured tool call info""" + tool_result: ToolResult | None = None + """structured tool result info""" + class QueueAgentMessageEvent(AppQueueEvent): """ diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 79a5e657b3..0998510b60 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -113,6 +113,38 @@ class MessageStreamResponse(StreamResponse): answer: str from_variable_selector: list[str] | None = None + # Extended fields for Agent/Tool streaming (imported at runtime to avoid circular import) + chunk_type: str | None = None + """type of the chunk: text, tool_call, tool_result, thought""" + + # Tool call fields (when chunk_type == "tool_call") + tool_call_id: str | None = None + """unique identifier for this tool call""" + tool_name: str | None = None + """name of the tool being called""" + tool_arguments: str | None = None + """accumulated tool arguments JSON""" + + # Tool result fields (when chunk_type == "tool_result") + tool_files: list[str] | None = None + """file IDs produced by tool""" + tool_error: str | None = None + """error message if tool failed""" + tool_elapsed_time: float | None = None + """elapsed time spent executing the tool""" + tool_icon: str | dict | None = None + """icon of the tool""" + tool_icon_dark: str | dict | None = None + """dark theme icon of the tool""" + + def model_dump(self, *args, **kwargs) -> dict[str, object]: + kwargs.setdefault("exclude_none", True) + return super().model_dump(*args, **kwargs) + + def model_dump_json(self, *args, **kwargs) -> str: + kwargs.setdefault("exclude_none", True) + return super().model_dump_json(*args, **kwargs) + class MessageAudioStreamResponse(StreamResponse): """ @@ -582,6 +614,17 @@ class LoopNodeCompletedStreamResponse(StreamResponse): data: Data +class ChunkType(StrEnum): + """Stream chunk type for LLM-related events.""" + + TEXT = "text" # Normal text streaming + TOOL_CALL = "tool_call" # Tool call arguments streaming + TOOL_RESULT = "tool_result" # Tool execution result + THOUGHT = "thought" # Agent thinking process (ReAct) + THOUGHT_START = "thought_start" # Agent thought start + THOUGHT_END = "thought_end" # Agent thought end + + class TextChunkStreamResponse(StreamResponse): """ TextChunkStreamResponse entity @@ -595,6 +638,36 @@ class TextChunkStreamResponse(StreamResponse): text: str from_variable_selector: list[str] | None = None + # Extended fields for Agent/Tool streaming + chunk_type: ChunkType = ChunkType.TEXT + """type of the chunk""" + + # Tool call fields (when chunk_type == TOOL_CALL) + tool_call_id: str | None = None + """unique identifier for this tool call""" + tool_name: str | None = None + """name of the tool being called""" + tool_arguments: str | None = None + """accumulated tool arguments JSON""" + + # Tool result fields (when chunk_type == TOOL_RESULT) + tool_files: list[str] | None = None + """file IDs produced by tool""" + tool_error: str | None = None + """error message if tool failed""" + + # Tool elapsed time fields (when chunk_type == TOOL_RESULT) + tool_elapsed_time: float | None = None + """elapsed time spent executing the tool""" + + def model_dump(self, *args, **kwargs) -> dict[str, object]: + kwargs.setdefault("exclude_none", True) + return super().model_dump(*args, **kwargs) + + def model_dump_json(self, *args, **kwargs) -> str: + kwargs.setdefault("exclude_none", True) + return super().model_dump_json(*args, **kwargs) + event: StreamEvent = StreamEvent.TEXT_CHUNK data: Data @@ -743,7 +816,7 @@ class AgentLogStreamResponse(StreamResponse): """ node_execution_id: str - id: str + message_id: str label: str parent_id: str | None = None error: str | None = None diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py new file mode 100644 index 0000000000..77cc00bdc9 --- /dev/null +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -0,0 +1,60 @@ +import logging + +from core.variables import Variable +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID +from core.workflow.conversation_variable_updater import ConversationVariableUpdater +from core.workflow.enums import NodeType +from core.workflow.graph_engine.layers.base import GraphEngineLayer +from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent +from core.workflow.nodes.variable_assigner.common import helpers as common_helpers + +logger = logging.getLogger(__name__) + + +class ConversationVariablePersistenceLayer(GraphEngineLayer): + def __init__(self, conversation_variable_updater: ConversationVariableUpdater) -> None: + super().__init__() + self._conversation_variable_updater = conversation_variable_updater + + def on_graph_start(self) -> None: + pass + + def on_event(self, event: GraphEngineEvent) -> None: + if not isinstance(event, NodeRunSucceededEvent): + return + if event.node_type != NodeType.VARIABLE_ASSIGNER: + return + if self.graph_runtime_state is None: + return + + updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or [] + if not updated_variables: + return + + conversation_id = self.graph_runtime_state.system_variable.conversation_id + if conversation_id is None: + return + + updated_any = False + for item in updated_variables: + selector = item.selector + if len(selector) < 2: + logger.warning("Conversation variable selector invalid. selector=%s", selector) + continue + if selector[0] != CONVERSATION_VARIABLE_NODE_ID: + continue + variable = self.graph_runtime_state.variable_pool.get(selector) + if not isinstance(variable, Variable): + logger.warning( + "Conversation variable not found in variable pool. selector=%s", + selector, + ) + continue + self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable) + updated_any = True + + if updated_any: + self._conversation_variable_updater.flush() + + def on_graph_end(self, error: Exception | None) -> None: + pass diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 5bb93fa44a..a1852ffe19 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -1,4 +1,5 @@ import logging +import re import time from collections.abc import Generator from threading import Thread @@ -58,7 +59,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.message_event import message_was_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now -from models.model import AppMode, Conversation, Message, MessageAgentThought +from models.model import AppMode, Conversation, LLMGenerationDetail, Message, MessageAgentThought logger = logging.getLogger(__name__) @@ -68,6 +69,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. """ + _THINK_PATTERN = re.compile(r"]*>(.*?)", re.IGNORECASE | re.DOTALL) + _task_state: EasyUITaskState _application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity] @@ -409,11 +412,136 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): ) ) + # Save LLM generation detail if there's reasoning_content + self._save_generation_detail(session=session, message=message, llm_result=llm_result) + message_was_created.send( message, application_generate_entity=self._application_generate_entity, ) + def _save_generation_detail(self, *, session: Session, message: Message, llm_result: LLMResult) -> None: + """ + Save LLM generation detail for Completion/Chat/Agent-Chat applications. + For Agent-Chat, also merges MessageAgentThought records. + """ + import json + + reasoning_list: list[str] = [] + tool_calls_list: list[dict] = [] + sequence: list[dict] = [] + answer = message.answer or "" + + # Check if this is Agent-Chat mode by looking for agent thoughts + agent_thoughts = ( + session.query(MessageAgentThought) + .filter_by(message_id=message.id) + .order_by(MessageAgentThought.position.asc()) + .all() + ) + + if agent_thoughts: + # Agent-Chat mode: merge MessageAgentThought records + content_pos = 0 + cleaned_answer_parts: list[str] = [] + for thought in agent_thoughts: + # Add thought/reasoning + if thought.thought: + reasoning_text = thought.thought + if " blocks and clean the final answer + clean_answer, reasoning_content = self._split_reasoning_from_answer(answer) + if reasoning_content: + answer = clean_answer + llm_result.message.content = clean_answer + llm_result.reasoning_content = reasoning_content + message.answer = clean_answer + if reasoning_content: + reasoning_list = [reasoning_content] + # Content comes first, then reasoning + if answer: + sequence.append({"type": "content", "start": 0, "end": len(answer)}) + sequence.append({"type": "reasoning", "index": 0}) + + # Only save if there's meaningful generation detail + if not reasoning_list and not tool_calls_list: + return + + # Check if generation detail already exists + existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first() + + if existing: + existing.reasoning_content = json.dumps(reasoning_list) if reasoning_list else None + existing.tool_calls = json.dumps(tool_calls_list) if tool_calls_list else None + existing.sequence = json.dumps(sequence) if sequence else None + else: + generation_detail = LLMGenerationDetail( + tenant_id=self._application_generate_entity.app_config.tenant_id, + app_id=self._application_generate_entity.app_config.app_id, + message_id=message.id, + reasoning_content=json.dumps(reasoning_list) if reasoning_list else None, + tool_calls=json.dumps(tool_calls_list) if tool_calls_list else None, + sequence=json.dumps(sequence) if sequence else None, + ) + session.add(generation_detail) + + @classmethod + def _split_reasoning_from_answer(cls, text: str) -> tuple[str, str]: + """ + Extract reasoning segments from blocks and return (clean_text, reasoning). + """ + matches = cls._THINK_PATTERN.findall(text) + reasoning_content = "\n".join(match.strip() for match in matches) if matches else "" + + clean_text = cls._THINK_PATTERN.sub("", text) + clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip() + + return clean_text, reasoning_content or "" + def _handle_stop(self, event: QueueStopEvent): """ Handle stop. diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 0e7f300cee..f9f341fcea 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -232,15 +232,31 @@ class MessageCycleManager: answer: str, message_id: str, from_variable_selector: list[str] | None = None, + chunk_type: str | None = None, + tool_call_id: str | None = None, + tool_name: str | None = None, + tool_arguments: str | None = None, + tool_files: list[str] | None = None, + tool_error: str | None = None, + tool_elapsed_time: float | None = None, + tool_icon: str | dict | None = None, + tool_icon_dark: str | dict | None = None, event_type: StreamEvent | None = None, ) -> MessageStreamResponse: """ Message to stream response. :param answer: answer :param message_id: message id + :param from_variable_selector: from variable selector + :param chunk_type: type of the chunk (text, function_call, tool_result, thought) + :param tool_call_id: unique identifier for this tool call + :param tool_name: name of the tool being called + :param tool_arguments: accumulated tool arguments JSON + :param tool_files: file IDs produced by tool + :param tool_error: error message if tool failed :return: """ - return MessageStreamResponse( + response = MessageStreamResponse( task_id=self._application_generate_entity.task_id, id=message_id, answer=answer, @@ -248,6 +264,35 @@ class MessageCycleManager: event=event_type or StreamEvent.MESSAGE, ) + if chunk_type: + response = response.model_copy(update={"chunk_type": chunk_type}) + + if chunk_type == "tool_call": + response = response.model_copy( + update={ + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "tool_arguments": tool_arguments, + "tool_icon": tool_icon, + "tool_icon_dark": tool_icon_dark, + } + ) + elif chunk_type == "tool_result": + response = response.model_copy( + update={ + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "tool_arguments": tool_arguments, + "tool_files": tool_files, + "tool_error": tool_error, + "tool_elapsed_time": tool_elapsed_time, + "tool_icon": tool_icon, + "tool_icon_dark": tool_icon_dark, + } + ) + + return response + def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: """ Message replace to stream response. diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index d0279349ca..5249fea8cd 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -5,7 +5,6 @@ from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import Document @@ -90,6 +89,8 @@ class DatasetIndexToolCallbackHandler: # TODO(-LAN-): Improve type check def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]): """Handle return_retriever_resource_info.""" + from core.app.entities.queue_entities import QueueRetrieverResourcesEvent + self._queue_manager.publish( QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER ) diff --git a/api/core/datasource/__base/datasource_plugin.py b/api/core/datasource/__base/datasource_plugin.py index 50c7249fe4..451e4fda0e 100644 --- a/api/core/datasource/__base/datasource_plugin.py +++ b/api/core/datasource/__base/datasource_plugin.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from configs import dify_config @@ -30,7 +32,7 @@ class DatasourcePlugin(ABC): """ return DatasourceProviderType.LOCAL_FILE - def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": + def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> DatasourcePlugin: return self.__class__( entity=self.entity.model_copy(), runtime=runtime, diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 260dcf04f5..dde7d59726 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import enum from enum import StrEnum from typing import Any @@ -31,7 +33,7 @@ class DatasourceProviderType(enum.StrEnum): ONLINE_DRIVE = "online_drive" @classmethod - def value_of(cls, value: str) -> "DatasourceProviderType": + def value_of(cls, value: str) -> DatasourceProviderType: """ Get value of given mode. @@ -81,7 +83,7 @@ class DatasourceParameter(PluginParameter): typ: DatasourceParameterType, required: bool, options: list[str] | None = None, - ) -> "DatasourceParameter": + ) -> DatasourceParameter: """ get a simple datasource parameter @@ -187,14 +189,14 @@ class DatasourceInvokeMeta(BaseModel): tool_config: dict | None = None @classmethod - def empty(cls) -> "DatasourceInvokeMeta": + def empty(cls) -> DatasourceInvokeMeta: """ Get an empty instance of DatasourceInvokeMeta """ return cls(time_cost=0.0, error=None, tool_config={}) @classmethod - def error_instance(cls, error: str) -> "DatasourceInvokeMeta": + def error_instance(cls, error: str) -> DatasourceInvokeMeta: """ Get an instance of DatasourceInvokeMeta with error """ diff --git a/api/core/db/session_factory.py b/api/core/db/session_factory.py index 1dae2eafd4..45d4bc4594 100644 --- a/api/core/db/session_factory.py +++ b/api/core/db/session_factory.py @@ -1,7 +1,7 @@ from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker -_session_maker: sessionmaker | None = None +_session_maker: sessionmaker[Session] | None = None def configure_session_factory(engine: Engine, expire_on_commit: bool = False): @@ -10,7 +10,7 @@ def configure_session_factory(engine: Engine, expire_on_commit: bool = False): _session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit) -def get_session_maker() -> sessionmaker: +def get_session_maker() -> sessionmaker[Session]: if _session_maker is None: raise RuntimeError("Session factory not configured. Call configure_session_factory() first.") return _session_maker @@ -27,7 +27,7 @@ class SessionFactory: configure_session_factory(engine, expire_on_commit) @staticmethod - def get_session_maker() -> sessionmaker: + def get_session_maker() -> sessionmaker[Session]: return get_session_maker() @staticmethod diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index 7fdf5e4be6..135d2a4945 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from datetime import datetime from enum import StrEnum @@ -75,7 +77,7 @@ class MCPProviderEntity(BaseModel): updated_at: datetime @classmethod - def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity": + def from_db_model(cls, db_provider: MCPToolProvider) -> MCPProviderEntity: """Create entity from database model with decryption""" return cls( diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 8a8067332d..0078ec7e4f 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import StrEnum, auto from typing import Union @@ -178,7 +180,7 @@ class BasicProviderConfig(BaseModel): TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR @classmethod - def value_of(cls, value: str) -> "ProviderConfig.Type": + def value_of(cls, value: str) -> ProviderConfig.Type: """ Get value of given mode. diff --git a/api/core/file/helpers.py b/api/core/file/helpers.py index 6d553d7dc6..2ac483673a 100644 --- a/api/core/file/helpers.py +++ b/api/core/file/helpers.py @@ -8,8 +8,9 @@ import urllib.parse from configs import dify_config -def get_signed_file_url(upload_file_id: str, as_attachment=False) -> str: - url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview" +def get_signed_file_url(upload_file_id: str, as_attachment=False, for_external: bool = True) -> str: + base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL) + url = f"{base_url}/files/{upload_file_id}/file-preview" timestamp = str(int(time.time())) nonce = os.urandom(16).hex() diff --git a/api/core/file/models.py b/api/core/file/models.py index d149205d77..6324523b22 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -112,17 +112,17 @@ class File(BaseModel): return text - def generate_url(self) -> str | None: + def generate_url(self, for_external: bool = True) -> str | None: if self.transfer_method == FileTransferMethod.REMOTE_URL: return self.remote_url elif self.transfer_method == FileTransferMethod.LOCAL_FILE: if self.related_id is None: raise ValueError("Missing file related_id") - return helpers.get_signed_file_url(upload_file_id=self.related_id) + return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external) elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]: assert self.related_id is not None assert self.extension is not None - return sign_tool_file(tool_file_id=self.related_id, extension=self.extension) + return sign_tool_file(tool_file_id=self.related_id, extension=self.extension, for_external=for_external) return None def to_plugin_parameter(self) -> dict[str, Any]: @@ -133,7 +133,7 @@ class File(BaseModel): "extension": self.extension, "size": self.size, "type": self.type, - "url": self.generate_url(), + "url": self.generate_url(for_external=False), } @model_validator(mode="after") diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index 6fda073913..5cdea19a8d 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -76,7 +76,7 @@ class TemplateTransformer(ABC): Post-process the result to convert scientific notation strings back to numbers """ - def convert_scientific_notation(value): + def convert_scientific_notation(value: Any) -> Any: if isinstance(value, str): # Check if the string looks like scientific notation if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE): @@ -90,7 +90,7 @@ class TemplateTransformer(ABC): return [convert_scientific_notation(v) for v in value] return value - return convert_scientific_notation(result) # type: ignore[no-any-return] + return convert_scientific_notation(result) @classmethod @abstractmethod diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index c97ae6eac7..84a6fd0d1f 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -68,13 +68,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): request_id: RequestId, request_meta: RequestParams.Meta | None, request: ReceiveRequestT, - session: """BaseSession[ - SendRequestT, - SendNotificationT, - SendResultT, - ReceiveRequestT, - ReceiveNotificationT - ]""", + session: """BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]""", on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], ): self.request_id = request_id diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 89dae2dbff..3ac83b4c96 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC from collections.abc import Mapping, Sequence from enum import StrEnum, auto @@ -17,7 +19,7 @@ class PromptMessageRole(StrEnum): TOOL = auto() @classmethod - def value_of(cls, value: str) -> "PromptMessageRole": + def value_of(cls, value: str) -> PromptMessageRole: """ Get value of given mode. diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index aee6ce1108..19194d162c 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from decimal import Decimal from enum import StrEnum, auto from typing import Any @@ -20,7 +22,7 @@ class ModelType(StrEnum): TTS = auto() @classmethod - def value_of(cls, origin_model_type: str) -> "ModelType": + def value_of(cls, origin_model_type: str) -> ModelType: """ Get model type from origin model type. @@ -103,7 +105,7 @@ class DefaultParameterName(StrEnum): JSON_SCHEMA = auto() @classmethod - def value_of(cls, value: Any) -> "DefaultParameterName": + def value_of(cls, value: Any) -> DefaultParameterName: """ Get parameter name from value. diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 12a202ce64..28f162a928 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import hashlib import logging from collections.abc import Sequence @@ -38,7 +40,7 @@ class ModelProviderFactory: plugin_providers = self.get_plugin_model_providers() return [provider.declaration for provider in plugin_providers] - def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]: + def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]: """ Get all plugin model providers :return: list of plugin model providers @@ -76,7 +78,7 @@ class ModelProviderFactory: plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) return plugin_model_provider_entity.declaration - def get_plugin_model_provider(self, provider: str) -> "PluginModelProviderEntity": + def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity: """ Get plugin model provider :param provider: provider name diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 3b83121357..6674228dc0 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import enum from collections.abc import Mapping, Sequence from datetime import datetime @@ -242,7 +244,7 @@ class CredentialType(enum.StrEnum): return [item.value for item in cls] @classmethod - def of(cls, credential_type: str) -> "CredentialType": + def of(cls, credential_type: str) -> CredentialType: type_name = credential_type.lower() if type_name in {"api-key", "api_key"}: return cls.API_KEY diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index a306f9ba0c..91bb71bfa6 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import json import logging @@ -6,7 +8,7 @@ import re import threading import time import uuid -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import clickzetta # type: ignore from pydantic import BaseModel, model_validator @@ -76,7 +78,7 @@ class ClickzettaConnectionPool: Manages connection reuse across ClickzettaVector instances. """ - _instance: Optional["ClickzettaConnectionPool"] = None + _instance: ClickzettaConnectionPool | None = None _lock = threading.Lock() def __init__(self): @@ -89,7 +91,7 @@ class ClickzettaConnectionPool: self._start_cleanup_thread() @classmethod - def get_instance(cls) -> "ClickzettaConnectionPool": + def get_instance(cls) -> ClickzettaConnectionPool: """Get singleton instance of connection pool.""" if cls._instance is None: with cls._lock: @@ -104,7 +106,7 @@ class ClickzettaConnectionPool: f"{config.workspace}:{config.vcluster}:{config.schema_name}" ) - def _create_connection(self, config: ClickzettaConfig) -> "Connection": + def _create_connection(self, config: ClickzettaConfig) -> Connection: """Create a new ClickZetta connection.""" max_retries = 3 retry_delay = 1.0 @@ -134,7 +136,7 @@ class ClickzettaConnectionPool: raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts") - def _configure_connection(self, connection: "Connection"): + def _configure_connection(self, connection: Connection): """Configure connection session settings.""" try: with connection.cursor() as cursor: @@ -181,7 +183,7 @@ class ClickzettaConnectionPool: except Exception: logger.exception("Failed to configure connection, continuing with defaults") - def _is_connection_valid(self, connection: "Connection") -> bool: + def _is_connection_valid(self, connection: Connection) -> bool: """Check if connection is still valid.""" try: with connection.cursor() as cursor: @@ -190,7 +192,7 @@ class ClickzettaConnectionPool: except Exception: return False - def get_connection(self, config: ClickzettaConfig) -> "Connection": + def get_connection(self, config: ClickzettaConfig) -> Connection: """Get a connection from the pool or create a new one.""" config_key = self._get_config_key(config) @@ -221,7 +223,7 @@ class ClickzettaConnectionPool: # No valid connection found, create new one return self._create_connection(config) - def return_connection(self, config: ClickzettaConfig, connection: "Connection"): + def return_connection(self, config: ClickzettaConfig, connection: Connection): """Return a connection to the pool.""" config_key = self._get_config_key(config) @@ -315,22 +317,22 @@ class ClickzettaVector(BaseVector): self._connection_pool = ClickzettaConnectionPool.get_instance() self._init_write_queue() - def _get_connection(self) -> "Connection": + def _get_connection(self) -> Connection: """Get a connection from the pool.""" return self._connection_pool.get_connection(self._config) - def _return_connection(self, connection: "Connection"): + def _return_connection(self, connection: Connection): """Return a connection to the pool.""" self._connection_pool.return_connection(self._config, connection) class ConnectionContext: """Context manager for borrowing and returning connections.""" - def __init__(self, vector_instance: "ClickzettaVector"): + def __init__(self, vector_instance: ClickzettaVector): self.vector = vector_instance self.connection: Connection | None = None - def __enter__(self) -> "Connection": + def __enter__(self) -> Connection: self.connection = self.vector._get_connection() return self.connection @@ -338,7 +340,7 @@ class ClickzettaVector(BaseVector): if self.connection: self.vector._return_connection(self.connection) - def get_connection_context(self) -> "ClickzettaVector.ConnectionContext": + def get_connection_context(self) -> ClickzettaVector.ConnectionContext: """Get a connection context manager.""" return self.ConnectionContext(self) @@ -437,7 +439,7 @@ class ClickzettaVector(BaseVector): """Return the vector database type.""" return "clickzetta" - def _ensure_connection(self) -> "Connection": + def _ensure_connection(self) -> Connection: """Get a connection from the pool.""" return self._get_connection() @@ -984,9 +986,11 @@ class ClickzettaVector(BaseVector): # No need for dataset_id filter since each dataset has its own table - # Use simple quote escaping for LIKE clause - escaped_query = query.replace("'", "''") - filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'") + # Escape special characters for LIKE clause to prevent SQL injection + from libs.helper import escape_like_pattern + + escaped_query = escape_like_pattern(query).replace("'", "''") + filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%' ESCAPE '\\\\'") where_clause = " AND ".join(filter_clauses) search_sql = f""" diff --git a/api/core/rag/datasource/vdb/iris/iris_vector.py b/api/core/rag/datasource/vdb/iris/iris_vector.py index b1bfabb76e..5bdb0af0b3 100644 --- a/api/core/rag/datasource/vdb/iris/iris_vector.py +++ b/api/core/rag/datasource/vdb/iris/iris_vector.py @@ -287,11 +287,15 @@ class IrisVector(BaseVector): cursor.execute(sql, (query,)) else: # Fallback to LIKE search (inefficient for large datasets) - query_pattern = f"%{query}%" + # Escape special characters for LIKE clause to prevent SQL injection + from libs.helper import escape_like_pattern + + escaped_query = escape_like_pattern(query) + query_pattern = f"%{escaped_query}%" sql = f""" SELECT TOP {top_k} id, text, meta FROM {self.schema}.{self.table_name} - WHERE text LIKE ? + WHERE text LIKE ? ESCAPE '\\' """ cursor.execute(sql, (query_pattern,)) diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 1fe74d3042..69adac522d 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Sequence from typing import Any @@ -22,7 +24,7 @@ class DatasetDocumentStore: self._document_id = document_id @classmethod - def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore": + def from_dict(cls, config_dict: dict[str, Any]) -> DatasetDocumentStore: return cls(**config_dict) def to_dict(self) -> dict[str, Any]: diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index f67f613e9d..511f5a698d 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -7,10 +7,11 @@ import re import tempfile import uuid from urllib.parse import urlparse -from xml.etree import ElementTree import httpx from docx import Document as DocxDocument +from docx.oxml.ns import qn +from docx.text.run import Run from configs import dify_config from core.helper import ssrf_proxy @@ -229,44 +230,20 @@ class WordExtractor(BaseExtractor): image_map = self._extract_images_from_docx(doc) - hyperlinks_url = None - url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+") - for para in doc.paragraphs: - for run in para.runs: - if run.text and hyperlinks_url: - result = f" [{run.text}]({hyperlinks_url}) " - run.text = result - hyperlinks_url = None - if "HYPERLINK" in run.element.xml: - try: - xml = ElementTree.XML(run.element.xml) - x_child = [c for c in xml.iter() if c is not None] - for x in x_child: - if x is None: - continue - if x.tag.endswith("instrText"): - if x.text is None: - continue - for i in url_pattern.findall(x.text): - hyperlinks_url = str(i) - except Exception: - logger.exception("Failed to parse HYPERLINK xml") - def parse_paragraph(paragraph): - paragraph_content = [] - - def append_image_link(image_id, has_drawing): + def append_image_link(image_id, has_drawing, target_buffer): """Helper to append image link from image_map based on relationship type.""" rel = doc.part.rels[image_id] if rel.is_external: if image_id in image_map and not has_drawing: - paragraph_content.append(image_map[image_id]) + target_buffer.append(image_map[image_id]) else: image_part = rel.target_part if image_part in image_map and not has_drawing: - paragraph_content.append(image_map[image_part]) + target_buffer.append(image_map[image_part]) - for run in paragraph.runs: + def process_run(run, target_buffer): + # Helper to extract text and embedded images from a run element and append them to target_buffer if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"): # Process drawing type images drawing_elements = run.element.findall( @@ -287,13 +264,13 @@ class WordExtractor(BaseExtractor): # External image: use embed_id as key if embed_id in image_map: has_drawing = True - paragraph_content.append(image_map[embed_id]) + target_buffer.append(image_map[embed_id]) else: # Internal image: use target_part as key image_part = doc.part.related_parts.get(embed_id) if image_part in image_map: has_drawing = True - paragraph_content.append(image_map[image_part]) + target_buffer.append(image_map[image_part]) # Process pict type images shape_elements = run.element.findall( ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict" @@ -308,7 +285,7 @@ class WordExtractor(BaseExtractor): "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id" ) if image_id and image_id in doc.part.rels: - append_image_link(image_id, has_drawing) + append_image_link(image_id, has_drawing, target_buffer) # Find imagedata element in VML image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata") if image_data is not None: @@ -316,9 +293,93 @@ class WordExtractor(BaseExtractor): "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id" ) if image_id and image_id in doc.part.rels: - append_image_link(image_id, has_drawing) + append_image_link(image_id, has_drawing, target_buffer) if run.text.strip(): - paragraph_content.append(run.text.strip()) + target_buffer.append(run.text.strip()) + + def process_hyperlink(hyperlink_elem, target_buffer): + # Helper to extract text from a hyperlink element and append it to target_buffer + r_id = hyperlink_elem.get(qn("r:id")) + + # Extract text from runs inside the hyperlink + link_text_parts = [] + for run_elem in hyperlink_elem.findall(qn("w:r")): + run = Run(run_elem, paragraph) + # Hyperlink text may be split across multiple runs (e.g., with different formatting), + # so collect all run texts first + if run.text: + link_text_parts.append(run.text) + + link_text = "".join(link_text_parts).strip() + + # Resolve URL + if r_id: + try: + rel = doc.part.rels.get(r_id) + if rel and rel.is_external: + link_text = f"[{link_text or rel.target_ref}]({rel.target_ref})" + except Exception: + logger.exception("Failed to resolve URL for hyperlink with r:id: %s", r_id) + + if link_text: + target_buffer.append(link_text) + + paragraph_content = [] + # State for legacy HYPERLINK fields + hyperlink_field_url = None + hyperlink_field_text_parts: list = [] + is_collecting_field_text = False + # Iterate through paragraph elements in document order + for child in paragraph._element: + tag = child.tag + if tag == qn("w:r"): + # Regular run + run = Run(child, paragraph) + + # Check for fldChar (begin/end/separate) and instrText for legacy hyperlinks + fld_chars = child.findall(qn("w:fldChar")) + instr_texts = child.findall(qn("w:instrText")) + + # Handle Fields + if fld_chars or instr_texts: + # Process instrText to find HYPERLINK "url" + for instr in instr_texts: + if instr.text and "HYPERLINK" in instr.text: + # Quick regex to extract URL + match = re.search(r'HYPERLINK\s+"([^"]+)"', instr.text, re.IGNORECASE) + if match: + hyperlink_field_url = match.group(1) + + # Process fldChar + for fld_char in fld_chars: + fld_char_type = fld_char.get(qn("w:fldCharType")) + if fld_char_type == "begin": + # Start of a field: reset legacy link state + hyperlink_field_url = None + hyperlink_field_text_parts = [] + is_collecting_field_text = False + elif fld_char_type == "separate": + # Separator: if we found a URL, start collecting visible text + if hyperlink_field_url: + is_collecting_field_text = True + elif fld_char_type == "end": + # End of field + if is_collecting_field_text and hyperlink_field_url: + # Create markdown link and append to main content + display_text = "".join(hyperlink_field_text_parts).strip() + if display_text: + link_md = f"[{display_text}]({hyperlink_field_url})" + paragraph_content.append(link_md) + # Reset state + hyperlink_field_url = None + hyperlink_field_text_parts = [] + is_collecting_field_text = False + + # Decide where to append content + target_buffer = hyperlink_field_text_parts if is_collecting_field_text else paragraph_content + process_run(run, target_buffer) + elif tag == qn("w:hyperlink"): + process_hyperlink(child, paragraph_content) return "".join(paragraph_content) if paragraph_content else "" paragraphs = doc.paragraphs.copy() diff --git a/api/core/rag/pipeline/queue.py b/api/core/rag/pipeline/queue.py index 7472598a7f..bf8db95b4e 100644 --- a/api/core/rag/pipeline/queue.py +++ b/api/core/rag/pipeline/queue.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from collections.abc import Sequence from typing import Any @@ -16,7 +18,7 @@ class TaskWrapper(BaseModel): return self.model_dump_json() @classmethod - def deserialize(cls, serialized_data: str) -> "TaskWrapper": + def deserialize(cls, serialized_data: str) -> TaskWrapper: return cls.model_validate_json(serialized_data) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index c6339aa3ba..f8f85d141a 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1198,18 +1198,24 @@ class DatasetRetrieval: json_field = DatasetDocument.doc_metadata[metadata_name].as_string() + from libs.helper import escape_like_pattern + match condition: case "contains": - filters.append(json_field.like(f"%{value}%")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.like(f"%{escaped_value}%", escape="\\")) case "not contains": - filters.append(json_field.notlike(f"%{value}%")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.notlike(f"%{escaped_value}%", escape="\\")) case "start with": - filters.append(json_field.like(f"{value}%")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.like(f"{escaped_value}%", escape="\\")) case "end with": - filters.append(json_field.like(f"%{value}")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.like(f"%{escaped_value}", escape="\\")) case "is" | "=": if isinstance(value, str): diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 4436773d25..a45d1d1046 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -29,6 +29,7 @@ from models import ( Account, CreatorUserRole, EndUser, + LLMGenerationDetail, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, ) @@ -457,6 +458,113 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) session.merge(db_model) session.flush() + # Save LLMGenerationDetail for LLM nodes with successful execution + if ( + domain_model.node_type == NodeType.LLM + and domain_model.status == WorkflowNodeExecutionStatus.SUCCEEDED + and domain_model.outputs is not None + ): + self._save_llm_generation_detail(session, domain_model) + + def _save_llm_generation_detail(self, session, execution: WorkflowNodeExecution) -> None: + """ + Save LLM generation detail for LLM nodes. + Extracts reasoning_content, tool_calls, and sequence from outputs and metadata. + """ + outputs = execution.outputs or {} + metadata = execution.metadata or {} + + reasoning_list = self._extract_reasoning(outputs) + tool_calls_list = self._extract_tool_calls(metadata.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG)) + + if not reasoning_list and not tool_calls_list: + return + + sequence = self._build_generation_sequence(outputs.get("text", ""), reasoning_list, tool_calls_list) + self._upsert_generation_detail(session, execution, reasoning_list, tool_calls_list, sequence) + + def _extract_reasoning(self, outputs: Mapping[str, Any]) -> list[str]: + """Extract reasoning_content as a clean list of non-empty strings.""" + reasoning_content = outputs.get("reasoning_content") + if isinstance(reasoning_content, str): + trimmed = reasoning_content.strip() + return [trimmed] if trimmed else [] + if isinstance(reasoning_content, list): + return [item.strip() for item in reasoning_content if isinstance(item, str) and item.strip()] + return [] + + def _extract_tool_calls(self, agent_log: Any) -> list[dict[str, str]]: + """Extract tool call records from agent logs.""" + if not agent_log or not isinstance(agent_log, list): + return [] + + tool_calls: list[dict[str, str]] = [] + for log in agent_log: + log_data = log.data if hasattr(log, "data") else (log.get("data", {}) if isinstance(log, dict) else {}) + tool_name = log_data.get("tool_name") + if tool_name and str(tool_name).strip(): + tool_calls.append( + { + "id": log_data.get("tool_call_id", ""), + "name": tool_name, + "arguments": json.dumps(log_data.get("tool_args", {})), + "result": str(log_data.get("output", "")), + } + ) + return tool_calls + + def _build_generation_sequence( + self, text: str, reasoning_list: list[str], tool_calls_list: list[dict[str, str]] + ) -> list[dict[str, Any]]: + """Build a simple content/reasoning/tool_call sequence.""" + sequence: list[dict[str, Any]] = [] + if text: + sequence.append({"type": "content", "start": 0, "end": len(text)}) + for index in range(len(reasoning_list)): + sequence.append({"type": "reasoning", "index": index}) + for index in range(len(tool_calls_list)): + sequence.append({"type": "tool_call", "index": index}) + return sequence + + def _upsert_generation_detail( + self, + session, + execution: WorkflowNodeExecution, + reasoning_list: list[str], + tool_calls_list: list[dict[str, str]], + sequence: list[dict[str, Any]], + ) -> None: + """Insert or update LLMGenerationDetail with serialized fields.""" + existing = ( + session.query(LLMGenerationDetail) + .filter_by( + workflow_run_id=execution.workflow_execution_id, + node_id=execution.node_id, + ) + .first() + ) + + reasoning_json = json.dumps(reasoning_list) if reasoning_list else None + tool_calls_json = json.dumps(tool_calls_list) if tool_calls_list else None + sequence_json = json.dumps(sequence) if sequence else None + + if existing: + existing.reasoning_content = reasoning_json + existing.tool_calls = tool_calls_json + existing.sequence = sequence_json + return + + generation_detail = LLMGenerationDetail( + tenant_id=self._tenant_id, + app_id=self._app_id, + workflow_run_id=execution.workflow_execution_id, + node_id=execution.node_id, + reasoning_content=reasoning_json, + tool_calls=tool_calls_json, + sequence=sequence_json, + ) + session.add(generation_detail) + def get_db_models_by_workflow_run( self, workflow_run_id: str, diff --git a/api/core/schemas/registry.py b/api/core/schemas/registry.py index 51bfae1cd3..b4ecfe47ff 100644 --- a/api/core/schemas/registry.py +++ b/api/core/schemas/registry.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import json import logging import threading from collections.abc import Mapping, MutableMapping from pathlib import Path -from typing import Any, ClassVar, Optional +from typing import Any, ClassVar class SchemaRegistry: @@ -11,7 +13,7 @@ class SchemaRegistry: logger: ClassVar[logging.Logger] = logging.getLogger(__name__) - _default_instance: ClassVar[Optional["SchemaRegistry"]] = None + _default_instance: ClassVar[SchemaRegistry | None] = None _lock: ClassVar[threading.Lock] = threading.Lock() def __init__(self, base_dir: str): @@ -20,7 +22,7 @@ class SchemaRegistry: self.metadata: MutableMapping[str, MutableMapping[str, Any]] = {} @classmethod - def default_registry(cls) -> "SchemaRegistry": + def default_registry(cls) -> SchemaRegistry: """Returns the default schema registry for builtin schemas (thread-safe singleton)""" if cls._default_instance is None: with cls._lock: diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 8ca4eabb7a..24fc11aefc 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from collections.abc import Generator from copy import deepcopy @@ -6,6 +8,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from models.model import File +from core.model_runtime.entities.message_entities import PromptMessageTool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( ToolEntity, @@ -24,7 +27,7 @@ class Tool(ABC): self.entity = entity self.runtime = runtime - def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> Tool: """ fork a new tool with metadata :return: the new tool @@ -152,6 +155,60 @@ class Tool(ABC): return parameters + def to_prompt_message_tool(self) -> PromptMessageTool: + message_tool = PromptMessageTool( + name=self.entity.identity.name, + description=self.entity.description.llm if self.entity.description else "", + parameters={ + "type": "object", + "properties": {}, + "required": [], + }, + ) + + parameters = self.get_merged_runtime_parameters() + for parameter in parameters: + if parameter.form != ToolParameter.ToolParameterForm.LLM: + continue + + parameter_type = parameter.type.as_normal_type() + if parameter.type in { + ToolParameter.ToolParameterType.SYSTEM_FILES, + ToolParameter.ToolParameterType.FILE, + ToolParameter.ToolParameterType.FILES, + }: + # Determine the description based on parameter type + if parameter.type == ToolParameter.ToolParameterType.FILE: + file_format_desc = " Input the file id with format: [File: file_id]." + else: + file_format_desc = "Input the file id with format: [Files: file_id1, file_id2, ...]. " + + message_tool.parameters["properties"][parameter.name] = { + "type": "string", + "description": (parameter.llm_description or "") + file_format_desc, + } + continue + enum = [] + if parameter.type == ToolParameter.ToolParameterType.SELECT: + enum = [option.value for option in parameter.options] if parameter.options else [] + + message_tool.parameters["properties"][parameter.name] = ( + { + "type": parameter_type, + "description": parameter.llm_description or "", + } + if parameter.input_schema is None + else parameter.input_schema + ) + + if len(enum) > 0: + message_tool.parameters["properties"][parameter.name]["enum"] = enum + + if parameter.required: + message_tool.parameters["required"].append(parameter.name) + + return message_tool + def create_image_message( self, image: str, @@ -166,7 +223,7 @@ class Tool(ABC): type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image) ) - def create_file_message(self, file: "File") -> ToolInvokeMessage: + def create_file_message(self, file: File) -> ToolInvokeMessage: return ToolInvokeMessage( type=ToolInvokeMessage.MessageType.FILE, message=ToolInvokeMessage.FileMessage(), diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index 84efefba07..51b0407886 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from core.tools.__base.tool import Tool @@ -24,7 +26,7 @@ class BuiltinTool(Tool): super().__init__(**kwargs) self.provider = provider - def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> BuiltinTool: """ fork a new tool with metadata :return: the new tool diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py index 0cc992155a..e2f6c00555 100644 --- a/api/core/tools/custom_tool/provider.py +++ b/api/core/tools/custom_tool/provider.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pydantic import Field from sqlalchemy import select @@ -32,7 +34,7 @@ class ApiToolProviderController(ToolProviderController): self.tools = [] @classmethod - def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController": + def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> ApiToolProviderController: credentials_schema = [ ProviderConfig( name="auth_type", diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 583a3584f7..b5c7a6310c 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import contextlib from collections.abc import Mapping @@ -55,7 +57,7 @@ class ToolProviderType(StrEnum): MCP = auto() @classmethod - def value_of(cls, value: str) -> "ToolProviderType": + def value_of(cls, value: str) -> ToolProviderType: """ Get value of given mode. @@ -79,7 +81,7 @@ class ApiProviderSchemaType(StrEnum): OPENAI_ACTIONS = auto() @classmethod - def value_of(cls, value: str) -> "ApiProviderSchemaType": + def value_of(cls, value: str) -> ApiProviderSchemaType: """ Get value of given mode. @@ -102,7 +104,7 @@ class ApiProviderAuthType(StrEnum): API_KEY_QUERY = auto() @classmethod - def value_of(cls, value: str) -> "ApiProviderAuthType": + def value_of(cls, value: str) -> ApiProviderAuthType: """ Get value of given mode. @@ -307,7 +309,7 @@ class ToolParameter(PluginParameter): typ: ToolParameterType, required: bool, options: list[str] | None = None, - ) -> "ToolParameter": + ) -> ToolParameter: """ get a simple tool parameter @@ -429,14 +431,14 @@ class ToolInvokeMeta(BaseModel): tool_config: dict | None = None @classmethod - def empty(cls) -> "ToolInvokeMeta": + def empty(cls) -> ToolInvokeMeta: """ Get an empty instance of ToolInvokeMeta """ return cls(time_cost=0.0, error=None, tool_config={}) @classmethod - def error_instance(cls, error: str) -> "ToolInvokeMeta": + def error_instance(cls, error: str) -> ToolInvokeMeta: """ Get an instance of ToolInvokeMeta with error """ diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 96917045e3..ef9e9c103a 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import json import logging @@ -118,7 +120,7 @@ class MCPTool(Tool): for item in json_list: yield self.create_json_message(item) - def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool: return MCPTool( entity=self.entity, runtime=runtime, diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index 828dc3b810..d3a2ad488c 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Generator from typing import Any @@ -46,7 +48,7 @@ class PluginTool(Tool): message_id=message_id, ) - def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> PluginTool: return PluginTool( entity=self.entity, runtime=runtime, diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py index fef3157f27..22e099deba 100644 --- a/api/core/tools/signature.py +++ b/api/core/tools/signature.py @@ -7,12 +7,12 @@ import time from configs import dify_config -def sign_tool_file(tool_file_id: str, extension: str) -> str: +def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True) -> str: """ sign file to get a temporary url for plugin access """ - # Use internal URL for plugin/tool file access in Docker environments - base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + # Use internal URL for plugin/tool file access in Docker environments, unless for_external is True + base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL) file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}" timestamp = str(int(time.time())) diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 5422f5250b..a706f101ca 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Mapping from pydantic import Field @@ -47,7 +49,7 @@ class WorkflowToolProviderController(ToolProviderController): self.provider_id = provider_id @classmethod - def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController": + def from_db(cls, db_provider: WorkflowToolProvider) -> WorkflowToolProviderController: with session_factory.create_session() as session, session.begin(): app = session.get(App, db_provider.app_id) if not app: diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 30334f5da8..81a1d54199 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import logging from collections.abc import Generator, Mapping, Sequence @@ -181,7 +183,7 @@ class WorkflowTool(Tool): return found return None - def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> WorkflowTool: """ fork a new tool with metadata diff --git a/api/core/variables/types.py b/api/core/variables/types.py index ce71711344..13b926c978 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from collections.abc import Mapping from enum import StrEnum -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from core.file.models import File @@ -52,7 +54,7 @@ class SegmentType(StrEnum): return self in _ARRAY_TYPES @classmethod - def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]: + def infer_segment_type(cls, value: Any) -> SegmentType | None: """ Attempt to infer the `SegmentType` based on the Python type of the `value` parameter. @@ -173,7 +175,7 @@ class SegmentType(StrEnum): raise AssertionError("this statement should be unreachable.") @staticmethod - def cast_value(value: Any, type_: "SegmentType"): + def cast_value(value: Any, type_: SegmentType): # Cast Python's `bool` type to `int` when the runtime type requires # an integer or number. # @@ -193,7 +195,7 @@ class SegmentType(StrEnum): return [int(i) for i in value] return value - def exposed_type(self) -> "SegmentType": + def exposed_type(self) -> SegmentType: """Returns the type exposed to the frontend. The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here. @@ -202,7 +204,7 @@ class SegmentType(StrEnum): return SegmentType.NUMBER return self - def element_type(self) -> "SegmentType | None": + def element_type(self) -> SegmentType | None: """Return the element type of the current segment type, or `None` if the element type is undefined. Raises: @@ -217,7 +219,7 @@ class SegmentType(StrEnum): return _ARRAY_ELEMENT_TYPES_MAPPING.get(self) @staticmethod - def get_zero_value(t: "SegmentType"): + def get_zero_value(t: SegmentType): # Lazy import to avoid circular dependency from factories import variable_factory diff --git a/api/core/workflow/entities/__init__.py b/api/core/workflow/entities/__init__.py index be70e467a0..0f3b9a5239 100644 --- a/api/core/workflow/entities/__init__.py +++ b/api/core/workflow/entities/__init__.py @@ -1,11 +1,16 @@ from .agent import AgentNodeStrategyInit from .graph_init_params import GraphInitParams +from .tool_entities import ToolCall, ToolCallResult, ToolResult, ToolResultStatus from .workflow_execution import WorkflowExecution from .workflow_node_execution import WorkflowNodeExecution __all__ = [ "AgentNodeStrategyInit", "GraphInitParams", + "ToolCall", + "ToolCallResult", + "ToolResult", + "ToolResultStatus", "WorkflowExecution", "WorkflowNodeExecution", ] diff --git a/api/core/workflow/entities/tool_entities.py b/api/core/workflow/entities/tool_entities.py new file mode 100644 index 0000000000..7e71a86849 --- /dev/null +++ b/api/core/workflow/entities/tool_entities.py @@ -0,0 +1,39 @@ +from enum import StrEnum + +from pydantic import BaseModel, Field + +from core.file import File + + +class ToolResultStatus(StrEnum): + SUCCESS = "success" + ERROR = "error" + + +class ToolCall(BaseModel): + id: str | None = Field(default=None, description="Unique identifier for this tool call") + name: str | None = Field(default=None, description="Name of the tool being called") + arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON") + icon: str | dict | None = Field(default=None, description="Icon of the tool") + icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool") + + +class ToolResult(BaseModel): + id: str | None = Field(default=None, description="Identifier of the tool call this result belongs to") + name: str | None = Field(default=None, description="Name of the tool") + output: str | None = Field(default=None, description="Tool output text, error or success message") + files: list[str] = Field(default_factory=list, description="File produced by tool") + status: ToolResultStatus | None = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status") + elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool") + icon: str | dict | None = Field(default=None, description="Icon of the tool") + icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool") + + +class ToolCallResult(BaseModel): + id: str | None = Field(default=None, description="Identifier for the tool call") + name: str | None = Field(default=None, description="Name of the tool") + arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON") + output: str | None = Field(default=None, description="Tool output text, error or success message") + files: list[File] = Field(default_factory=list, description="File produced by tool") + status: ToolResultStatus = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status") + elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool") diff --git a/api/core/workflow/entities/workflow_execution.py b/api/core/workflow/entities/workflow_execution.py index a8a86d3db2..1b3fb36f1f 100644 --- a/api/core/workflow/entities/workflow_execution.py +++ b/api/core/workflow/entities/workflow_execution.py @@ -5,6 +5,8 @@ Models are independent of the storage mechanism and don't contain implementation details like tenant_id, app_id, etc. """ +from __future__ import annotations + from collections.abc import Mapping from datetime import datetime from typing import Any @@ -59,7 +61,7 @@ class WorkflowExecution(BaseModel): graph: Mapping[str, Any], inputs: Mapping[str, Any], started_at: datetime, - ) -> "WorkflowExecution": + ) -> WorkflowExecution: return WorkflowExecution( id_=id_, workflow_id=workflow_id, diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index 94ec89ad39..2e08d0359d 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -248,6 +248,8 @@ class WorkflowNodeExecutionMetadataKey(StrEnum): ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output DATASOURCE_INFO = "datasource_info" + LLM_CONTENT_SEQUENCE = "llm_content_sequence" + LLM_TRACE = "llm_trace" COMPLETED_REASON = "completed_reason" # completed reason for loop node diff --git a/api/core/workflow/graph/graph.py b/api/core/workflow/graph/graph.py index ba5a01fc94..7be94c2426 100644 --- a/api/core/workflow/graph/graph.py +++ b/api/core/workflow/graph/graph.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from collections import defaultdict from collections.abc import Mapping, Sequence @@ -175,7 +177,7 @@ class Graph: def _create_node_instances( cls, node_configs_map: dict[str, dict[str, object]], - node_factory: "NodeFactory", + node_factory: NodeFactory, ) -> dict[str, Node]: """ Create node instances from configurations using the node factory. @@ -197,7 +199,7 @@ class Graph: return nodes @classmethod - def new(cls) -> "GraphBuilder": + def new(cls) -> GraphBuilder: """Create a fluent builder for assembling a graph programmatically.""" return GraphBuilder(graph_cls=cls) @@ -284,9 +286,9 @@ class Graph: cls, *, graph_config: Mapping[str, object], - node_factory: "NodeFactory", + node_factory: NodeFactory, root_node_id: str | None = None, - ) -> "Graph": + ) -> Graph: """ Initialize graph @@ -383,7 +385,7 @@ class GraphBuilder: self._edges: list[Edge] = [] self._edge_counter = 0 - def add_root(self, node: Node) -> "GraphBuilder": + def add_root(self, node: Node) -> GraphBuilder: """Register the root node. Must be called exactly once.""" if self._nodes: @@ -398,7 +400,7 @@ class GraphBuilder: *, from_node_id: str | None = None, source_handle: str = "source", - ) -> "GraphBuilder": + ) -> GraphBuilder: """Append a node and connect it from the specified predecessor.""" if not self._nodes: @@ -419,7 +421,7 @@ class GraphBuilder: return self - def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder": + def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder: """Connect two existing nodes without adding a new node.""" if tail not in self._nodes_by_id: diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 88d6e5cac1..9a870d7bf5 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -5,9 +5,12 @@ This engine uses a modular architecture with separated packages following Domain-Driven Design principles for improved maintainability and testability. """ +from __future__ import annotations + import contextvars import logging import queue +import threading from collections.abc import Generator from typing import TYPE_CHECKING, cast, final @@ -75,10 +78,13 @@ class GraphEngine: scale_down_idle_time: float | None = None, ) -> None: """Initialize the graph engine with all subsystems and dependencies.""" + # stop event + self._stop_event = threading.Event() # Bind runtime state to current workflow context self._graph = graph self._graph_runtime_state = graph_runtime_state + self._graph_runtime_state.stop_event = self._stop_event self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) self._command_channel = command_channel @@ -177,6 +183,7 @@ class GraphEngine: max_workers=self._max_workers, scale_up_threshold=self._scale_up_threshold, scale_down_idle_time=self._scale_down_idle_time, + stop_event=self._stop_event, ) # === Orchestration === @@ -207,6 +214,7 @@ class GraphEngine: event_handler=self._event_handler_registry, execution_coordinator=self._execution_coordinator, event_emitter=self._event_manager, + stop_event=self._stop_event, ) # === Validation === @@ -226,7 +234,7 @@ class GraphEngine: ) -> None: layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel) - def layer(self, layer: GraphEngineLayer) -> "GraphEngine": + def layer(self, layer: GraphEngineLayer) -> GraphEngine: """Add a layer for extending functionality.""" self._layers.append(layer) self._bind_layer_context(layer) @@ -324,6 +332,7 @@ class GraphEngine: def _start_execution(self, *, resume: bool = False) -> None: """Start execution subsystems.""" + self._stop_event.clear() paused_nodes: list[str] = [] if resume: paused_nodes = self._graph_runtime_state.consume_paused_nodes() @@ -351,13 +360,12 @@ class GraphEngine: def _stop_execution(self) -> None: """Stop execution subsystems.""" + self._stop_event.set() self._dispatcher.stop() self._worker_pool.stop() # Don't mark complete here as the dispatcher already does it # Notify layers - logger = logging.getLogger(__name__) - for layer in self._layers: try: layer.on_graph_end(self._graph_execution.error) diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/core/workflow/graph_engine/orchestration/dispatcher.py index 334a3f77bf..27439a2412 100644 --- a/api/core/workflow/graph_engine/orchestration/dispatcher.py +++ b/api/core/workflow/graph_engine/orchestration/dispatcher.py @@ -44,6 +44,7 @@ class Dispatcher: event_queue: queue.Queue[GraphNodeEventBase], event_handler: "EventHandler", execution_coordinator: ExecutionCoordinator, + stop_event: threading.Event, event_emitter: EventManager | None = None, ) -> None: """ @@ -61,7 +62,7 @@ class Dispatcher: self._event_emitter = event_emitter self._thread: threading.Thread | None = None - self._stop_event = threading.Event() + self._stop_event = stop_event self._start_time: float | None = None def start(self) -> None: @@ -69,16 +70,14 @@ class Dispatcher: if self._thread and self._thread.is_alive(): return - self._stop_event.clear() self._start_time = time.time() self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True) self._thread.start() def stop(self) -> None: """Stop the dispatcher thread.""" - self._stop_event.set() if self._thread and self._thread.is_alive(): - self._thread.join(timeout=10.0) + self._thread.join(timeout=2.0) def _dispatcher_loop(self) -> None: """Main dispatcher loop.""" diff --git a/api/core/workflow/graph_engine/ready_queue/factory.py b/api/core/workflow/graph_engine/ready_queue/factory.py index 1144e1de69..a9d4f470e5 100644 --- a/api/core/workflow/graph_engine/ready_queue/factory.py +++ b/api/core/workflow/graph_engine/ready_queue/factory.py @@ -2,6 +2,8 @@ Factory for creating ReadyQueue instances from serialized state. """ +from __future__ import annotations + from typing import TYPE_CHECKING from .in_memory import InMemoryReadyQueue @@ -11,7 +13,7 @@ if TYPE_CHECKING: from .protocol import ReadyQueue -def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue": +def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue: """ Create a ReadyQueue instance from a serialized state. diff --git a/api/core/workflow/graph_engine/response_coordinator/coordinator.py b/api/core/workflow/graph_engine/response_coordinator/coordinator.py index 98e0ea91ef..c5ea94ba80 100644 --- a/api/core/workflow/graph_engine/response_coordinator/coordinator.py +++ b/api/core/workflow/graph_engine/response_coordinator/coordinator.py @@ -16,7 +16,13 @@ from pydantic import BaseModel, Field from core.workflow.enums import NodeExecutionType, NodeState from core.workflow.graph import Graph -from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent +from core.workflow.graph_events import ( + ChunkType, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, + ToolCall, + ToolResult, +) from core.workflow.nodes.base.template import TextSegment, VariableSegment from core.workflow.runtime import VariablePool @@ -321,11 +327,24 @@ class ResponseStreamCoordinator: selector: Sequence[str], chunk: str, is_final: bool = False, + chunk_type: ChunkType = ChunkType.TEXT, + tool_call: ToolCall | None = None, + tool_result: ToolResult | None = None, ) -> NodeRunStreamChunkEvent: """Create a stream chunk event with consistent structure. For selectors with special prefixes (sys, env, conversation), we use the active response node's information since these are not actual node IDs. + + Args: + node_id: The node ID to attribute the event to + execution_id: The execution ID for this node + selector: The variable selector + chunk: The chunk content + is_final: Whether this is the final chunk + chunk_type: The semantic type of the chunk being streamed + tool_call: Structured data for tool_call chunks + tool_result: Structured data for tool_result chunks """ # Check if this is a special selector that doesn't correspond to a node if selector and selector[0] not in self._graph.nodes and self._active_session: @@ -338,6 +357,9 @@ class ResponseStreamCoordinator: selector=selector, chunk=chunk, is_final=is_final, + chunk_type=chunk_type, + tool_call=tool_call, + tool_result=tool_result, ) # Standard case: selector refers to an actual node @@ -349,6 +371,9 @@ class ResponseStreamCoordinator: selector=selector, chunk=chunk, is_final=is_final, + chunk_type=chunk_type, + tool_call=tool_call, + tool_result=tool_result, ) def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]: @@ -356,6 +381,8 @@ class ResponseStreamCoordinator: Handles both regular node selectors and special system selectors (sys, env, conversation). For special selectors, we attribute the output to the active response node. + + For object-type variables, automatically streams all child fields that have stream events. """ events: list[NodeRunStreamChunkEvent] = [] source_selector_prefix = segment.selector[0] if segment.selector else "" @@ -364,60 +391,81 @@ class ResponseStreamCoordinator: # Determine which node to attribute the output to # For special selectors (sys, env, conversation), use the active response node # For regular selectors, use the source node - if self._active_session and source_selector_prefix not in self._graph.nodes: - # Special selector - use active response node - output_node_id = self._active_session.node_id - else: - # Regular node selector - output_node_id = source_selector_prefix + active_session = self._active_session + special_selector = bool(active_session and source_selector_prefix not in self._graph.nodes) + output_node_id = active_session.node_id if special_selector and active_session else source_selector_prefix execution_id = self._get_or_create_execution_id(output_node_id) - # Stream all available chunks - while self._has_unread_stream(segment.selector): - if event := self._pop_stream_chunk(segment.selector): - # For special selectors, we need to update the event to use - # the active response node's information - if self._active_session and source_selector_prefix not in self._graph.nodes: - response_node = self._graph.nodes[self._active_session.node_id] - # Create a new event with the response node's information - # but keep the original selector - updated_event = NodeRunStreamChunkEvent( - id=execution_id, - node_id=response_node.id, - node_type=response_node.node_type, - selector=event.selector, # Keep original selector - chunk=event.chunk, - is_final=event.is_final, - ) - events.append(updated_event) - else: - # Regular node selector - use event as is - events.append(event) + # Check if there's a direct stream for this selector + has_direct_stream = ( + tuple(segment.selector) in self._stream_buffers or tuple(segment.selector) in self._closed_streams + ) - # Check if this is the last chunk by looking ahead - stream_closed = self._is_stream_closed(segment.selector) - # Check if stream is closed to determine if segment is complete - if stream_closed: - is_complete = True + stream_targets = [segment.selector] if has_direct_stream else sorted(self._find_child_streams(segment.selector)) - elif value := self._variable_pool.get(segment.selector): - # Process scalar value - is_last_segment = bool( - self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1 - ) - events.append( - self._create_stream_chunk_event( - node_id=output_node_id, - execution_id=execution_id, - selector=segment.selector, - chunk=value.markdown, - is_final=is_last_segment, + if stream_targets: + all_complete = True + + for target_selector in stream_targets: + while self._has_unread_stream(target_selector): + if event := self._pop_stream_chunk(target_selector): + events.append( + self._rewrite_stream_event( + event=event, + output_node_id=output_node_id, + execution_id=execution_id, + special_selector=bool(special_selector), + ) + ) + + if not self._is_stream_closed(target_selector): + all_complete = False + + is_complete = all_complete + + # Fallback: check if scalar value exists in variable pool + if not is_complete and not has_direct_stream: + if value := self._variable_pool.get(segment.selector): + # Process scalar value + is_last_segment = bool( + self._active_session + and self._active_session.index == len(self._active_session.template.segments) - 1 ) - ) - is_complete = True + events.append( + self._create_stream_chunk_event( + node_id=output_node_id, + execution_id=execution_id, + selector=segment.selector, + chunk=value.markdown, + is_final=is_last_segment, + ) + ) + is_complete = True return events, is_complete + def _rewrite_stream_event( + self, + event: NodeRunStreamChunkEvent, + output_node_id: str, + execution_id: str, + special_selector: bool, + ) -> NodeRunStreamChunkEvent: + """Rewrite event to attribute to active response node when selector is special.""" + if not special_selector: + return event + + return self._create_stream_chunk_event( + node_id=output_node_id, + execution_id=execution_id, + selector=event.selector, + chunk=event.chunk, + is_final=event.is_final, + chunk_type=event.chunk_type, + tool_call=event.tool_call, + tool_result=event.tool_result, + ) + def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]: """Process a text segment. Returns (events, is_complete).""" assert self._active_session is not None @@ -513,6 +561,36 @@ class ResponseStreamCoordinator: # ============= Internal Stream Management Methods ============= + def _find_child_streams(self, parent_selector: Sequence[str]) -> list[tuple[str, ...]]: + """Find all child stream selectors that are descendants of the parent selector. + + For example, if parent_selector is ['llm', 'generation'], this will find: + - ['llm', 'generation', 'content'] + - ['llm', 'generation', 'tool_calls'] + - ['llm', 'generation', 'tool_results'] + - ['llm', 'generation', 'thought'] + + Args: + parent_selector: The parent selector to search for children + + Returns: + List of child selector tuples found in stream buffers or closed streams + """ + parent_key = tuple(parent_selector) + parent_len = len(parent_key) + child_streams: set[tuple[str, ...]] = set() + + # Search in both active buffers and closed streams + all_selectors = set(self._stream_buffers.keys()) | self._closed_streams + + for selector_key in all_selectors: + # Check if this selector is a direct child of the parent + # Direct child means: len(child) == len(parent) + 1 and child starts with parent + if len(selector_key) == parent_len + 1 and selector_key[:parent_len] == parent_key: + child_streams.add(selector_key) + + return sorted(child_streams) + def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None: """ Append a stream chunk to the internal buffer. diff --git a/api/core/workflow/graph_engine/response_coordinator/session.py b/api/core/workflow/graph_engine/response_coordinator/session.py index 8b7c2e441e..8ceaa428c3 100644 --- a/api/core/workflow/graph_engine/response_coordinator/session.py +++ b/api/core/workflow/graph_engine/response_coordinator/session.py @@ -5,6 +5,8 @@ This module contains the private ResponseSession class used internally by ResponseStreamCoordinator to manage streaming sessions. """ +from __future__ import annotations + from dataclasses import dataclass from core.workflow.nodes.answer.answer_node import AnswerNode @@ -27,7 +29,7 @@ class ResponseSession: index: int = 0 # Current position in the template segments @classmethod - def from_node(cls, node: Node) -> "ResponseSession": + def from_node(cls, node: Node) -> ResponseSession: """ Create a ResponseSession from an AnswerNode or EndNode. diff --git a/api/core/workflow/graph_engine/worker.py b/api/core/workflow/graph_engine/worker.py index e37a08ae47..83419830b6 100644 --- a/api/core/workflow/graph_engine/worker.py +++ b/api/core/workflow/graph_engine/worker.py @@ -42,6 +42,7 @@ class Worker(threading.Thread): event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, layers: Sequence[GraphEngineLayer], + stop_event: threading.Event, worker_id: int = 0, flask_app: Flask | None = None, context_vars: contextvars.Context | None = None, @@ -65,13 +66,16 @@ class Worker(threading.Thread): self._worker_id = worker_id self._flask_app = flask_app self._context_vars = context_vars - self._stop_event = threading.Event() self._last_task_time = time.time() + self._stop_event = stop_event self._layers = layers if layers is not None else [] def stop(self) -> None: - """Signal the worker to stop processing.""" - self._stop_event.set() + """Worker is controlled via shared stop_event from GraphEngine. + + This method is a no-op retained for backward compatibility. + """ + pass @property def is_idle(self) -> bool: diff --git a/api/core/workflow/graph_engine/worker_management/worker_pool.py b/api/core/workflow/graph_engine/worker_management/worker_pool.py index 5b9234586b..df76ebe882 100644 --- a/api/core/workflow/graph_engine/worker_management/worker_pool.py +++ b/api/core/workflow/graph_engine/worker_management/worker_pool.py @@ -41,6 +41,7 @@ class WorkerPool: event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, layers: list[GraphEngineLayer], + stop_event: threading.Event, flask_app: "Flask | None" = None, context_vars: "Context | None" = None, min_workers: int | None = None, @@ -81,6 +82,7 @@ class WorkerPool: self._worker_counter = 0 self._lock = threading.RLock() self._running = False + self._stop_event = stop_event # No longer tracking worker states with callbacks to avoid lock contention @@ -135,7 +137,7 @@ class WorkerPool: # Wait for workers to finish for worker in self._workers: if worker.is_alive(): - worker.join(timeout=10.0) + worker.join(timeout=2.0) self._workers.clear() @@ -152,6 +154,7 @@ class WorkerPool: worker_id=worker_id, flask_app=self._flask_app, context_vars=self._context_vars, + stop_event=self._stop_event, ) worker.start() diff --git a/api/core/workflow/graph_events/__init__.py b/api/core/workflow/graph_events/__init__.py index 7a5edbb331..4ee0ec94d2 100644 --- a/api/core/workflow/graph_events/__init__.py +++ b/api/core/workflow/graph_events/__init__.py @@ -36,6 +36,7 @@ from .loop import ( # Node events from .node import ( + ChunkType, NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunPauseRequestedEvent, @@ -44,10 +45,13 @@ from .node import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + ToolCall, + ToolResult, ) __all__ = [ "BaseGraphEvent", + "ChunkType", "GraphEngineEvent", "GraphNodeEventBase", "GraphRunAbortedEvent", @@ -73,4 +77,6 @@ __all__ = [ "NodeRunStartedEvent", "NodeRunStreamChunkEvent", "NodeRunSucceededEvent", + "ToolCall", + "ToolResult", ] diff --git a/api/core/workflow/graph_events/node.py b/api/core/workflow/graph_events/node.py index f225798d41..2ae4fcd919 100644 --- a/api/core/workflow/graph_events/node.py +++ b/api/core/workflow/graph_events/node.py @@ -1,10 +1,11 @@ from collections.abc import Sequence from datetime import datetime +from enum import StrEnum from pydantic import Field from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.workflow.entities import AgentNodeStrategyInit +from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult from core.workflow.entities.pause_reason import PauseReason from .base import GraphNodeEventBase @@ -21,13 +22,39 @@ class NodeRunStartedEvent(GraphNodeEventBase): provider_id: str = "" +class ChunkType(StrEnum): + """Stream chunk type for LLM-related events.""" + + TEXT = "text" # Normal text streaming + TOOL_CALL = "tool_call" # Tool call arguments streaming + TOOL_RESULT = "tool_result" # Tool execution result + THOUGHT = "thought" # Agent thinking process (ReAct) + THOUGHT_START = "thought_start" # Agent thought start + THOUGHT_END = "thought_end" # Agent thought end + + class NodeRunStreamChunkEvent(GraphNodeEventBase): - # Spec-compliant fields + """Stream chunk event for workflow node execution.""" + + # Base fields selector: Sequence[str] = Field( ..., description="selector identifying the output location (e.g., ['nodeA', 'text'])" ) chunk: str = Field(..., description="the actual chunk content") is_final: bool = Field(default=False, description="indicates if this is the last chunk") + chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk") + + # Tool call fields (when chunk_type == TOOL_CALL) + tool_call: ToolCall | None = Field( + default=None, + description="structured payload for tool_call chunks", + ) + + # Tool result fields (when chunk_type == TOOL_RESULT) + tool_result: ToolResult | None = Field( + default=None, + description="structured payload for tool_result chunks", + ) class NodeRunRetrieverResourceEvent(GraphNodeEventBase): diff --git a/api/core/workflow/node_events/__init__.py b/api/core/workflow/node_events/__init__.py index f14a594c85..67263311b9 100644 --- a/api/core/workflow/node_events/__init__.py +++ b/api/core/workflow/node_events/__init__.py @@ -13,16 +13,21 @@ from .loop import ( LoopSucceededEvent, ) from .node import ( + ChunkType, ModelInvokeCompletedEvent, PauseRequestedEvent, RunRetrieverResourceEvent, RunRetryEvent, StreamChunkEvent, StreamCompletedEvent, + ThoughtChunkEvent, + ToolCallChunkEvent, + ToolResultChunkEvent, ) __all__ = [ "AgentLogEvent", + "ChunkType", "IterationFailedEvent", "IterationNextEvent", "IterationStartedEvent", @@ -39,4 +44,7 @@ __all__ = [ "RunRetryEvent", "StreamChunkEvent", "StreamCompletedEvent", + "ThoughtChunkEvent", + "ToolCallChunkEvent", + "ToolResultChunkEvent", ] diff --git a/api/core/workflow/node_events/node.py b/api/core/workflow/node_events/node.py index e4fa52f444..f0bbc4d96f 100644 --- a/api/core/workflow/node_events/node.py +++ b/api/core/workflow/node_events/node.py @@ -1,11 +1,13 @@ from collections.abc import Sequence from datetime import datetime +from enum import StrEnum from pydantic import Field from core.file import File from core.model_runtime.entities.llm_entities import LLMUsage from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.workflow.entities import ToolCall, ToolResult from core.workflow.entities.pause_reason import PauseReason from core.workflow.node_events import NodeRunResult @@ -32,13 +34,60 @@ class RunRetryEvent(NodeEventBase): start_at: datetime = Field(..., description="Retry start time") +class ChunkType(StrEnum): + """Stream chunk type for LLM-related events.""" + + TEXT = "text" # Normal text streaming + TOOL_CALL = "tool_call" # Tool call arguments streaming + TOOL_RESULT = "tool_result" # Tool execution result + THOUGHT = "thought" # Agent thinking process (ReAct) + THOUGHT_START = "thought_start" # Agent thought start + THOUGHT_END = "thought_end" # Agent thought end + + class StreamChunkEvent(NodeEventBase): - # Spec-compliant fields + """Base stream chunk event - normal text streaming output.""" + selector: Sequence[str] = Field( ..., description="selector identifying the output location (e.g., ['nodeA', 'text'])" ) chunk: str = Field(..., description="the actual chunk content") is_final: bool = Field(default=False, description="indicates if this is the last chunk") + chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk") + tool_call: ToolCall | None = Field(default=None, description="structured payload for tool_call chunks") + tool_result: ToolResult | None = Field(default=None, description="structured payload for tool_result chunks") + + +class ToolCallChunkEvent(StreamChunkEvent): + """Tool call streaming event - tool call arguments streaming output.""" + + chunk_type: ChunkType = Field(default=ChunkType.TOOL_CALL, frozen=True) + tool_call: ToolCall | None = Field(default=None, description="structured tool call payload") + + +class ToolResultChunkEvent(StreamChunkEvent): + """Tool result event - tool execution result.""" + + chunk_type: ChunkType = Field(default=ChunkType.TOOL_RESULT, frozen=True) + tool_result: ToolResult | None = Field(default=None, description="structured tool result payload") + + +class ThoughtStartChunkEvent(StreamChunkEvent): + """Agent thought start streaming event - Agent thinking process (ReAct).""" + + chunk_type: ChunkType = Field(default=ChunkType.THOUGHT_START, frozen=True) + + +class ThoughtEndChunkEvent(StreamChunkEvent): + """Agent thought end streaming event - Agent thinking process (ReAct).""" + + chunk_type: ChunkType = Field(default=ChunkType.THOUGHT_END, frozen=True) + + +class ThoughtChunkEvent(StreamChunkEvent): + """Agent thought streaming event - Agent thinking process (ReAct).""" + + chunk_type: ChunkType = Field(default=ChunkType.THOUGHT, frozen=True) class StreamCompletedEvent(NodeEventBase): diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 4be006de11..234651ce96 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, cast @@ -167,7 +169,7 @@ class AgentNode(Node[AgentNodeData]): variable_pool: VariablePool, node_data: AgentNodeData, for_log: bool = False, - strategy: "PluginAgentStrategy", + strategy: PluginAgentStrategy, ) -> dict[str, Any]: """ Generate parameters based on the given tool parameters, variable pool, and node data. @@ -328,7 +330,7 @@ class AgentNode(Node[AgentNodeData]): def _generate_credentials( self, parameters: dict[str, Any], - ) -> "InvokeCredentials": + ) -> InvokeCredentials: """ Generate credentials based on the given agent parameters. """ @@ -442,9 +444,7 @@ class AgentNode(Node[AgentNodeData]): model_schema.features.remove(feature) return model_schema - def _filter_mcp_type_tool( - self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]] - ) -> list[dict[str, Any]]: + def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: """ Filter MCP type tool :param strategy: plugin agent strategy diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 5aab6bbde4..e5a20c8e91 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from abc import ABC from builtins import type as type_ @@ -111,7 +113,7 @@ class DefaultValue(BaseModel): raise DefaultValueTypeError(f"Cannot convert to number: {value}") @model_validator(mode="after") - def validate_value_type(self) -> "DefaultValue": + def validate_value_type(self) -> DefaultValue: # Type validation configuration type_validators = { DefaultValueType.STRING: { diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 8ebba3659c..06e4c0440d 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib import logging import operator @@ -46,6 +48,9 @@ from core.workflow.node_events import ( RunRetrieverResourceEvent, StreamChunkEvent, StreamCompletedEvent, + ThoughtChunkEvent, + ToolCallChunkEvent, + ToolResultChunkEvent, ) from core.workflow.runtime import GraphRuntimeState from libs.datetime_utils import naive_utc_now @@ -59,7 +64,7 @@ logger = logging.getLogger(__name__) class Node(Generic[NodeDataT]): - node_type: ClassVar["NodeType"] + node_type: ClassVar[NodeType] execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE _node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData @@ -198,14 +203,14 @@ class Node(Generic[NodeDataT]): return None # Global registry populated via __init_subclass__ - _registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {} + _registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {} def __init__( self, id: str, config: Mapping[str, Any], - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, ) -> None: self._graph_init_params = graph_init_params self.id = id @@ -241,7 +246,7 @@ class Node(Generic[NodeDataT]): return @property - def graph_init_params(self) -> "GraphInitParams": + def graph_init_params(self) -> GraphInitParams: return self._graph_init_params @property @@ -264,6 +269,10 @@ class Node(Generic[NodeDataT]): """ raise NotImplementedError + def _should_stop(self) -> bool: + """Check if execution should be stopped.""" + return self.graph_runtime_state.stop_event.is_set() + def run(self) -> Generator[GraphNodeEventBase, None, None]: execution_id = self.ensure_execution_id() self._start_at = naive_utc_now() @@ -332,6 +341,21 @@ class Node(Generic[NodeDataT]): yield event else: yield event + + if self._should_stop(): + error_message = "Execution cancelled" + yield NodeRunFailedEvent( + id=self.execution_id, + node_id=self._node_id, + node_type=self.node_type, + start_at=self._start_at, + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=error_message, + ), + error=error_message, + ) + return except Exception as e: logger.exception("Node %s failed to run", self._node_id) result = NodeRunResult( @@ -438,7 +462,7 @@ class Node(Generic[NodeDataT]): raise NotImplementedError("subclasses of BaseNode must implement `version` method.") @classmethod - def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]: + def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]: """Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry. Import all modules under core.workflow.nodes so subclasses register themselves on import. @@ -543,6 +567,8 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent: + from core.workflow.graph_events import ChunkType + return NodeRunStreamChunkEvent( id=self.execution_id, node_id=self._node_id, @@ -550,6 +576,60 @@ class Node(Generic[NodeDataT]): selector=event.selector, chunk=event.chunk, is_final=event.is_final, + chunk_type=ChunkType(event.chunk_type.value), + tool_call=event.tool_call, + tool_result=event.tool_result, + ) + + @_dispatch.register + def _(self, event: ToolCallChunkEvent) -> NodeRunStreamChunkEvent: + from core.workflow.graph_events import ChunkType + + return NodeRunStreamChunkEvent( + id=self._node_execution_id, + node_id=self._node_id, + node_type=self.node_type, + selector=event.selector, + chunk=event.chunk, + is_final=event.is_final, + chunk_type=ChunkType.TOOL_CALL, + tool_call=event.tool_call, + ) + + @_dispatch.register + def _(self, event: ToolResultChunkEvent) -> NodeRunStreamChunkEvent: + from core.workflow.entities import ToolResult, ToolResultStatus + from core.workflow.graph_events import ChunkType + + tool_result = event.tool_result or ToolResult() + status: ToolResultStatus = tool_result.status or ToolResultStatus.SUCCESS + tool_result = tool_result.model_copy( + update={"status": status, "files": tool_result.files or []}, + ) + + return NodeRunStreamChunkEvent( + id=self._node_execution_id, + node_id=self._node_id, + node_type=self.node_type, + selector=event.selector, + chunk=event.chunk, + is_final=event.is_final, + chunk_type=ChunkType.TOOL_RESULT, + tool_result=tool_result, + ) + + @_dispatch.register + def _(self, event: ThoughtChunkEvent) -> NodeRunStreamChunkEvent: + from core.workflow.graph_events import ChunkType + + return NodeRunStreamChunkEvent( + id=self._node_execution_id, + node_id=self._node_id, + node_type=self.node_type, + selector=event.selector, + chunk=event.chunk, + is_final=event.is_final, + chunk_type=ChunkType.THOUGHT, ) @_dispatch.register diff --git a/api/core/workflow/nodes/base/template.py b/api/core/workflow/nodes/base/template.py index ba3e2058cf..81f4b9f6fb 100644 --- a/api/core/workflow/nodes/base/template.py +++ b/api/core/workflow/nodes/base/template.py @@ -4,6 +4,8 @@ This module provides a unified template structure for both Answer and End nodes, similar to SegmentGroup but focused on template representation without values. """ +from __future__ import annotations + from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass @@ -58,7 +60,7 @@ class Template: segments: list[TemplateSegmentUnion] @classmethod - def from_answer_template(cls, template_str: str) -> "Template": + def from_answer_template(cls, template_str: str) -> Template: """Create a Template from an Answer node template string. Example: @@ -107,7 +109,7 @@ class Template: return cls(segments=segments) @classmethod - def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template": + def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> Template: """Create a Template from an End node outputs configuration. End nodes are treated as templates of concatenated variables with newlines. diff --git a/api/core/workflow/nodes/llm/__init__.py b/api/core/workflow/nodes/llm/__init__.py index f7bc713f63..edd0d3d581 100644 --- a/api/core/workflow/nodes/llm/__init__.py +++ b/api/core/workflow/nodes/llm/__init__.py @@ -3,6 +3,7 @@ from .entities import ( LLMNodeCompletionModelPromptTemplate, LLMNodeData, ModelConfig, + ToolMetadata, VisionConfig, ) from .node import LLMNode @@ -13,5 +14,6 @@ __all__ = [ "LLMNodeCompletionModelPromptTemplate", "LLMNodeData", "ModelConfig", + "ToolMetadata", "VisionConfig", ] diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index fe6f2290aa..0d61e45003 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -1,10 +1,17 @@ +import re from collections.abc import Mapping, Sequence from typing import Any, Literal -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator +from core.agent.entities import AgentLog, AgentResult +from core.file import File from core.model_runtime.entities import ImagePromptMessageContent, LLMMode +from core.model_runtime.entities.llm_entities import LLMUsage from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.tools.entities.tool_entities import ToolProviderType +from core.workflow.entities import ToolCall, ToolCallResult +from core.workflow.node_events import AgentLogEvent from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.base.entities import VariableSelector @@ -58,6 +65,268 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): jinja2_text: str | None = None +class ToolMetadata(BaseModel): + """ + Tool metadata for LLM node with tool support. + + Defines the essential fields needed for tool configuration, + particularly the 'type' field to identify tool provider type. + """ + + # Core fields + enabled: bool = True + type: ToolProviderType = Field(..., description="Tool provider type: builtin, api, mcp, workflow") + provider_name: str = Field(..., description="Tool provider name/identifier") + tool_name: str = Field(..., description="Tool name") + + # Optional fields + plugin_unique_identifier: str | None = Field(None, description="Plugin unique identifier for plugin tools") + credential_id: str | None = Field(None, description="Credential ID for tools requiring authentication") + + # Configuration fields + parameters: dict[str, Any] = Field(default_factory=dict, description="Tool parameters") + settings: dict[str, Any] = Field(default_factory=dict, description="Tool settings configuration") + extra: dict[str, Any] = Field(default_factory=dict, description="Extra tool configuration like custom description") + + +class ModelTraceSegment(BaseModel): + """Model invocation trace segment with token usage and output.""" + + text: str | None = Field(None, description="Model output text content") + reasoning: str | None = Field(None, description="Reasoning/thought content from model") + tool_calls: list[ToolCall] = Field(default_factory=list, description="Tool calls made by the model") + + +class ToolTraceSegment(BaseModel): + """Tool invocation trace segment with call details and result.""" + + id: str | None = Field(default=None, description="Unique identifier for this tool call") + name: str | None = Field(default=None, description="Name of the tool being called") + arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON") + output: str | None = Field(default=None, description="Tool call result") + + +class LLMTraceSegment(BaseModel): + """ + Streaming trace segment for LLM tool-enabled runs. + + Represents alternating model and tool invocations in sequence: + model -> tool -> model -> tool -> ... + + Each segment records its execution duration. + """ + + type: Literal["model", "tool"] + duration: float = Field(..., description="Execution duration in seconds") + usage: LLMUsage | None = Field(default=None, description="Token usage statistics for this model call") + output: ModelTraceSegment | ToolTraceSegment = Field(..., description="Output of the segment") + + # Common metadata for both model and tool segments + provider: str | None = Field(default=None, description="Model or tool provider identifier") + name: str | None = Field(default=None, description="Name of the model or tool") + icon: str | None = Field(default=None, description="Icon for the provider") + icon_dark: str | None = Field(default=None, description="Dark theme icon for the provider") + error: str | None = Field(default=None, description="Error message if segment failed") + status: Literal["success", "error"] | None = Field(default=None, description="Tool execution status") + + +class LLMGenerationData(BaseModel): + """Generation data from LLM invocation with tools. + + For multi-turn tool calls like: thought1 -> text1 -> tool_call1 -> thought2 -> text2 -> tool_call2 + - reasoning_contents: [thought1, thought2, ...] - one element per turn + - tool_calls: [{id, name, arguments, result}, ...] - all tool calls with results + """ + + text: str = Field(..., description="Accumulated text content from all turns") + reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn") + tool_calls: list[ToolCallResult] = Field(default_factory=list, description="Tool calls with results") + sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering") + usage: LLMUsage = Field(..., description="LLM usage statistics") + finish_reason: str | None = Field(None, description="Finish reason from LLM") + files: list[File] = Field(default_factory=list, description="Generated files") + trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order") + + +class ThinkTagStreamParser: + """Lightweight state machine to split streaming chunks by tags.""" + + _START_PATTERN = re.compile(r"]*)?>", re.IGNORECASE) + _END_PATTERN = re.compile(r"", re.IGNORECASE) + _START_PREFIX = " int: + """Return length of the longest suffix of `text` that is a prefix of `prefix`.""" + max_len = min(len(text), len(prefix) - 1) + for i in range(max_len, 0, -1): + if text[-i:].lower() == prefix[:i].lower(): + return i + return 0 + + def process(self, chunk: str) -> list[tuple[str, str]]: + """ + Split incoming chunk into ('thought' | 'text', content) tuples. + Content excludes the tags themselves and handles split tags across chunks. + """ + parts: list[tuple[str, str]] = [] + self._buffer += chunk + + while self._buffer: + if self._in_think: + end_match = self._END_PATTERN.search(self._buffer) + if end_match: + thought_text = self._buffer[: end_match.start()] + if thought_text: + parts.append(("thought", thought_text)) + parts.append(("thought_end", "")) + self._buffer = self._buffer[end_match.end() :] + self._in_think = False + continue + + hold_len = self._suffix_prefix_len(self._buffer, self._END_PREFIX) + emit = self._buffer[: len(self._buffer) - hold_len] + if emit: + parts.append(("thought", emit)) + self._buffer = self._buffer[-hold_len:] if hold_len > 0 else "" + break + + start_match = self._START_PATTERN.search(self._buffer) + if start_match: + prefix = self._buffer[: start_match.start()] + if prefix: + parts.append(("text", prefix)) + self._buffer = self._buffer[start_match.end() :] + parts.append(("thought_start", "")) + self._in_think = True + continue + + hold_len = self._suffix_prefix_len(self._buffer, self._START_PREFIX) + emit = self._buffer[: len(self._buffer) - hold_len] + if emit: + parts.append(("text", emit)) + self._buffer = self._buffer[-hold_len:] if hold_len > 0 else "" + break + + cleaned_parts: list[tuple[str, str]] = [] + for kind, content in parts: + # Extra safeguard: strip any stray tags that slipped through. + content = self._START_PATTERN.sub("", content) + content = self._END_PATTERN.sub("", content) + if content or kind in {"thought_start", "thought_end"}: + cleaned_parts.append((kind, content)) + + return cleaned_parts + + def flush(self) -> list[tuple[str, str]]: + """Flush remaining buffer when the stream ends.""" + if not self._buffer: + return [] + kind = "thought" if self._in_think else "text" + content = self._buffer + # Drop dangling partial tags instead of emitting them + if content.lower().startswith(self._START_PREFIX) or content.lower().startswith(self._END_PREFIX): + content = "" + self._buffer = "" + if not content and not self._in_think: + return [] + # Strip any complete tags that might still be present. + content = self._START_PATTERN.sub("", content) + content = self._END_PATTERN.sub("", content) + + result: list[tuple[str, str]] = [] + if content: + result.append((kind, content)) + if self._in_think: + result.append(("thought_end", "")) + self._in_think = False + return result + + +class StreamBuffers(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + think_parser: ThinkTagStreamParser = Field(default_factory=ThinkTagStreamParser) + pending_thought: list[str] = Field(default_factory=list) + pending_content: list[str] = Field(default_factory=list) + pending_tool_calls: list[ToolCall] = Field(default_factory=list) + current_turn_reasoning: list[str] = Field(default_factory=list) + reasoning_per_turn: list[str] = Field(default_factory=list) + + +class TraceState(BaseModel): + trace_segments: list[LLMTraceSegment] = Field(default_factory=list) + tool_trace_map: dict[str, LLMTraceSegment] = Field(default_factory=dict) + tool_call_index_map: dict[str, int] = Field(default_factory=dict) + model_segment_start_time: float | None = Field(default=None, description="Start time for current model segment") + pending_usage: LLMUsage | None = Field(default=None, description="Pending usage for current model segment") + + +class AggregatedResult(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + text: str = "" + files: list[File] = Field(default_factory=list) + usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage) + finish_reason: str | None = None + + +class AgentContext(BaseModel): + agent_logs: list[AgentLogEvent] = Field(default_factory=list) + agent_result: AgentResult | None = None + + +class ToolOutputState(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + stream: StreamBuffers = Field(default_factory=StreamBuffers) + trace: TraceState = Field(default_factory=TraceState) + aggregate: AggregatedResult = Field(default_factory=AggregatedResult) + agent: AgentContext = Field(default_factory=AgentContext) + + +class ToolLogPayload(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + tool_name: str = "" + tool_call_id: str = "" + tool_args: dict[str, Any] = Field(default_factory=dict) + tool_output: Any = None + tool_error: Any = None + files: list[Any] = Field(default_factory=list) + meta: dict[str, Any] = Field(default_factory=dict) + + @classmethod + def from_log(cls, log: AgentLog) -> "ToolLogPayload": + data = log.data or {} + return cls( + tool_name=data.get("tool_name", ""), + tool_call_id=data.get("tool_call_id", ""), + tool_args=data.get("tool_args") or {}, + tool_output=data.get("output"), + tool_error=data.get("error"), + files=data.get("files") or [], + meta=data.get("meta") or {}, + ) + + @classmethod + def from_mapping(cls, data: Mapping[str, Any]) -> "ToolLogPayload": + return cls( + tool_name=data.get("tool_name", ""), + tool_call_id=data.get("tool_call_id", ""), + tool_args=data.get("tool_args") or {}, + tool_output=data.get("output"), + tool_error=data.get("error"), + files=data.get("files") or [], + meta=data.get("meta") or {}, + ) + + class LLMNodeData(BaseNodeData): model: ModelConfig prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate @@ -86,6 +355,10 @@ class LLMNodeData(BaseNodeData): ), ) + # Tool support + tools: Sequence[ToolMetadata] = Field(default_factory=list) + max_iterations: int | None = Field(default=None, description="Maximum number of iterations for the LLM node") + @field_validator("prompt_config", mode="before") @classmethod def convert_none_prompt_config(cls, v: Any): diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 04e2802191..20de710db3 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import io import json @@ -9,6 +11,8 @@ from typing import TYPE_CHECKING, Any, Literal from sqlalchemy import select +from core.agent.entities import AgentLog, AgentResult, AgentToolEntity, ExecutionContext +from core.agent.patterns import StrategyFactory from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import File, FileTransferMethod, FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage @@ -46,7 +50,9 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.tools.__base.tool import Tool from core.tools.signature import sign_upload_file +from core.tools.tool_manager import ToolManager from core.variables import ( ArrayFileSegment, ArraySegment, @@ -56,7 +62,8 @@ from core.variables import ( StringSegment, ) from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.entities import GraphInitParams +from core.workflow.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus +from core.workflow.entities.tool_entities import ToolCallResult from core.workflow.enums import ( NodeType, SystemVariableKey, @@ -64,13 +71,18 @@ from core.workflow.enums import ( WorkflowNodeExecutionStatus, ) from core.workflow.node_events import ( + AgentLogEvent, ModelInvokeCompletedEvent, NodeEventBase, NodeRunResult, RunRetrieverResourceEvent, StreamChunkEvent, StreamCompletedEvent, + ThoughtChunkEvent, + ToolCallChunkEvent, + ToolResultChunkEvent, ) +from core.workflow.node_events.node import ThoughtEndChunkEvent, ThoughtStartChunkEvent from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser @@ -81,10 +93,21 @@ from models.model import UploadFile from . import llm_utils from .entities import ( + AgentContext, + AggregatedResult, + LLMGenerationData, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, LLMNodeData, + LLMTraceSegment, ModelConfig, + ModelTraceSegment, + StreamBuffers, + ThinkTagStreamParser, + ToolLogPayload, + ToolOutputState, + ToolTraceSegment, + TraceState, ) from .exc import ( InvalidContextStructureError, @@ -113,7 +136,7 @@ class LLMNode(Node[LLMNodeData]): # Instance attributes specific to LLMNode. # Output variable for file - _file_outputs: list["File"] + _file_outputs: list[File] _llm_file_saver: LLMFileSaver @@ -121,8 +144,8 @@ class LLMNode(Node[LLMNodeData]): self, id: str, config: Mapping[str, Any], - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, *, llm_file_saver: LLMFileSaver | None = None, ): @@ -149,11 +172,11 @@ class LLMNode(Node[LLMNodeData]): def _run(self) -> Generator: node_inputs: dict[str, Any] = {} process_data: dict[str, Any] = {} - result_text = "" clean_text = "" usage = LLMUsage.empty_usage() finish_reason = None - reasoning_content = None + reasoning_content = "" # Initialize as empty string for consistency + clean_text = "" # Initialize clean_text to avoid UnboundLocalError variable_pool = self.graph_runtime_state.variable_pool try: @@ -234,55 +257,58 @@ class LLMNode(Node[LLMNodeData]): context_files=context_files, ) - # handle invoke result - generator = LLMNode.invoke_llm( - node_data_model=self.node_data.model, - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop, - user_id=self.user_id, - structured_output_enabled=self.node_data.structured_output_enabled, - structured_output=self.node_data.structured_output, - file_saver=self._llm_file_saver, - file_outputs=self._file_outputs, - node_id=self._node_id, - node_type=self.node_type, - reasoning_format=self.node_data.reasoning_format, - ) - + # Variables for outputs + generation_data: LLMGenerationData | None = None structured_output: LLMStructuredOutput | None = None - for event in generator: - if isinstance(event, StreamChunkEvent): - yield event - elif isinstance(event, ModelInvokeCompletedEvent): - # Raw text - result_text = event.text - usage = event.usage - finish_reason = event.finish_reason - reasoning_content = event.reasoning_content or "" + # Check if tools are configured + if self.tool_call_enabled: + # Use tool-enabled invocation (Agent V2 style) + generator = self._invoke_llm_with_tools( + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop, + files=files, + variable_pool=variable_pool, + node_inputs=node_inputs, + process_data=process_data, + ) + else: + # Use traditional LLM invocation + generator = LLMNode.invoke_llm( + node_data_model=self._node_data.model, + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop, + user_id=self.user_id, + structured_output_enabled=self._node_data.structured_output_enabled, + structured_output=self._node_data.structured_output, + file_saver=self._llm_file_saver, + file_outputs=self._file_outputs, + node_id=self._node_id, + node_type=self.node_type, + reasoning_format=self._node_data.reasoning_format, + ) - # For downstream nodes, determine clean text based on reasoning_format - if self.node_data.reasoning_format == "tagged": - # Keep tags for backward compatibility - clean_text = result_text - else: - # Extract clean text from tags - clean_text, _ = LLMNode._split_reasoning(result_text, self.node_data.reasoning_format) + ( + clean_text, + reasoning_content, + generation_reasoning_content, + generation_clean_content, + usage, + finish_reason, + structured_output, + generation_data, + ) = yield from self._stream_llm_events(generator, model_instance=model_instance) - # Process structured output if available from the event. - structured_output = ( - LLMStructuredOutput(structured_output=event.structured_output) - if event.structured_output - else None - ) - - # deduct quota - llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) - break - elif isinstance(event, LLMStructuredOutput): - structured_output = event + # Extract variables from generation_data if available + if generation_data: + clean_text = generation_data.text + reasoning_content = "" + usage = generation_data.usage + finish_reason = generation_data.finish_reason + # Unified process_data building process_data = { "model_mode": model_config.mode, "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( @@ -293,24 +319,88 @@ class LLMNode(Node[LLMNodeData]): "model_provider": model_config.provider, "model_name": model_config.model, } + if self.tool_call_enabled and self._node_data.tools: + process_data["tools"] = [ + { + "type": tool.type.value if hasattr(tool.type, "value") else tool.type, + "provider_name": tool.provider_name, + "tool_name": tool.tool_name, + } + for tool in self._node_data.tools + if tool.enabled + ] + # Unified outputs building outputs = { "text": clean_text, "reasoning_content": reasoning_content, "usage": jsonable_encoder(usage), "finish_reason": finish_reason, } + + # Build generation field + if generation_data: + # Use generation_data from tool invocation (supports multi-turn) + generation = { + "content": generation_data.text, + "reasoning_content": generation_data.reasoning_contents, # [thought1, thought2, ...] + "tool_calls": [self._serialize_tool_call(item) for item in generation_data.tool_calls], + "sequence": generation_data.sequence, + } + files_to_output = generation_data.files + else: + # Traditional LLM invocation + generation_reasoning = generation_reasoning_content or reasoning_content + generation_content = generation_clean_content or clean_text + sequence: list[dict[str, Any]] = [] + if generation_reasoning: + sequence = [ + {"type": "reasoning", "index": 0}, + {"type": "content", "start": 0, "end": len(generation_content)}, + ] + generation = { + "content": generation_content, + "reasoning_content": [generation_reasoning] if generation_reasoning else [], + "tool_calls": [], + "sequence": sequence, + } + files_to_output = self._file_outputs + + outputs["generation"] = generation + if files_to_output: + outputs["files"] = ArrayFileSegment(value=files_to_output) if structured_output: outputs["structured_output"] = structured_output.structured_output - if self._file_outputs: - outputs["files"] = ArrayFileSegment(value=self._file_outputs) # Send final chunk event to indicate streaming is complete - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk="", - is_final=True, - ) + if not self.tool_call_enabled: + # For tool calls, final events are already sent in _process_tool_outputs + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk="", + is_final=True, + ) + yield StreamChunkEvent( + selector=[self._node_id, "generation", "content"], + chunk="", + is_final=True, + ) + yield ThoughtChunkEvent( + selector=[self._node_id, "generation", "thought"], + chunk="", + is_final=True, + ) + + metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, + } + + if generation_data and generation_data.trace: + metadata[WorkflowNodeExecutionMetadataKey.LLM_TRACE] = [ + segment.model_dump() for segment in generation_data.trace + ] yield StreamCompletedEvent( node_run_result=NodeRunResult( @@ -318,11 +408,7 @@ class LLMNode(Node[LLMNodeData]): inputs=node_inputs, process_data=process_data, outputs=outputs, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, + metadata=metadata, llm_usage=usage, ) ) @@ -361,7 +447,7 @@ class LLMNode(Node[LLMNodeData]): structured_output_enabled: bool, structured_output: Mapping[str, Any] | None = None, file_saver: LLMFileSaver, - file_outputs: list["File"], + file_outputs: list[File], node_id: str, node_type: NodeType, reasoning_format: Literal["separated", "tagged"] = "tagged", @@ -415,7 +501,7 @@ class LLMNode(Node[LLMNodeData]): *, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], file_saver: LLMFileSaver, - file_outputs: list["File"], + file_outputs: list[File], node_id: str, node_type: NodeType, reasoning_format: Literal["separated", "tagged"] = "tagged", @@ -444,6 +530,8 @@ class LLMNode(Node[LLMNodeData]): usage = LLMUsage.empty_usage() finish_reason = None full_text_buffer = io.StringIO() + think_parser = ThinkTagStreamParser() + reasoning_chunks: list[str] = [] # Initialize streaming metrics tracking start_time = request_start_time if request_start_time is not None else time.perf_counter() @@ -472,12 +560,45 @@ class LLMNode(Node[LLMNodeData]): has_content = True full_text_buffer.write(text_part) + # Text output: always forward raw chunk (keep tags intact) yield StreamChunkEvent( selector=[node_id, "text"], chunk=text_part, is_final=False, ) + # Generation output: split out thoughts, forward only non-thought content chunks + for kind, segment in think_parser.process(text_part): + if not segment: + if kind not in {"thought_start", "thought_end"}: + continue + + if kind == "thought_start": + yield ThoughtStartChunkEvent( + selector=[node_id, "generation", "thought"], + chunk="", + is_final=False, + ) + elif kind == "thought": + reasoning_chunks.append(segment) + yield ThoughtChunkEvent( + selector=[node_id, "generation", "thought"], + chunk=segment, + is_final=False, + ) + elif kind == "thought_end": + yield ThoughtEndChunkEvent( + selector=[node_id, "generation", "thought"], + chunk="", + is_final=False, + ) + else: + yield StreamChunkEvent( + selector=[node_id, "generation", "content"], + chunk=segment, + is_final=False, + ) + # Update the whole metadata if not model and result.model: model = result.model @@ -492,16 +613,47 @@ class LLMNode(Node[LLMNodeData]): except OutputParserError as e: raise LLMNodeError(f"Failed to parse structured output: {e}") + for kind, segment in think_parser.flush(): + if not segment and kind not in {"thought_start", "thought_end"}: + continue + if kind == "thought_start": + yield ThoughtStartChunkEvent( + selector=[node_id, "generation", "thought"], + chunk="", + is_final=False, + ) + elif kind == "thought": + reasoning_chunks.append(segment) + yield ThoughtChunkEvent( + selector=[node_id, "generation", "thought"], + chunk=segment, + is_final=False, + ) + elif kind == "thought_end": + yield ThoughtEndChunkEvent( + selector=[node_id, "generation", "thought"], + chunk="", + is_final=False, + ) + else: + yield StreamChunkEvent( + selector=[node_id, "generation", "content"], + chunk=segment, + is_final=False, + ) + # Extract reasoning content from tags in the main text full_text = full_text_buffer.getvalue() if reasoning_format == "tagged": # Keep tags in text for backward compatibility clean_text = full_text - reasoning_content = "" + reasoning_content = "".join(reasoning_chunks) else: # Extract clean text and reasoning from tags clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format) + if reasoning_chunks and not reasoning_content: + reasoning_content = "".join(reasoning_chunks) # Calculate streaming metrics end_time = time.perf_counter() @@ -525,7 +677,7 @@ class LLMNode(Node[LLMNodeData]): ) @staticmethod - def _image_file_to_markdown(file: "File", /): + def _image_file_to_markdown(file: File, /): text_chunk = f"![]({file.generate_url()})" return text_chunk @@ -774,7 +926,7 @@ class LLMNode(Node[LLMNodeData]): def fetch_prompt_messages( *, sys_query: str | None = None, - sys_files: Sequence["File"], + sys_files: Sequence[File], context: str | None = None, memory: TokenBufferMemory | None = None, model_config: ModelConfigWithCredentialsEntity, @@ -785,7 +937,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], tenant_id: str, - context_files: list["File"] | None = None, + context_files: list[File] | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] @@ -1137,7 +1289,7 @@ class LLMNode(Node[LLMNodeData]): *, invoke_result: LLMResult | LLMResultWithStructuredOutput, saver: LLMFileSaver, - file_outputs: list["File"], + file_outputs: list[File], reasoning_format: Literal["separated", "tagged"] = "tagged", request_latency: float | None = None, ) -> ModelInvokeCompletedEvent: @@ -1179,7 +1331,7 @@ class LLMNode(Node[LLMNodeData]): *, content: ImagePromptMessageContent, file_saver: LLMFileSaver, - ) -> "File": + ) -> File: """_save_multimodal_output saves multi-modal contents generated by LLM plugins. There are two kinds of multimodal outputs: @@ -1229,7 +1381,7 @@ class LLMNode(Node[LLMNodeData]): *, contents: str | list[PromptMessageContentUnionTypes] | None, file_saver: LLMFileSaver, - file_outputs: list["File"], + file_outputs: list[File], ) -> Generator[str, None, None]: """Convert intermediate prompt messages into strings and yield them to the caller. @@ -1266,6 +1418,732 @@ class LLMNode(Node[LLMNodeData]): def retry(self) -> bool: return self.node_data.retry_config.retry_enabled + @property + def tool_call_enabled(self) -> bool: + return ( + self.node_data.tools is not None + and len(self.node_data.tools) > 0 + and all(tool.enabled for tool in self.node_data.tools) + ) + + def _stream_llm_events( + self, + generator: Generator[NodeEventBase | LLMStructuredOutput, None, LLMGenerationData | None], + *, + model_instance: ModelInstance, + ) -> Generator[ + NodeEventBase, + None, + tuple[ + str, + str, + str, + str, + LLMUsage, + str | None, + LLMStructuredOutput | None, + LLMGenerationData | None, + ], + ]: + """ + Stream events and capture generator return value in one place. + Uses generator delegation so _run stays concise while still emitting events. + """ + clean_text = "" + reasoning_content = "" + generation_reasoning_content = "" + generation_clean_content = "" + usage = LLMUsage.empty_usage() + finish_reason: str | None = None + structured_output: LLMStructuredOutput | None = None + generation_data: LLMGenerationData | None = None + completed = False + + while True: + try: + event = next(generator) + except StopIteration as exc: + if isinstance(exc.value, LLMGenerationData): + generation_data = exc.value + break + + if completed: + # After completion we still drain to reach StopIteration.value + continue + + match event: + case StreamChunkEvent() | ThoughtChunkEvent(): + yield event + + case ModelInvokeCompletedEvent( + text=text, + usage=usage_event, + finish_reason=finish_reason_event, + reasoning_content=reasoning_event, + structured_output=structured_raw, + ): + clean_text = text + usage = usage_event + finish_reason = finish_reason_event + reasoning_content = reasoning_event or "" + generation_reasoning_content = reasoning_content + generation_clean_content = clean_text + + if self.node_data.reasoning_format == "tagged": + # Keep tagged text for output; also extract reasoning for generation field + generation_clean_content, generation_reasoning_content = LLMNode._split_reasoning( + clean_text, reasoning_format="separated" + ) + else: + clean_text, generation_reasoning_content = LLMNode._split_reasoning( + clean_text, self.node_data.reasoning_format + ) + generation_clean_content = clean_text + + structured_output = ( + LLMStructuredOutput(structured_output=structured_raw) if structured_raw else None + ) + + llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) + completed = True + + case LLMStructuredOutput(): + structured_output = event + + case _: + continue + + return ( + clean_text, + reasoning_content, + generation_reasoning_content, + generation_clean_content, + usage, + finish_reason, + structured_output, + generation_data, + ) + + def _invoke_llm_with_tools( + self, + model_instance: ModelInstance, + prompt_messages: Sequence[PromptMessage], + stop: Sequence[str] | None, + files: Sequence["File"], + variable_pool: VariablePool, + node_inputs: dict[str, Any], + process_data: dict[str, Any], + ) -> Generator[NodeEventBase, None, LLMGenerationData]: + """Invoke LLM with tools support (from Agent V2). + + Returns LLMGenerationData with text, reasoning_contents, tool_calls, usage, finish_reason, files + """ + # Get model features to determine strategy + model_features = self._get_model_features(model_instance) + + # Prepare tool instances + tool_instances = self._prepare_tool_instances(variable_pool) + + # Prepare prompt files (files that come from prompt variables, not vision files) + prompt_files = self._extract_prompt_files(variable_pool) + + # Use factory to create appropriate strategy + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=model_instance, + tools=tool_instances, + files=prompt_files, + max_iterations=self._node_data.max_iterations or 10, + context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id), + ) + + # Run strategy + outputs = strategy.run( + prompt_messages=list(prompt_messages), + model_parameters=self._node_data.model.completion_params, + stop=list(stop or []), + stream=True, + ) + + # Process outputs and return generation result + result = yield from self._process_tool_outputs(outputs) + return result + + def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]: + """Get model schema to determine features.""" + try: + model_type_instance = model_instance.model_type_instance + model_schema = model_type_instance.get_model_schema( + model_instance.model, + model_instance.credentials, + ) + return model_schema.features if model_schema and model_schema.features else [] + except Exception: + logger.warning("Failed to get model schema, assuming no special features") + return [] + + def _prepare_tool_instances(self, variable_pool: VariablePool) -> list[Tool]: + """Prepare tool instances from configuration.""" + tool_instances = [] + + if self._node_data.tools: + for tool in self._node_data.tools: + try: + # Process settings to extract the correct structure + processed_settings = {} + for key, value in tool.settings.items(): + if isinstance(value, dict) and "value" in value and isinstance(value["value"], dict): + # Extract the nested value if it has the ToolInput structure + if "type" in value["value"] and "value" in value["value"]: + processed_settings[key] = value["value"] + else: + processed_settings[key] = value + else: + processed_settings[key] = value + + # Merge parameters with processed settings (similar to Agent Node logic) + merged_parameters = {**tool.parameters, **processed_settings} + + # Create AgentToolEntity from ToolMetadata + agent_tool = AgentToolEntity( + provider_id=tool.provider_name, + provider_type=tool.type, + tool_name=tool.tool_name, + tool_parameters=merged_parameters, + plugin_unique_identifier=tool.plugin_unique_identifier, + credential_id=tool.credential_id, + ) + + # Get tool runtime from ToolManager + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=self.tenant_id, + app_id=self.app_id, + agent_tool=agent_tool, + invoke_from=self.invoke_from, + variable_pool=variable_pool, + ) + + # Apply custom description from extra field if available + if tool.extra.get("description") and tool_runtime.entity.description: + tool_runtime.entity.description.llm = ( + tool.extra.get("description") or tool_runtime.entity.description.llm + ) + + tool_instances.append(tool_runtime) + except Exception as e: + logger.warning("Failed to load tool %s: %s", tool, str(e)) + continue + + return tool_instances + + def _extract_prompt_files(self, variable_pool: VariablePool) -> list["File"]: + """Extract files from prompt template variables.""" + from core.variables import ArrayFileVariable, FileVariable + + files: list[File] = [] + + # Extract variables from prompt template + if isinstance(self._node_data.prompt_template, list): + for message in self._node_data.prompt_template: + if message.text: + parser = VariableTemplateParser(message.text) + variable_selectors = parser.extract_variable_selectors() + + for variable_selector in variable_selectors: + variable = variable_pool.get(variable_selector.value_selector) + if isinstance(variable, FileVariable) and variable.value: + files.append(variable.value) + elif isinstance(variable, ArrayFileVariable) and variable.value: + files.extend(variable.value) + + return files + + @staticmethod + def _serialize_tool_call(tool_call: ToolCallResult) -> dict[str, Any]: + """Convert ToolCallResult into JSON-friendly dict.""" + + def _file_to_ref(file: File) -> str | None: + # Align with streamed tool result events which carry file IDs + return file.id or file.related_id + + files = [] + for file in tool_call.files or []: + ref = _file_to_ref(file) + if ref: + files.append(ref) + + return { + "id": tool_call.id, + "name": tool_call.name, + "arguments": tool_call.arguments, + "output": tool_call.output, + "files": files, + "status": tool_call.status.value if hasattr(tool_call.status, "value") else tool_call.status, + "elapsed_time": tool_call.elapsed_time, + } + + def _generate_model_provider_icon_url(self, provider: str, dark: bool = False) -> str | None: + """Generate icon URL for model provider.""" + from yarl import URL + + from configs import dify_config + + icon_type = "icon_small_dark" if dark else "icon_small" + try: + return str( + URL(dify_config.CONSOLE_API_URL or "/") + / "console" + / "api" + / "workspaces" + / "current" + / "model-providers" + / provider + / icon_type + / "en_US" + ) + except Exception: + return None + + def _flush_model_segment( + self, + buffers: StreamBuffers, + trace_state: TraceState, + error: str | None = None, + ) -> None: + """Flush pending thought/content buffers into a single model trace segment.""" + if not buffers.pending_thought and not buffers.pending_content and not buffers.pending_tool_calls: + return + + now = time.perf_counter() + duration = now - trace_state.model_segment_start_time if trace_state.model_segment_start_time else 0.0 + + # Use pending_usage from trace_state (captured from THOUGHT log) + usage = trace_state.pending_usage + + # Generate model provider icon URL + provider = self._node_data.model.provider + model_name = self._node_data.model.name + model_icon = self._generate_model_provider_icon_url(provider) + model_icon_dark = self._generate_model_provider_icon_url(provider, dark=True) + + trace_state.trace_segments.append( + LLMTraceSegment( + type="model", + duration=duration, + usage=usage, + output=ModelTraceSegment( + text="".join(buffers.pending_content) if buffers.pending_content else None, + reasoning="".join(buffers.pending_thought) if buffers.pending_thought else None, + tool_calls=list(buffers.pending_tool_calls), + ), + provider=provider, + name=model_name, + icon=model_icon, + icon_dark=model_icon_dark, + error=error, + status="error" if error else "success", + ) + ) + buffers.pending_thought.clear() + buffers.pending_content.clear() + buffers.pending_tool_calls.clear() + trace_state.model_segment_start_time = None + trace_state.pending_usage = None + + def _handle_agent_log_output( + self, output: AgentLog, buffers: StreamBuffers, trace_state: TraceState, agent_context: AgentContext + ) -> Generator[NodeEventBase, None, None]: + payload = ToolLogPayload.from_log(output) + agent_log_event = AgentLogEvent( + message_id=output.id, + label=output.label, + node_execution_id=self.id, + parent_id=output.parent_id, + error=output.error, + status=output.status.value, + data=output.data, + metadata={k.value: v for k, v in output.metadata.items()}, + node_id=self._node_id, + ) + for log in agent_context.agent_logs: + if log.message_id == agent_log_event.message_id: + log.data = agent_log_event.data + log.status = agent_log_event.status + log.error = agent_log_event.error + log.label = agent_log_event.label + log.metadata = agent_log_event.metadata + break + else: + agent_context.agent_logs.append(agent_log_event) + + # Handle THOUGHT log completion - capture usage for model segment + if output.log_type == AgentLog.LogType.THOUGHT and output.status == AgentLog.LogStatus.SUCCESS: + llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE) if output.metadata else None + if llm_usage: + trace_state.pending_usage = llm_usage + + if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.START: + tool_name = payload.tool_name + tool_call_id = payload.tool_call_id + tool_arguments = json.dumps(payload.tool_args) if payload.tool_args else "" + + # Get icon from metadata (available at START) + tool_icon = output.metadata.get(AgentLog.LogMetadata.ICON) if output.metadata else None + tool_icon_dark = output.metadata.get(AgentLog.LogMetadata.ICON_DARK) if output.metadata else None + + if tool_call_id and tool_call_id not in trace_state.tool_call_index_map: + trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map) + + # Add tool call to pending list for model segment + buffers.pending_tool_calls.append(ToolCall(id=tool_call_id, name=tool_name, arguments=tool_arguments)) + + yield ToolCallChunkEvent( + selector=[self._node_id, "generation", "tool_calls"], + chunk=tool_arguments, + tool_call=ToolCall( + id=tool_call_id, + name=tool_name, + arguments=tool_arguments, + icon=tool_icon, + icon_dark=tool_icon_dark, + ), + is_final=False, + ) + + if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START: + tool_name = payload.tool_name + tool_output = payload.tool_output + tool_call_id = payload.tool_call_id + tool_files = payload.files if isinstance(payload.files, list) else [] + tool_error = payload.tool_error + tool_arguments = json.dumps(payload.tool_args) if payload.tool_args else "" + + if tool_call_id and tool_call_id not in trace_state.tool_call_index_map: + trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map) + + # Flush model segment before tool result processing + self._flush_model_segment(buffers, trace_state) + + if output.status == AgentLog.LogStatus.ERROR: + tool_error = output.error or payload.tool_error + if not tool_error and payload.meta: + tool_error = payload.meta.get("error") + else: + if payload.meta: + meta_error = payload.meta.get("error") + if meta_error: + tool_error = meta_error + + elapsed_time = output.metadata.get(AgentLog.LogMetadata.ELAPSED_TIME) if output.metadata else None + tool_provider = output.metadata.get(AgentLog.LogMetadata.PROVIDER) if output.metadata else None + tool_icon = output.metadata.get(AgentLog.LogMetadata.ICON) if output.metadata else None + tool_icon_dark = output.metadata.get(AgentLog.LogMetadata.ICON_DARK) if output.metadata else None + result_str = str(tool_output) if tool_output is not None else None + + tool_status: Literal["success", "error"] = "error" if tool_error else "success" + tool_call_segment = LLMTraceSegment( + type="tool", + duration=elapsed_time or 0.0, + usage=None, + output=ToolTraceSegment( + id=tool_call_id, + name=tool_name, + arguments=tool_arguments, + output=result_str, + ), + provider=tool_provider, + name=tool_name, + icon=tool_icon, + icon_dark=tool_icon_dark, + error=str(tool_error) if tool_error else None, + status=tool_status, + ) + trace_state.trace_segments.append(tool_call_segment) + if tool_call_id: + trace_state.tool_trace_map[tool_call_id] = tool_call_segment + + # Start new model segment tracking + trace_state.model_segment_start_time = time.perf_counter() + + yield ToolResultChunkEvent( + selector=[self._node_id, "generation", "tool_results"], + chunk=result_str or "", + tool_result=ToolResult( + id=tool_call_id, + name=tool_name, + output=result_str, + files=tool_files, + status=ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS, + elapsed_time=elapsed_time, + icon=tool_icon, + icon_dark=tool_icon_dark, + ), + is_final=False, + ) + + if buffers.current_turn_reasoning: + buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning)) + buffers.current_turn_reasoning.clear() + + def _handle_llm_chunk_output( + self, output: LLMResultChunk, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult + ) -> Generator[NodeEventBase, None, None]: + message = output.delta.message + + if message and message.content: + chunk_text = message.content + if isinstance(chunk_text, list): + chunk_text = "".join(getattr(content, "data", str(content)) for content in chunk_text) + else: + chunk_text = str(chunk_text) + + for kind, segment in buffers.think_parser.process(chunk_text): + if not segment and kind not in {"thought_start", "thought_end"}: + continue + + # Start tracking model segment time on first output + if trace_state.model_segment_start_time is None: + trace_state.model_segment_start_time = time.perf_counter() + + if kind == "thought_start": + yield ThoughtStartChunkEvent( + selector=[self._node_id, "generation", "thought"], + chunk="", + is_final=False, + ) + elif kind == "thought": + buffers.current_turn_reasoning.append(segment) + buffers.pending_thought.append(segment) + yield ThoughtChunkEvent( + selector=[self._node_id, "generation", "thought"], + chunk=segment, + is_final=False, + ) + elif kind == "thought_end": + yield ThoughtEndChunkEvent( + selector=[self._node_id, "generation", "thought"], + chunk="", + is_final=False, + ) + else: + aggregate.text += segment + buffers.pending_content.append(segment) + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk=segment, + is_final=False, + ) + yield StreamChunkEvent( + selector=[self._node_id, "generation", "content"], + chunk=segment, + is_final=False, + ) + + if output.delta.usage: + self._accumulate_usage(aggregate.usage, output.delta.usage) + + if output.delta.finish_reason: + aggregate.finish_reason = output.delta.finish_reason + + def _flush_remaining_stream( + self, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult + ) -> Generator[NodeEventBase, None, None]: + for kind, segment in buffers.think_parser.flush(): + if not segment and kind not in {"thought_start", "thought_end"}: + continue + + # Start tracking model segment time on first output + if trace_state.model_segment_start_time is None: + trace_state.model_segment_start_time = time.perf_counter() + + if kind == "thought_start": + yield ThoughtStartChunkEvent( + selector=[self._node_id, "generation", "thought"], + chunk="", + is_final=False, + ) + elif kind == "thought": + buffers.current_turn_reasoning.append(segment) + buffers.pending_thought.append(segment) + yield ThoughtChunkEvent( + selector=[self._node_id, "generation", "thought"], + chunk=segment, + is_final=False, + ) + elif kind == "thought_end": + yield ThoughtEndChunkEvent( + selector=[self._node_id, "generation", "thought"], + chunk="", + is_final=False, + ) + else: + aggregate.text += segment + buffers.pending_content.append(segment) + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk=segment, + is_final=False, + ) + yield StreamChunkEvent( + selector=[self._node_id, "generation", "content"], + chunk=segment, + is_final=False, + ) + + if buffers.current_turn_reasoning: + buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning)) + + # For final flush, use aggregate.usage if pending_usage is not set + # (e.g., for simple LLM calls without tool invocations) + if trace_state.pending_usage is None: + trace_state.pending_usage = aggregate.usage + + # Flush final model segment + self._flush_model_segment(buffers, trace_state) + + def _close_streams(self) -> Generator[NodeEventBase, None, None]: + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk="", + is_final=True, + ) + yield StreamChunkEvent( + selector=[self._node_id, "generation", "content"], + chunk="", + is_final=True, + ) + yield ThoughtChunkEvent( + selector=[self._node_id, "generation", "thought"], + chunk="", + is_final=True, + ) + yield ToolCallChunkEvent( + selector=[self._node_id, "generation", "tool_calls"], + chunk="", + tool_call=ToolCall( + id="", + name="", + arguments="", + ), + is_final=True, + ) + yield ToolResultChunkEvent( + selector=[self._node_id, "generation", "tool_results"], + chunk="", + tool_result=ToolResult( + id="", + name="", + output="", + files=[], + status=ToolResultStatus.SUCCESS, + ), + is_final=True, + ) + + def _build_generation_data( + self, + trace_state: TraceState, + agent_context: AgentContext, + aggregate: AggregatedResult, + buffers: StreamBuffers, + ) -> LLMGenerationData: + sequence: list[dict[str, Any]] = [] + reasoning_index = 0 + content_position = 0 + tool_call_seen_index: dict[str, int] = {} + for trace_segment in trace_state.trace_segments: + if trace_segment.type == "thought": + sequence.append({"type": "reasoning", "index": reasoning_index}) + reasoning_index += 1 + elif trace_segment.type == "content": + segment_text = trace_segment.text or "" + start = content_position + end = start + len(segment_text) + sequence.append({"type": "content", "start": start, "end": end}) + content_position = end + elif trace_segment.type == "tool_call": + tool_id = trace_segment.tool_call.id if trace_segment.tool_call and trace_segment.tool_call.id else "" + if tool_id not in tool_call_seen_index: + tool_call_seen_index[tool_id] = len(tool_call_seen_index) + sequence.append({"type": "tool_call", "index": tool_call_seen_index[tool_id]}) + + tool_calls_for_generation: list[ToolCallResult] = [] + for log in agent_context.agent_logs: + payload = ToolLogPayload.from_mapping(log.data or {}) + tool_call_id = payload.tool_call_id + if not tool_call_id or log.status == AgentLog.LogStatus.START.value: + continue + + tool_args = payload.tool_args + log_error = payload.tool_error + log_output = payload.tool_output + result_text = log_output or log_error or "" + status = ToolResultStatus.ERROR if log_error else ToolResultStatus.SUCCESS + tool_calls_for_generation.append( + ToolCallResult( + id=tool_call_id, + name=payload.tool_name, + arguments=json.dumps(tool_args) if tool_args else "", + output=result_text, + status=status, + elapsed_time=log.metadata.get(AgentLog.LogMetadata.ELAPSED_TIME) if log.metadata else None, + ) + ) + + tool_calls_for_generation.sort( + key=lambda item: trace_state.tool_call_index_map.get(item.id or "", len(trace_state.tool_call_index_map)) + ) + + return LLMGenerationData( + text=aggregate.text, + reasoning_contents=buffers.reasoning_per_turn, + tool_calls=tool_calls_for_generation, + sequence=sequence, + usage=aggregate.usage, + finish_reason=aggregate.finish_reason, + files=aggregate.files, + trace=trace_state.trace_segments, + ) + + def _process_tool_outputs( + self, + outputs: Generator[LLMResultChunk | AgentLog, None, AgentResult], + ) -> Generator[NodeEventBase, None, LLMGenerationData]: + """Process strategy outputs and convert to node events.""" + state = ToolOutputState() + + try: + for output in outputs: + if isinstance(output, AgentLog): + yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent) + else: + yield from self._handle_llm_chunk_output(output, state.stream, state.trace, state.aggregate) + except StopIteration as exception: + if isinstance(getattr(exception, "value", None), AgentResult): + state.agent.agent_result = exception.value + + if state.agent.agent_result: + state.aggregate.text = state.agent.agent_result.text or state.aggregate.text + state.aggregate.files = state.agent.agent_result.files + if state.agent.agent_result.usage: + state.aggregate.usage = state.agent.agent_result.usage + if state.agent.agent_result.finish_reason: + state.aggregate.finish_reason = state.agent.agent_result.finish_reason + + yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate) + yield from self._close_streams() + + return self._build_generation_data(state.trace, state.agent, state.aggregate, state.stream) + + def _accumulate_usage(self, total_usage: LLMUsage, delta_usage: LLMUsage) -> None: + """Accumulate LLM usage statistics.""" + total_usage.prompt_tokens += delta_usage.prompt_tokens + total_usage.completion_tokens += delta_usage.completion_tokens + total_usage.total_tokens += delta_usage.total_tokens + total_usage.prompt_price += delta_usage.prompt_price + total_usage.completion_price += delta_usage.completion_price + total_usage.total_price += delta_usage.total_price + def _combine_message_content_with_role( *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole diff --git a/api/core/workflow/nodes/node_factory.py b/api/core/workflow/nodes/node_factory.py index f177aef665..557d3a330d 100644 --- a/api/core/workflow/nodes/node_factory.py +++ b/api/core/workflow/nodes/node_factory.py @@ -113,7 +113,6 @@ class DifyNodeFactory(NodeFactory): code_providers=self._code_providers, code_limits=self._code_limits, ) - if node_type == NodeType.TEMPLATE_TRANSFORM: return TemplateTransformNode( id=node_id, diff --git a/api/core/workflow/nodes/variable_assigner/common/impl.py b/api/core/workflow/nodes/variable_assigner/common/impl.py deleted file mode 100644 index 050e213535..0000000000 --- a/api/core/workflow/nodes/variable_assigner/common/impl.py +++ /dev/null @@ -1,28 +0,0 @@ -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.variables.variables import Variable -from extensions.ext_database import db -from models import ConversationVariable - -from .exc import VariableOperatorNodeError - - -class ConversationVariableUpdaterImpl: - def update(self, conversation_id: str, variable: Variable): - stmt = select(ConversationVariable).where( - ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id - ) - with Session(db.engine) as session: - row = session.scalar(stmt) - if not row: - raise VariableOperatorNodeError("conversation variable not found in the database") - row.data = variable.model_dump_json() - session.commit() - - def flush(self): - pass - - -def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl: - return ConversationVariableUpdaterImpl() diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index da23207b62..d2ea7d94ea 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -1,9 +1,8 @@ -from collections.abc import Callable, Mapping, Sequence -from typing import TYPE_CHECKING, Any, TypeAlias +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any from core.variables import SegmentType, Variable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID -from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities import GraphInitParams from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult @@ -11,19 +10,14 @@ from core.workflow.nodes.base.node import Node from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from ..common.impl import conversation_variable_updater_factory from .node_data import VariableAssignerData, WriteMode if TYPE_CHECKING: from core.workflow.runtime import GraphRuntimeState -_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] - - class VariableAssignerNode(Node[VariableAssignerData]): node_type = NodeType.VARIABLE_ASSIGNER - _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY def __init__( self, @@ -31,7 +25,6 @@ class VariableAssignerNode(Node[VariableAssignerData]): config: Mapping[str, Any], graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory, ): super().__init__( id=id, @@ -39,7 +32,6 @@ class VariableAssignerNode(Node[VariableAssignerData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._conv_var_updater_factory = conv_var_updater_factory @classmethod def version(cls) -> str: @@ -96,16 +88,7 @@ class VariableAssignerNode(Node[VariableAssignerData]): # Over write the variable. self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) - # TODO: Move database operation to the pipeline. - # Update conversation variable. - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) - if not conversation_id: - raise VariableOperatorNodeError("conversation_id not found") - conv_var_updater = self._conv_var_updater_factory() - conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable) - conv_var_updater.flush() updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)] - return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={ diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index 389fb54d35..486e6bb6a7 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,24 +1,20 @@ import json from collections.abc import Mapping, MutableMapping, Sequence -from typing import Any, cast +from typing import TYPE_CHECKING, Any -from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import SegmentType, Variable from core.variables.consts import SELECTORS_LENGTH from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID -from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory from . import helpers from .entities import VariableAssignerNodeData, VariableOperationItem from .enums import InputType, Operation from .exc import ( - ConversationIDNotFoundError, InputTypeNotSupportedError, InvalidDataError, InvalidInputValueError, @@ -26,6 +22,10 @@ from .exc import ( VariableNotFoundError, ) +if TYPE_CHECKING: + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState + def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): selector_node_id = item.variable_selector[0] @@ -53,6 +53,20 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_ class VariableAssignerNode(Node[VariableAssignerNodeData]): node_type = NodeType.VARIABLE_ASSIGNER + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + ): + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: """ Check if this Variable Assigner node blocks the output of specific variables. @@ -70,9 +84,6 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): return False - def _conv_var_updater_factory(self) -> ConversationVariableUpdater: - return conversation_variable_updater_factory() - @classmethod def version(cls) -> str: return "2" @@ -179,26 +190,12 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): # remove the duplicated items first. updated_variable_selectors = list(set(map(tuple, updated_variable_selectors))) - conv_var_updater = self._conv_var_updater_factory() - # Update variables for selector in updated_variable_selectors: variable = self.graph_runtime_state.variable_pool.get(selector) if not isinstance(variable, Variable): raise VariableNotFoundError(variable_selector=selector) process_data[variable.name] = variable.value - if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID: - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) - if not conversation_id: - if self.invoke_from != InvokeFrom.DEBUGGER: - raise ConversationIDNotFoundError - else: - conversation_id = conversation_id.value - conv_var_updater.update( - conversation_id=cast(str, conversation_id), - variable=variable, - ) - conv_var_updater.flush() updated_variables = [ common_helpers.variable_to_processed_data(selector, seg) for selector in updated_variable_selectors diff --git a/api/core/workflow/repositories/draft_variable_repository.py b/api/core/workflow/repositories/draft_variable_repository.py index 97bfcd5666..66ef714c16 100644 --- a/api/core/workflow/repositories/draft_variable_repository.py +++ b/api/core/workflow/repositories/draft_variable_repository.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc from collections.abc import Mapping from typing import Any, Protocol @@ -23,7 +25,7 @@ class DraftVariableSaverFactory(Protocol): node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, - ) -> "DraftVariableSaver": + ) -> DraftVariableSaver: pass diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/core/workflow/runtime/graph_runtime_state.py index 1561b789df..401cecc162 100644 --- a/api/core/workflow/runtime/graph_runtime_state.py +++ b/api/core/workflow/runtime/graph_runtime_state.py @@ -2,6 +2,7 @@ from __future__ import annotations import importlib import json +import threading from collections.abc import Mapping, Sequence from copy import deepcopy from dataclasses import dataclass @@ -168,6 +169,7 @@ class GraphRuntimeState: self._pending_response_coordinator_dump: str | None = None self._pending_graph_execution_workflow_id: str | None = None self._paused_nodes: set[str] = set() + self.stop_event: threading.Event = threading.Event() if graph is not None: self.attach_graph(graph) diff --git a/api/core/workflow/runtime/graph_runtime_state_protocol.py b/api/core/workflow/runtime/graph_runtime_state_protocol.py index 5e0878e873..bfbb5ba704 100644 --- a/api/core/workflow/runtime/graph_runtime_state_protocol.py +++ b/api/core/workflow/runtime/graph_runtime_state_protocol.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import Any, Protocol from core.model_runtime.entities.llm_entities import LLMUsage @@ -9,7 +9,7 @@ from core.workflow.system_variable import SystemVariableReadOnlyView class ReadOnlyVariablePool(Protocol): """Read-only interface for VariablePool.""" - def get(self, node_id: str, variable_key: str) -> Segment | None: + def get(self, selector: Sequence[str], /) -> Segment | None: """Get a variable value (read-only).""" ... diff --git a/api/core/workflow/runtime/read_only_wrappers.py b/api/core/workflow/runtime/read_only_wrappers.py index 8539727fd6..d3e4c60d9b 100644 --- a/api/core/workflow/runtime/read_only_wrappers.py +++ b/api/core/workflow/runtime/read_only_wrappers.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from copy import deepcopy from typing import Any @@ -18,9 +18,9 @@ class ReadOnlyVariablePoolWrapper: def __init__(self, variable_pool: VariablePool) -> None: self._variable_pool = variable_pool - def get(self, node_id: str, variable_key: str) -> Segment | None: + def get(self, selector: Sequence[str], /) -> Segment | None: """Return a copy of a variable value if present.""" - value = self._variable_pool.get([node_id, variable_key]) + value = self._variable_pool.get(selector) return deepcopy(value) if value is not None else None def get_all_by_node(self, node_id: str) -> Mapping[str, object]: diff --git a/api/core/workflow/runtime/variable_pool.py b/api/core/workflow/runtime/variable_pool.py index 7fbaec9e70..85ceb9d59e 100644 --- a/api/core/workflow/runtime/variable_pool.py +++ b/api/core/workflow/runtime/variable_pool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re from collections import defaultdict from collections.abc import Mapping, Sequence @@ -267,6 +269,6 @@ class VariablePool(BaseModel): self.add(selector, value) @classmethod - def empty(cls) -> "VariablePool": + def empty(cls) -> VariablePool: """Create an empty variable pool.""" return cls(system_variables=SystemVariable.empty()) diff --git a/api/core/workflow/system_variable.py b/api/core/workflow/system_variable.py index ad925912a4..cda8091771 100644 --- a/api/core/workflow/system_variable.py +++ b/api/core/workflow/system_variable.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Mapping, Sequence from types import MappingProxyType from typing import Any @@ -70,7 +72,7 @@ class SystemVariable(BaseModel): return data @classmethod - def empty(cls) -> "SystemVariable": + def empty(cls) -> SystemVariable: return cls() def to_dict(self) -> dict[SystemVariableKey, Any]: @@ -114,7 +116,7 @@ class SystemVariable(BaseModel): d[SystemVariableKey.TIMESTAMP] = self.timestamp return d - def as_view(self) -> "SystemVariableReadOnlyView": + def as_view(self) -> SystemVariableReadOnlyView: return SystemVariableReadOnlyView(self) diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 5a69eb15ac..c0279f893b 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -3,8 +3,9 @@ set -e # Set UTF-8 encoding to address potential encoding issues in containerized environments -export LANG=${LANG:-en_US.UTF-8} -export LC_ALL=${LC_ALL:-en_US.UTF-8} +# Use C.UTF-8 which is universally available in all containers +export LANG=${LANG:-C.UTF-8} +export LC_ALL=${LC_ALL:-C.UTF-8} export PYTHONIOENCODING=${PYTHONIOENCODING:-utf-8} if [[ "${MIGRATION_ENABLED}" == "true" ]]; then diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index cf994c11df..7d13f0c061 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -6,6 +6,7 @@ BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEAD SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization") AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN) FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN) +EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE) EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id") @@ -42,10 +43,28 @@ def init_app(app: DifyApp): _apply_cors_once( web_bp, - resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}}, - supports_credentials=True, - allow_headers=list(AUTHENTICATED_HEADERS), - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + resources={ + # Embedded bot endpoints (unauthenticated, cross-origin safe) + r"^/chat-messages$": { + "origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS, + "supports_credentials": False, + "allow_headers": list(EMBED_HEADERS), + "methods": ["GET", "POST", "OPTIONS"], + }, + r"^/chat-messages/.*": { + "origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS, + "supports_credentials": False, + "allow_headers": list(EMBED_HEADERS), + "methods": ["GET", "POST", "OPTIONS"], + }, + # Default web application endpoints (authenticated) + r"/*": { + "origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS, + "supports_credentials": True, + "allow_headers": list(AUTHENTICATED_HEADERS), + "methods": ["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + }, + }, expose_headers=list(EXPOSED_HEADERS), ) app.register_blueprint(web_bp) diff --git a/api/extensions/logstore/aliyun_logstore.py b/api/extensions/logstore/aliyun_logstore.py index 22d1f473a3..8c64a25be4 100644 --- a/api/extensions/logstore/aliyun_logstore.py +++ b/api/extensions/logstore/aliyun_logstore.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import os import threading @@ -33,7 +35,7 @@ class AliyunLogStore: Ensures only one instance exists to prevent multiple PG connection pools. """ - _instance: "AliyunLogStore | None" = None + _instance: AliyunLogStore | None = None _initialized: bool = False # Track delayed PG connection for newly created projects @@ -66,7 +68,7 @@ class AliyunLogStore: "\t", ] - def __new__(cls) -> "AliyunLogStore": + def __new__(cls) -> AliyunLogStore: """Implement singleton pattern.""" if cls._instance is None: cls._instance = super().__new__(cls) diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index 51a97b20f8..1d9911465b 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -5,6 +5,8 @@ automatic cleanup, backup and restore. Supports complete lifecycle management for knowledge base files. """ +from __future__ import annotations + import json import logging import operator @@ -48,7 +50,7 @@ class FileMetadata: return data @classmethod - def from_dict(cls, data: dict) -> "FileMetadata": + def from_dict(cls, data: dict) -> FileMetadata: """Create instance from dictionary""" data = data.copy() data["created_at"] = datetime.fromisoformat(data["created_at"]) diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index d8ae0ad8b8..fe59cdcbb4 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -169,6 +169,7 @@ class MessageDetail(ResponseModel): status: str error: str | None = None parent_message_id: str | None = None + generation_detail: JSONValue | None = Field(default=None, validation_alias="generation_detail_dict") @field_validator("inputs", mode="before") @classmethod diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index 70138404c7..913fb675f9 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -1,93 +1,85 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -from libs.helper import TimestampField +from datetime import datetime -upload_config_fields = { - "file_size_limit": fields.Integer, - "batch_count_limit": fields.Integer, - "image_file_size_limit": fields.Integer, - "video_file_size_limit": fields.Integer, - "audio_file_size_limit": fields.Integer, - "workflow_file_upload_limit": fields.Integer, - "image_file_batch_limit": fields.Integer, - "single_chunk_attachment_limit": fields.Integer, -} +from pydantic import BaseModel, ConfigDict, field_validator -def build_upload_config_model(api_or_ns: Namespace): - """Build the upload config model for the API or Namespace. - - Args: - api_or_ns: Flask-RestX Api or Namespace instance - - Returns: - The registered model - """ - return api_or_ns.model("UploadConfig", upload_config_fields) +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) -file_fields = { - "id": fields.String, - "name": fields.String, - "size": fields.Integer, - "extension": fields.String, - "mime_type": fields.String, - "created_by": fields.String, - "created_at": TimestampField, - "preview_url": fields.String, - "source_url": fields.String, -} +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -def build_file_model(api_or_ns: Namespace): - """Build the file model for the API or Namespace. - - Args: - api_or_ns: Flask-RestX Api or Namespace instance - - Returns: - The registered model - """ - return api_or_ns.model("File", file_fields) +class UploadConfig(ResponseModel): + file_size_limit: int + batch_count_limit: int + file_upload_limit: int | None = None + image_file_size_limit: int + video_file_size_limit: int + audio_file_size_limit: int + workflow_file_upload_limit: int + image_file_batch_limit: int + single_chunk_attachment_limit: int + attachment_image_file_size_limit: int | None = None -remote_file_info_fields = { - "file_type": fields.String(attribute="file_type"), - "file_length": fields.Integer(attribute="file_length"), -} +class FileResponse(ResponseModel): + id: str + name: str + size: int + extension: str | None = None + mime_type: str | None = None + created_by: str | None = None + created_at: int | None = None + preview_url: str | None = None + source_url: str | None = None + original_url: str | None = None + user_id: str | None = None + tenant_id: str | None = None + conversation_id: str | None = None + file_key: str | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -def build_remote_file_info_model(api_or_ns: Namespace): - """Build the remote file info model for the API or Namespace. - - Args: - api_or_ns: Flask-RestX Api or Namespace instance - - Returns: - The registered model - """ - return api_or_ns.model("RemoteFileInfo", remote_file_info_fields) +class RemoteFileInfo(ResponseModel): + file_type: str + file_length: int -file_fields_with_signed_url = { - "id": fields.String, - "name": fields.String, - "size": fields.Integer, - "extension": fields.String, - "url": fields.String, - "mime_type": fields.String, - "created_by": fields.String, - "created_at": TimestampField, -} +class FileWithSignedUrl(ResponseModel): + id: str + name: str + size: int + extension: str | None = None + url: str | None = None + mime_type: str | None = None + created_by: str | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -def build_file_with_signed_url_model(api_or_ns: Namespace): - """Build the file with signed URL model for the API or Namespace. - - Args: - api_or_ns: Flask-RestX Api or Namespace instance - - Returns: - The registered model - """ - return api_or_ns.model("FileWithSignedUrl", file_fields_with_signed_url) +__all__ = [ + "FileResponse", + "FileWithSignedUrl", + "RemoteFileInfo", + "UploadConfig", +] diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 2bba198fa8..797f01c00c 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -59,6 +59,7 @@ class MessageListItem(ResponseModel): message_files: list[MessageFile] status: str error: str | None = None + generation_detail: JSONValueType | None = Field(default=None, validation_alias="generation_detail_dict") @field_validator("inputs", mode="before") @classmethod diff --git a/api/fields/rag_pipeline_fields.py b/api/fields/rag_pipeline_fields.py index f9e858c68b..97c02e7085 100644 --- a/api/fields/rag_pipeline_fields.py +++ b/api/fields/rag_pipeline_fields.py @@ -1,4 +1,4 @@ -from flask_restx import fields # type: ignore +from flask_restx import fields from fields.workflow_fields import workflow_partial_fields from libs.helper import AppIconUrlField, TimestampField diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 476025064f..1b2948811b 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -81,6 +81,7 @@ workflow_run_detail_fields = { "inputs": fields.Raw(attribute="inputs_dict"), "status": fields.String, "outputs": fields.Raw(attribute="outputs_dict"), + "outputs_as_generation": fields.Boolean, "error": fields.String, "elapsed_time": fields.Float, "total_tokens": fields.Integer, @@ -129,6 +130,7 @@ workflow_run_node_execution_fields = { "inputs_truncated": fields.Boolean, "outputs_truncated": fields.Boolean, "process_data_truncated": fields.Boolean, + "generation_detail": fields.Raw, } workflow_run_node_execution_list_fields = { diff --git a/api/libs/broadcast_channel/channel.py b/api/libs/broadcast_channel/channel.py index 5bbf0c79a3..d4cb3e9971 100644 --- a/api/libs/broadcast_channel/channel.py +++ b/api/libs/broadcast_channel/channel.py @@ -2,6 +2,8 @@ Broadcast channel for Pub/Sub messaging. """ +from __future__ import annotations + import types from abc import abstractmethod from collections.abc import Iterator @@ -129,6 +131,6 @@ class BroadcastChannel(Protocol): """ @abstractmethod - def topic(self, topic: str) -> "Topic": + def topic(self, topic: str) -> Topic: """topic returns a `Topic` instance for the given topic name.""" ... diff --git a/api/libs/broadcast_channel/redis/channel.py b/api/libs/broadcast_channel/redis/channel.py index 1fc3db8156..5bb4f579c1 100644 --- a/api/libs/broadcast_channel/redis/channel.py +++ b/api/libs/broadcast_channel/redis/channel.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from redis import Redis @@ -20,7 +22,7 @@ class BroadcastChannel: ): self._client = redis_client - def topic(self, topic: str) -> "Topic": + def topic(self, topic: str) -> Topic: return Topic(self._client, topic) diff --git a/api/libs/broadcast_channel/redis/sharded_channel.py b/api/libs/broadcast_channel/redis/sharded_channel.py index 16e3a80ee1..d190c51bbc 100644 --- a/api/libs/broadcast_channel/redis/sharded_channel.py +++ b/api/libs/broadcast_channel/redis/sharded_channel.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from redis import Redis @@ -18,7 +20,7 @@ class ShardedRedisBroadcastChannel: ): self._client = redis_client - def topic(self, topic: str) -> "ShardedTopic": + def topic(self, topic: str) -> ShardedTopic: return ShardedTopic(self._client, topic) diff --git a/api/libs/email_i18n.py b/api/libs/email_i18n.py index ff74ccbe8e..0828cf80bf 100644 --- a/api/libs/email_i18n.py +++ b/api/libs/email_i18n.py @@ -6,6 +6,8 @@ in Dify. It follows Domain-Driven Design principles with proper type hints and eliminates the need for repetitive language switching logic. """ +from __future__ import annotations + from dataclasses import dataclass from enum import StrEnum, auto from typing import Any, Protocol @@ -53,7 +55,7 @@ class EmailLanguage(StrEnum): ZH_HANS = "zh-Hans" @classmethod - def from_language_code(cls, language_code: str) -> "EmailLanguage": + def from_language_code(cls, language_code: str) -> EmailLanguage: """Convert a language code to EmailLanguage with fallback to English.""" if language_code == "zh-Hans": return cls.ZH_HANS diff --git a/api/libs/helper.py b/api/libs/helper.py index 74e1808e49..07c4823727 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -32,6 +32,38 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def escape_like_pattern(pattern: str) -> str: + """ + Escape special characters in a string for safe use in SQL LIKE patterns. + + This function escapes the special characters used in SQL LIKE patterns: + - Backslash (\\) -> \\ + - Percent (%) -> \\% + - Underscore (_) -> \\_ + + The escaped pattern can then be safely used in SQL LIKE queries with the + ESCAPE '\\' clause to prevent SQL injection via LIKE wildcards. + + Args: + pattern: The string pattern to escape + + Returns: + Escaped string safe for use in SQL LIKE queries + + Examples: + >>> escape_like_pattern("50% discount") + '50\\% discount' + >>> escape_like_pattern("test_data") + 'test\\_data' + >>> escape_like_pattern("path\\to\\file") + 'path\\\\to\\\\file' + """ + if not pattern: + return pattern + # Escape backslash first, then percent and underscore + return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + + def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None: """ Extract tenant_id from Account or EndUser object. diff --git a/api/migrations/versions/2025_12_17_1617-85c8b4a64f53_add_llm_generation_detail_table.py b/api/migrations/versions/2025_12_17_1617-85c8b4a64f53_add_llm_generation_detail_table.py new file mode 100644 index 0000000000..60786a720c --- /dev/null +++ b/api/migrations/versions/2025_12_17_1617-85c8b4a64f53_add_llm_generation_detail_table.py @@ -0,0 +1,46 @@ +"""add llm generation detail table. + +Revision ID: 85c8b4a64f53 +Revises: 7bb281b7a422 +Create Date: 2025-12-10 16:17:46.597669 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '85c8b4a64f53' +down_revision = '03ea244985ce' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('llm_generation_details', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('message_id', models.types.StringUUID(), nullable=True), + sa.Column('workflow_run_id', models.types.StringUUID(), nullable=True), + sa.Column('node_id', sa.String(length=255), nullable=True), + sa.Column('reasoning_content', models.types.LongText(), nullable=True), + sa.Column('tool_calls', models.types.LongText(), nullable=True), + sa.Column('sequence', models.types.LongText(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.CheckConstraint('(message_id IS NOT NULL AND workflow_run_id IS NULL AND node_id IS NULL) OR (message_id IS NULL AND workflow_run_id IS NOT NULL AND node_id IS NOT NULL)', name=op.f('llm_generation_details_ck_llm_generation_detail_assoc_mode_check')), + sa.PrimaryKeyConstraint('id', name='llm_generation_detail_pkey'), + sa.UniqueConstraint('message_id', name=op.f('llm_generation_details_message_id_key')) + ) + with op.batch_alter_table('llm_generation_details', schema=None) as batch_op: + batch_op.create_index('idx_llm_generation_detail_message', ['message_id'], unique=False) + batch_op.create_index('idx_llm_generation_detail_workflow', ['workflow_run_id', 'node_id'], unique=False) + + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('llm_generation_details') + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index 87dbdd6c70..6b562b676e 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -49,6 +49,7 @@ from .model import ( EndUser, IconType, InstalledApp, + LLMGenerationDetail, Message, MessageAgentThought, MessageAnnotation, @@ -155,6 +156,7 @@ __all__ = [ "IconType", "InstalledApp", "InvitationCode", + "LLMGenerationDetail", "LoadBalancingModelConfig", "Message", "MessageAgentThought", diff --git a/api/models/model.py b/api/models/model.py index 6cfcc2859d..617e70b992 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import re import uuid @@ -5,7 +7,7 @@ from collections.abc import Mapping from datetime import datetime from decimal import Decimal from enum import StrEnum, auto -from typing import TYPE_CHECKING, Any, Literal, Optional, cast +from typing import TYPE_CHECKING, Any, Literal, cast from uuid import uuid4 import sqlalchemy as sa @@ -31,6 +33,8 @@ from .provider_ids import GenericProviderID from .types import LongText, StringUUID if TYPE_CHECKING: + from core.app.entities.llm_generation_entities import LLMGenerationDetailData + from .workflow import Workflow @@ -54,7 +58,7 @@ class AppMode(StrEnum): RAG_PIPELINE = "rag-pipeline" @classmethod - def value_of(cls, value: str) -> "AppMode": + def value_of(cls, value: str) -> AppMode: """ Get value of given mode. @@ -70,6 +74,7 @@ class AppMode(StrEnum): class IconType(StrEnum): IMAGE = auto() EMOJI = auto() + LINK = auto() class App(Base): @@ -81,7 +86,7 @@ class App(Base): name: Mapped[str] = mapped_column(String(255)) description: Mapped[str] = mapped_column(LongText, default=sa.text("''")) mode: Mapped[str] = mapped_column(String(255)) - icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji + icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji, link icon = mapped_column(String(255)) icon_background: Mapped[str | None] = mapped_column(String(255)) app_model_config_id = mapped_column(StringUUID, nullable=True) @@ -120,19 +125,19 @@ class App(Base): return "" @property - def site(self) -> Optional["Site"]: + def site(self) -> Site | None: site = db.session.query(Site).where(Site.app_id == self.id).first() return site @property - def app_model_config(self) -> Optional["AppModelConfig"]: + def app_model_config(self) -> AppModelConfig | None: if self.app_model_config_id: return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() return None @property - def workflow(self) -> Optional["Workflow"]: + def workflow(self) -> Workflow | None: if self.workflow_id: from .workflow import Workflow @@ -287,7 +292,7 @@ class App(Base): return deleted_tools @property - def tags(self) -> list["Tag"]: + def tags(self) -> list[Tag]: tags = ( db.session.query(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) @@ -1193,7 +1198,7 @@ class Message(Base): return json.loads(self.message_metadata) if self.message_metadata else {} @property - def agent_thoughts(self) -> list["MessageAgentThought"]: + def agent_thoughts(self) -> list[MessageAgentThought]: return ( db.session.query(MessageAgentThought) .where(MessageAgentThought.message_id == self.id) @@ -1201,6 +1206,18 @@ class Message(Base): .all() ) + # FIXME (Novice) -- It's easy to cause N+1 query problem here. + @property + def generation_detail(self) -> dict[str, Any] | None: + """ + Get LLM generation detail for this message. + Returns the detail as a dictionary or None if not found. + """ + detail = db.session.query(LLMGenerationDetail).filter_by(message_id=self.id).first() + if detail: + return detail.to_dict() + return None + @property def retriever_resources(self) -> Any: return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else [] @@ -1306,7 +1323,7 @@ class Message(Base): } @classmethod - def from_dict(cls, data: dict[str, Any]) -> "Message": + def from_dict(cls, data: dict[str, Any]) -> Message: return cls( id=data["id"], app_id=data["app_id"], @@ -1533,7 +1550,7 @@ class OperationLog(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) action: Mapped[str] = mapped_column(String(255), nullable=False) - content: Mapped[Any] = mapped_column(sa.JSON) + content: Mapped[Any | None] = mapped_column(sa.JSON, nullable=True) created_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) @@ -2070,3 +2087,87 @@ class TraceAppConfig(TypeBase): "created_at": str(self.created_at) if self.created_at else None, "updated_at": str(self.updated_at) if self.updated_at else None, } + + +class LLMGenerationDetail(Base): + """ + Store LLM generation details including reasoning process and tool calls. + + Association (choose one): + - For apps with Message: use message_id (one-to-one) + - For Workflow: use workflow_run_id + node_id (one run may have multiple LLM nodes) + """ + + __tablename__ = "llm_generation_details" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="llm_generation_detail_pkey"), + sa.Index("idx_llm_generation_detail_message", "message_id"), + sa.Index("idx_llm_generation_detail_workflow", "workflow_run_id", "node_id"), + sa.CheckConstraint( + "(message_id IS NOT NULL AND workflow_run_id IS NULL AND node_id IS NULL)" + " OR " + "(message_id IS NULL AND workflow_run_id IS NOT NULL AND node_id IS NOT NULL)", + name="ck_llm_generation_detail_assoc_mode", + ), + ) + + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + + # Association fields (choose one) + message_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, unique=True) + workflow_run_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + node_id: Mapped[str | None] = mapped_column(String(255), nullable=True) + + # Core data as JSON strings + reasoning_content: Mapped[str | None] = mapped_column(LongText) + tool_calls: Mapped[str | None] = mapped_column(LongText) + sequence: Mapped[str | None] = mapped_column(LongText) + + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + + def to_domain_model(self) -> "LLMGenerationDetailData": + """Convert to Pydantic domain model with proper validation.""" + from core.app.entities.llm_generation_entities import LLMGenerationDetailData + + return LLMGenerationDetailData( + reasoning_content=json.loads(self.reasoning_content) if self.reasoning_content else [], + tool_calls=json.loads(self.tool_calls) if self.tool_calls else [], + sequence=json.loads(self.sequence) if self.sequence else [], + ) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for API response.""" + return self.to_domain_model().to_response_dict() + + @classmethod + def from_domain_model( + cls, + data: "LLMGenerationDetailData", + *, + tenant_id: str, + app_id: str, + message_id: str | None = None, + workflow_run_id: str | None = None, + node_id: str | None = None, + ) -> "LLMGenerationDetail": + """Create from Pydantic domain model.""" + # Enforce association mode at object creation time as well. + message_mode = message_id is not None + workflow_mode = workflow_run_id is not None or node_id is not None + if message_mode and workflow_mode: + raise ValueError("LLMGenerationDetail cannot set both message_id and workflow_run_id/node_id.") + if not message_mode and not (workflow_run_id and node_id): + raise ValueError("LLMGenerationDetail requires either message_id or workflow_run_id+node_id.") + + return cls( + tenant_id=tenant_id, + app_id=app_id, + message_id=message_id, + workflow_run_id=workflow_run_id, + node_id=node_id, + reasoning_content=json.dumps(data.reasoning_content) if data.reasoning_content else None, + tool_calls=json.dumps([tc.model_dump() for tc in data.tool_calls]) if data.tool_calls else None, + sequence=json.dumps([seg.model_dump() for seg in data.sequence]) if data.sequence else None, + ) diff --git a/api/models/provider.py b/api/models/provider.py index 2afd8c5329..441b54c797 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime from enum import StrEnum, auto from functools import cached_property @@ -19,7 +21,7 @@ class ProviderType(StrEnum): SYSTEM = auto() @staticmethod - def value_of(value: str) -> "ProviderType": + def value_of(value: str) -> ProviderType: for member in ProviderType: if member.value == value: return member @@ -37,7 +39,7 @@ class ProviderQuotaType(StrEnum): """hosted trial quota""" @staticmethod - def value_of(value: str) -> "ProviderQuotaType": + def value_of(value: str) -> ProviderQuotaType: for member in ProviderQuotaType: if member.value == value: return member @@ -76,7 +78,7 @@ class Provider(TypeBase): quota_type: Mapped[str | None] = mapped_column(String(40), nullable=True, server_default=text("''"), default="") quota_limit: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=None) - quota_used: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, default=0) + quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=0) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False diff --git a/api/models/tools.py b/api/models/tools.py index e4f9bcb582..e7b98dcf27 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from datetime import datetime from decimal import Decimal @@ -167,11 +169,11 @@ class ApiToolProvider(TypeBase): ) @property - def schema_type(self) -> "ApiProviderSchemaType": + def schema_type(self) -> ApiProviderSchemaType: return ApiProviderSchemaType.value_of(self.schema_type_str) @property - def tools(self) -> list["ApiToolBundle"]: + def tools(self) -> list[ApiToolBundle]: return [ApiToolBundle.model_validate(tool) for tool in json.loads(self.tools_str)] @property @@ -267,7 +269,7 @@ class WorkflowToolProvider(TypeBase): return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() @property - def parameter_configurations(self) -> list["WorkflowToolParameterConfiguration"]: + def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]: return [ WorkflowToolParameterConfiguration.model_validate(config) for config in json.loads(self.parameter_configuration) @@ -359,7 +361,7 @@ class MCPToolProvider(TypeBase): except (json.JSONDecodeError, TypeError): return [] - def to_entity(self) -> "MCPProviderEntity": + def to_entity(self) -> MCPProviderEntity: """Convert to domain entity""" from core.entities.mcp_provider import MCPProviderEntity @@ -533,5 +535,5 @@ class DeprecatedPublishedAppTool(TypeBase): ) @property - def description_i18n(self) -> "I18nObject": + def description_i18n(self) -> I18nObject: return I18nObject.model_validate(json.loads(self.description)) diff --git a/api/models/trigger.py b/api/models/trigger.py index 87e2a5ccfc..209345eb84 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -415,7 +415,7 @@ class AppTrigger(TypeBase): node_id: Mapped[str | None] = mapped_column(String(64), nullable=False) trigger_type: Mapped[str] = mapped_column(EnumText(AppTriggerType, length=50), nullable=False) title: Mapped[str] = mapped_column(String(255), nullable=False) - provider_name: Mapped[str] = mapped_column(String(255), server_default="", default="") # why it is nullable? + provider_name: Mapped[str | None] = mapped_column(String(255), nullable=True, server_default="", default="") status: Mapped[str] = mapped_column( EnumText(AppTriggerStatus, length=50), nullable=False, default=AppTriggerStatus.ENABLED ) diff --git a/api/models/workflow.py b/api/models/workflow.py index 853d5afefc..7e8a0f7c2e 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import json import logging from collections.abc import Generator, Mapping, Sequence from datetime import datetime from enum import StrEnum -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Union, cast from uuid import uuid4 import sqlalchemy as sa @@ -57,6 +59,37 @@ from .types import EnumText, LongText, StringUUID logger = logging.getLogger(__name__) +def is_generation_outputs(outputs: Mapping[str, Any]) -> bool: + if not outputs: + return False + + allowed_sequence_types = {"reasoning", "content", "tool_call"} + + def valid_sequence_item(item: Mapping[str, Any]) -> bool: + return isinstance(item, Mapping) and item.get("type") in allowed_sequence_types + + def valid_value(value: Any) -> bool: + if not isinstance(value, Mapping): + return False + + content = value.get("content") + reasoning_content = value.get("reasoning_content") + tool_calls = value.get("tool_calls") + sequence = value.get("sequence") + + return ( + isinstance(content, str) + and isinstance(reasoning_content, list) + and all(isinstance(item, str) for item in reasoning_content) + and isinstance(tool_calls, list) + and all(isinstance(item, Mapping) for item in tool_calls) + and isinstance(sequence, list) + and all(valid_sequence_item(item) for item in sequence) + ) + + return all(valid_value(value) for value in outputs.values()) + + class WorkflowType(StrEnum): """ Workflow Type Enum @@ -67,7 +100,7 @@ class WorkflowType(StrEnum): RAG_PIPELINE = "rag-pipeline" @classmethod - def value_of(cls, value: str) -> "WorkflowType": + def value_of(cls, value: str) -> WorkflowType: """ Get value of given mode. @@ -80,7 +113,7 @@ class WorkflowType(StrEnum): raise ValueError(f"invalid workflow type value {value}") @classmethod - def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": + def from_app_mode(cls, app_mode: Union[str, AppMode]) -> WorkflowType: """ Get workflow type from app mode. @@ -181,7 +214,7 @@ class Workflow(Base): # bug rag_pipeline_variables: list[dict], marked_name: str = "", marked_comment: str = "", - ) -> "Workflow": + ) -> Workflow: workflow = Workflow() workflow.id = str(uuid4()) workflow.tenant_id = tenant_id @@ -619,7 +652,7 @@ class WorkflowRun(Base): finished_at: Mapped[datetime | None] = mapped_column(DateTime) exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) - pause: Mapped[Optional["WorkflowPause"]] = orm.relationship( + pause: Mapped[WorkflowPause | None] = orm.relationship( "WorkflowPause", primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)", uselist=False, @@ -664,6 +697,10 @@ class WorkflowRun(Base): def workflow(self): return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first() + @property + def outputs_as_generation(self): + return is_generation_outputs(self.outputs_dict) + def to_dict(self): return { "id": self.id, @@ -677,6 +714,7 @@ class WorkflowRun(Base): "inputs": self.inputs_dict, "status": self.status, "outputs": self.outputs_dict, + "outputs_as_generation": self.outputs_as_generation, "error": self.error, "elapsed_time": self.elapsed_time, "total_tokens": self.total_tokens, @@ -689,7 +727,7 @@ class WorkflowRun(Base): } @classmethod - def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun": + def from_dict(cls, data: dict[str, Any]) -> WorkflowRun: return cls( id=data.get("id"), tenant_id=data.get("tenant_id"), @@ -841,7 +879,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo created_by: Mapped[str] = mapped_column(StringUUID) finished_at: Mapped[datetime | None] = mapped_column(DateTime) - offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship( + offload_data: Mapped[list[WorkflowNodeExecutionOffload]] = orm.relationship( "WorkflowNodeExecutionOffload", primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)", uselist=True, @@ -851,13 +889,13 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo @staticmethod def preload_offload_data( - query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"], + query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel], ): return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data)) @staticmethod def preload_offload_data_and_files( - query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"], + query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel], ): return query.options( orm.selectinload(WorkflowNodeExecutionModel.offload_data).options( @@ -932,7 +970,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo ) return extras - def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]: + def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> WorkflowNodeExecutionOffload | None: return next(iter([i for i in self.offload_data if i.type_ == type_]), None) @property @@ -1046,7 +1084,7 @@ class WorkflowNodeExecutionOffload(Base): back_populates="offload_data", ) - file: Mapped[Optional["UploadFile"]] = orm.relationship( + file: Mapped[UploadFile | None] = orm.relationship( foreign_keys=[file_id], lazy="raise", uselist=False, @@ -1064,7 +1102,7 @@ class WorkflowAppLogCreatedFrom(StrEnum): INSTALLED_APP = "installed-app" @classmethod - def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom": + def value_of(cls, value: str) -> WorkflowAppLogCreatedFrom: """ Get value of given mode. @@ -1181,7 +1219,7 @@ class ConversationVariable(TypeBase): ) @classmethod - def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> "ConversationVariable": + def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> ConversationVariable: obj = cls( id=variable.id, app_id=app_id, @@ -1334,7 +1372,7 @@ class WorkflowDraftVariable(Base): ) # Relationship to WorkflowDraftVariableFile - variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship( + variable_file: Mapped[WorkflowDraftVariableFile | None] = orm.relationship( foreign_keys=[file_id], lazy="raise", uselist=False, @@ -1504,8 +1542,9 @@ class WorkflowDraftVariable(Base): node_execution_id: str | None, description: str = "", file_id: str | None = None, - ) -> "WorkflowDraftVariable": + ) -> WorkflowDraftVariable: variable = WorkflowDraftVariable() + variable.id = str(uuid4()) variable.created_at = naive_utc_now() variable.updated_at = naive_utc_now() variable.description = description @@ -1526,7 +1565,7 @@ class WorkflowDraftVariable(Base): name: str, value: Segment, description: str = "", - ) -> "WorkflowDraftVariable": + ) -> WorkflowDraftVariable: variable = cls._new( app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, @@ -1547,7 +1586,7 @@ class WorkflowDraftVariable(Base): value: Segment, node_execution_id: str, editable: bool = False, - ) -> "WorkflowDraftVariable": + ) -> WorkflowDraftVariable: variable = cls._new( app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, @@ -1570,7 +1609,7 @@ class WorkflowDraftVariable(Base): visible: bool = True, editable: bool = True, file_id: str | None = None, - ) -> "WorkflowDraftVariable": + ) -> WorkflowDraftVariable: variable = cls._new( app_id=app_id, node_id=node_id, @@ -1666,7 +1705,7 @@ class WorkflowDraftVariableFile(Base): ) # Relationship to UploadFile - upload_file: Mapped["UploadFile"] = orm.relationship( + upload_file: Mapped[UploadFile] = orm.relationship( foreign_keys=[upload_file_id], lazy="raise", uselist=False, @@ -1733,7 +1772,7 @@ class WorkflowPause(DefaultFieldsMixin, Base): state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False) # Relationship to WorkflowRun - workflow_run: Mapped["WorkflowRun"] = orm.relationship( + workflow_run: Mapped[WorkflowRun] = orm.relationship( foreign_keys=[workflow_run_id], # require explicit preloading. lazy="raise", @@ -1789,7 +1828,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base): ) @classmethod - def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason": + def from_entity(cls, pause_reason: PauseReason) -> WorkflowPauseReason: if isinstance(pause_reason, HumanInputRequired): return cls( type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 7f44fe05a6..b73302508a 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -137,13 +137,16 @@ class AppAnnotationService: if not app: raise NotFound("App not found") if keyword: + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(keyword) stmt = ( select(MessageAnnotation) .where(MessageAnnotation.app_id == app_id) .where( or_( - MessageAnnotation.question.ilike(f"%{keyword}%"), - MessageAnnotation.content.ilike(f"%{keyword}%"), + MessageAnnotation.question.ilike(f"%{escaped_keyword}%", escape="\\"), + MessageAnnotation.content.ilike(f"%{escaped_keyword}%", escape="\\"), ) ) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index deba0b79e8..acd2a25a86 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -32,7 +32,7 @@ from extensions.ext_redis import redis_client from factories import variable_factory from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode -from models.model import AppModelConfig +from models.model import AppModelConfig, IconType from models.workflow import Workflow from services.plugin.dependencies_analysis import DependenciesAnalysisService from services.workflow_draft_variable_service import WorkflowDraftVariableService @@ -428,10 +428,10 @@ class AppDslService: # Set icon type icon_type_value = icon_type or app_data.get("icon_type") - if icon_type_value in ["emoji", "link", "image"]: + if icon_type_value in [IconType.EMOJI.value, IconType.IMAGE.value, IconType.LINK.value]: icon_type = icon_type_value else: - icon_type = "emoji" + icon_type = IconType.EMOJI.value icon = icon or str(app_data.get("icon", "")) if app: diff --git a/api/services/app_service.py b/api/services/app_service.py index ef89a4fd10..02ebfbace0 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -55,8 +55,11 @@ class AppService: if args.get("is_created_by_me", False): filters.append(App.created_by == user_id) if args.get("name"): + from libs.helper import escape_like_pattern + name = args["name"][:30] - filters.append(App.name.ilike(f"%{name}%")) + escaped_name = escape_like_pattern(name) + filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\")) # Check if tag_ids is not empty to avoid WHERE false condition if args.get("tag_ids") and len(args["tag_ids"]) > 0: target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"]) diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 659e7406fb..295d48d8a1 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -11,13 +11,13 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator from core.variables.types import SegmentType -from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory from extensions.ext_database import db from factories import variable_factory from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable from models.model import App, Conversation, EndUser, Message +from services.conversation_variable_updater import ConversationVariableUpdater from services.errors.conversation import ( ConversationNotExistsError, ConversationVariableNotExistsError, @@ -218,7 +218,9 @@ class ConversationService: # Apply variable_name filter if provided if variable_name: # Filter using JSON extraction to match variable names case-insensitively - escaped_variable_name = variable_name.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + from libs.helper import escape_like_pattern + + escaped_variable_name = escape_like_pattern(variable_name) # Filter using JSON extraction to match variable names case-insensitively if dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]: stmt = stmt.where( @@ -335,7 +337,7 @@ class ConversationService: updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict) # Use the conversation variable updater to persist the changes - updater = conversation_variable_updater_factory() + updater = ConversationVariableUpdater(session_factory.get_session_maker()) updater.update(conversation_id, updated_variable) updater.flush() diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py new file mode 100644 index 0000000000..acc0ec2b22 --- /dev/null +++ b/api/services/conversation_variable_updater.py @@ -0,0 +1,28 @@ +from sqlalchemy import select +from sqlalchemy.orm import Session, sessionmaker + +from core.variables.variables import Variable +from models import ConversationVariable + + +class ConversationVariableNotFoundError(Exception): + pass + + +class ConversationVariableUpdater: + def __init__(self, session_maker: sessionmaker[Session]) -> None: + self._session_maker: sessionmaker[Session] = session_maker + + def update(self, conversation_id: str, variable: Variable) -> None: + stmt = select(ConversationVariable).where( + ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id + ) + with self._session_maker() as session: + row = session.scalar(stmt) + if not row: + raise ConversationVariableNotFoundError("conversation variable not found in the database") + row.data = variable.model_dump_json() + session.commit() + + def flush(self) -> None: + pass diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index ac4b25c5dc..18e5613438 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -144,7 +144,8 @@ class DatasetService: query = query.where(Dataset.permission == DatasetPermissionEnum.ALL_TEAM) if search: - query = query.where(Dataset.name.ilike(f"%{search}%")) + escaped_search = helper.escape_like_pattern(search) + query = query.where(Dataset.name.ilike(f"%{escaped_search}%", escape="\\")) # Check if tag_ids is not empty to avoid WHERE false condition if tag_ids and len(tag_ids) > 0: @@ -3423,7 +3424,8 @@ class SegmentService: .order_by(ChildChunk.position.asc()) ) if keyword: - query = query.where(ChildChunk.content.ilike(f"%{keyword}%")) + escaped_keyword = helper.escape_like_pattern(keyword) + query = query.where(ChildChunk.content.ilike(f"%{escaped_keyword}%", escape="\\")) return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) @classmethod @@ -3456,7 +3458,8 @@ class SegmentService: query = query.where(DocumentSegment.status.in_(status_list)) if keyword: - query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) + escaped_keyword = helper.escape_like_pattern(keyword) + query = query.where(DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\")) query = query.order_by(DocumentSegment.position.asc(), DocumentSegment.id.asc()) paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 40faa85b9a..65dd41af43 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -35,7 +35,10 @@ class ExternalDatasetService: .order_by(ExternalKnowledgeApis.created_at.desc()) ) if search: - query = query.where(ExternalKnowledgeApis.name.ilike(f"%{search}%")) + from libs.helper import escape_like_pattern + + escaped_search = escape_like_pattern(search) + query = query.where(ExternalKnowledgeApis.name.ilike(f"%{escaped_search}%", escape="\\")) external_knowledge_apis = db.paginate( select=query, page=page, per_page=per_page, max_per_page=100, error_out=False diff --git a/api/services/llm_generation_service.py b/api/services/llm_generation_service.py new file mode 100644 index 0000000000..eb8327537e --- /dev/null +++ b/api/services/llm_generation_service.py @@ -0,0 +1,37 @@ +""" +LLM Generation Detail Service. + +Provides methods to query and attach generation details to workflow node executions +and messages, avoiding N+1 query problems. +""" + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.entities.llm_generation_entities import LLMGenerationDetailData +from models import LLMGenerationDetail + + +class LLMGenerationService: + """Service for handling LLM generation details.""" + + def __init__(self, session: Session): + self._session = session + + def get_generation_detail_for_message(self, message_id: str) -> LLMGenerationDetailData | None: + """Query generation detail for a specific message.""" + stmt = select(LLMGenerationDetail).where(LLMGenerationDetail.message_id == message_id) + detail = self._session.scalars(stmt).first() + return detail.to_domain_model() if detail else None + + def get_generation_details_for_messages( + self, + message_ids: list[str], + ) -> dict[str, LLMGenerationDetailData]: + """Batch query generation details for multiple messages.""" + if not message_ids: + return {} + + stmt = select(LLMGenerationDetail).where(LLMGenerationDetail.message_id.in_(message_ids)) + details = self._session.scalars(stmt).all() + return {detail.message_id: detail.to_domain_model() for detail in details if detail.message_id} diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index f53448e7fe..1ba64813ba 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -874,7 +874,7 @@ class RagPipelineService: variable_pool = node_instance.graph_runtime_state.variable_pool invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) if invoke_from: - if invoke_from.value == InvokeFrom.PUBLISHED: + if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE: document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) if document_id: document = db.session.query(Document).where(Document.id == document_id.value).first() @@ -1318,7 +1318,7 @@ class RagPipelineService: "datasource_info_list": [json.loads(document_pipeline_execution_log.datasource_info)], "original_document_id": document.id, }, - invoke_from=InvokeFrom.PUBLISHED, + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, streaming=False, call_depth=0, workflow_thread_pool_id=None, diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 937e6593fe..bd3585acf4 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -19,7 +19,10 @@ class TagService: .where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) ) if keyword: - query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%"))) + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(keyword) + query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\"))) query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at) results: list = query.order_by(Tag.created_at.desc()).all() return results diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 0f969207cf..f973361341 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses from abc import ABC, abstractmethod from collections.abc import Mapping @@ -106,7 +108,7 @@ class VariableTruncator(BaseTruncator): self._max_size_bytes = max_size_bytes @classmethod - def default(cls) -> "VariableTruncator": + def default(cls) -> VariableTruncator: return VariableTruncator( max_size_bytes=dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE, array_element_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH, diff --git a/api/services/website_service.py b/api/services/website_service.py index a23f01ec71..fe48c3b08e 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import json from dataclasses import dataclass @@ -78,7 +80,7 @@ class WebsiteCrawlApiRequest: return CrawlRequest(url=self.url, provider=self.provider, options=options) @classmethod - def from_args(cls, args: dict) -> "WebsiteCrawlApiRequest": + def from_args(cls, args: dict) -> WebsiteCrawlApiRequest: """Create from Flask-RESTful parsed arguments.""" provider = args.get("provider") url = args.get("url") @@ -102,7 +104,7 @@ class WebsiteCrawlStatusApiRequest: job_id: str @classmethod - def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest": + def from_args(cls, args: dict, job_id: str) -> WebsiteCrawlStatusApiRequest: """Create from Flask-RESTful parsed arguments.""" provider = args.get("provider") if not provider: diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 01f0c7a55a..8574d30255 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -86,12 +86,19 @@ class WorkflowAppService: # Join to workflow run for filtering when needed. if keyword: - keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u") + from libs.helper import escape_like_pattern + + # Escape special characters in keyword to prevent SQL injection via LIKE wildcards + escaped_keyword = escape_like_pattern(keyword[:30]) + keyword_like_val = f"%{escaped_keyword}%" keyword_conditions = [ - WorkflowRun.inputs.ilike(keyword_like_val), - WorkflowRun.outputs.ilike(keyword_like_val), + WorkflowRun.inputs.ilike(keyword_like_val, escape="\\"), + WorkflowRun.outputs.ilike(keyword_like_val, escape="\\"), # filter keyword by end user session id if created by end user role - and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)), + and_( + WorkflowRun.created_by_role == "end_user", + EndUser.session_id.ilike(keyword_like_val, escape="\\"), + ), ] # filter keyword by workflow run id diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index f299ce3baa..9407a2b3f0 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -679,6 +679,7 @@ def _batch_upsert_draft_variable( def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]: d: dict[str, Any] = { + "id": model.id, "app_id": model.app_id, "last_edited_at": None, "node_id": model.node_id, diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index 1eef361a92..3c5e152520 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -178,7 +178,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], workflow_id=workflow_id, user=account, application_generate_entity=entity, - invoke_from=InvokeFrom.PUBLISHED, + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 275f5abe6e..093342d1a3 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -178,7 +178,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], workflow_id=workflow_id, user=account, application_generate_entity=entity, - invoke_from=InvokeFrom.PUBLISHED, + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index da73122cd7..5555400ca6 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -444,6 +444,78 @@ class TestAnnotationService: assert total == 1 assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content + def test_get_annotation_list_by_app_id_with_special_characters_in_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _, \) in keyword are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotations with special characters in content + annotation_with_percent = { + "question": "Question with 50% discount", + "answer": "Answer about 50% discount offer", + } + AppAnnotationService.insert_app_annotation_directly(annotation_with_percent, app.id) + + annotation_with_underscore = { + "question": "Question with test_data", + "answer": "Answer about test_data value", + } + AppAnnotationService.insert_app_annotation_directly(annotation_with_underscore, app.id) + + annotation_with_backslash = { + "question": "Question with path\\to\\file", + "answer": "Answer about path\\to\\file location", + } + AppAnnotationService.insert_app_annotation_directly(annotation_with_backslash, app.id) + + # Create annotation that should NOT match (contains % but as part of different text) + annotation_no_match = { + "question": "Question with 100% different", + "answer": "Answer about 100% different content", + } + AppAnnotationService.insert_app_annotation_directly(annotation_no_match, app.id) + + # Test 1: Search with % character - should find exact match only + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="50%" + ) + assert total == 1 + assert len(annotation_list) == 1 + assert "50%" in annotation_list[0].question or "50%" in annotation_list[0].content + + # Test 2: Search with _ character - should find exact match only + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="test_data" + ) + assert total == 1 + assert len(annotation_list) == 1 + assert "test_data" in annotation_list[0].question or "test_data" in annotation_list[0].content + + # Test 3: Search with \ character - should find exact match only + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="path\\to\\file" + ) + assert total == 1 + assert len(annotation_list) == 1 + assert "path\\to\\file" in annotation_list[0].question or "path\\to\\file" in annotation_list[0].content + + # Test 4: Search with % should NOT match 100% (verifies escaping works) + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="50%" + ) + # Should only find the 50% annotation, not the 100% one + assert total == 1 + assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list) + def test_get_annotation_list_by_app_id_app_not_found( self, db_session_with_containers, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index e53392bcef..745d6c97b0 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -7,7 +7,9 @@ from constants.model_template import default_app_templates from models import Account from models.model import App, Site from services.account_service import AccountService, TenantService -from services.app_service import AppService + +# Delay import of AppService to avoid circular dependency +# from services.app_service import AppService class TestAppService: @@ -71,6 +73,9 @@ class TestAppService: } # Create app + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -109,6 +114,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Test different app modes @@ -159,6 +167,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() created_app = app_service.create_app(tenant.id, app_args, account) @@ -194,6 +205,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Create multiple apps @@ -245,6 +259,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Create apps with different modes @@ -315,6 +332,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Create an app @@ -392,6 +412,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -458,6 +481,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -508,6 +534,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -562,6 +591,9 @@ class TestAppService: "icon_background": "#74B9FF", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -617,6 +649,9 @@ class TestAppService: "icon_background": "#A29BFE", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -672,6 +707,9 @@ class TestAppService: "icon_background": "#FD79A8", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -720,6 +758,9 @@ class TestAppService: "icon_background": "#E17055", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -768,6 +809,9 @@ class TestAppService: "icon_background": "#00B894", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -826,6 +870,9 @@ class TestAppService: "icon_background": "#6C5CE7", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -862,6 +909,9 @@ class TestAppService: "icon_background": "#FDCB6E", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -899,6 +949,9 @@ class TestAppService: "icon_background": "#E84393", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -947,8 +1000,132 @@ class TestAppService: "icon_background": "#D63031", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Attempt to create app with invalid mode with pytest.raises(ValueError, match="invalid mode value"): app_service.create_app(tenant.id, app_args, account) + + def test_get_apps_with_special_characters_in_name( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test app retrieval with special characters in name search to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _, \) in name search are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Import here to avoid circular dependency + from services.app_service import AppService + + app_service = AppService() + + # Create apps with special characters in names + app_with_percent = app_service.create_app( + tenant.id, + { + "name": "App with 50% discount", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + app_with_underscore = app_service.create_app( + tenant.id, + { + "name": "test_data_app", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + app_with_backslash = app_service.create_app( + tenant.id, + { + "name": "path\\to\\app", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + # Create app that should NOT match + app_no_match = app_service.create_app( + tenant.id, + { + "name": "100% different", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + # Test 1: Search with % character + args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].name == "App with 50% discount" + + # Test 2: Search with _ character + args = {"name": "test_data", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].name == "test_data_app" + + # Test 3: Search with \ character + args = {"name": "path\\to\\app", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].name == "path\\to\\app" + + # Test 4: Search with % should NOT match 100% (verifies escaping works) + args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert all("50%" in app.name for app in paginated_apps.items) diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index 6732b8d558..e8c7f17e0b 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -1,3 +1,4 @@ +import uuid from unittest.mock import create_autospec, patch import pytest @@ -312,6 +313,85 @@ class TestTagService: result_no_match = TagService.get_tags("app", tenant.id, keyword="nonexistent") assert len(result_no_match) == 0 + def test_get_tags_with_special_characters_in_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test tag retrieval with special characters in keyword to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _, \) in keyword are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + from extensions.ext_database import db + + # Create tags with special characters in names + tag_with_percent = Tag( + name="50% discount", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_with_percent.id = str(uuid.uuid4()) + db.session.add(tag_with_percent) + + tag_with_underscore = Tag( + name="test_data_tag", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_with_underscore.id = str(uuid.uuid4()) + db.session.add(tag_with_underscore) + + tag_with_backslash = Tag( + name="path\\to\\tag", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_with_backslash.id = str(uuid.uuid4()) + db.session.add(tag_with_backslash) + + # Create tag that should NOT match + tag_no_match = Tag( + name="100% different", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_no_match.id = str(uuid.uuid4()) + db.session.add(tag_no_match) + + db.session.commit() + + # Act & Assert: Test 1 - Search with % character + result = TagService.get_tags("app", tenant.id, keyword="50%") + assert len(result) == 1 + assert result[0].name == "50% discount" + + # Test 2 - Search with _ character + result = TagService.get_tags("app", tenant.id, keyword="test_data") + assert len(result) == 1 + assert result[0].name == "test_data_tag" + + # Test 3 - Search with \ character + result = TagService.get_tags("app", tenant.id, keyword="path\\to\\tag") + assert len(result) == 1 + assert result[0].name == "path\\to\\tag" + + # Test 4 - Search with % should NOT match 100% (verifies escaping works) + result = TagService.get_tags("app", tenant.id, keyword="50%") + assert len(result) == 1 + assert all("50%" in item.name for item in result) + def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies): """ Test tag retrieval when no tags exist. diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 7b95944bbe..040fb826e1 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -10,7 +10,9 @@ from core.workflow.entities.workflow_execution import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun from models.enums import CreatorUserRole from services.account_service import AccountService, TenantService -from services.app_service import AppService + +# Delay import of AppService to avoid circular dependency +# from services.app_service import AppService from services.workflow_app_service import WorkflowAppService @@ -86,6 +88,9 @@ class TestWorkflowAppService: "api_rpm": 10, } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -147,6 +152,9 @@ class TestWorkflowAppService: "api_rpm": 10, } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -308,6 +316,156 @@ class TestWorkflowAppService: assert result_no_match["total"] == 0 assert len(result_no_match["data"]) == 0 + def test_get_paginate_workflow_app_logs_with_special_characters_in_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test workflow app logs pagination with special characters in keyword to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _) in keyword are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + # Arrange: Create test data + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account) + + from extensions.ext_database import db + + service = WorkflowAppService() + + # Test 1: Search with % character + workflow_run_1 = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status="succeeded", + inputs=json.dumps({"search_term": "50% discount", "input2": "other_value"}), + outputs=json.dumps({"result": "50% discount applied", "status": "success"}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_run_1) + db.session.flush() + + workflow_app_log_1 = WorkflowAppLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run_1.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + workflow_app_log_1.id = str(uuid.uuid4()) + workflow_app_log_1.created_at = datetime.now(UTC) + db.session.add(workflow_app_log_1) + db.session.commit() + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 + ) + # Should find the workflow_run_1 entry + assert result["total"] >= 1 + assert len(result["data"]) >= 1 + assert any(log.workflow_run_id == workflow_run_1.id for log in result["data"]) + + # Test 2: Search with _ character + workflow_run_2 = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status="succeeded", + inputs=json.dumps({"search_term": "test_data_value", "input2": "other_value"}), + outputs=json.dumps({"result": "test_data_value found", "status": "success"}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_run_2) + db.session.flush() + + workflow_app_log_2 = WorkflowAppLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run_2.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + workflow_app_log_2.id = str(uuid.uuid4()) + workflow_app_log_2.created_at = datetime.now(UTC) + db.session.add(workflow_app_log_2) + db.session.commit() + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="test_data", page=1, limit=20 + ) + # Should find the workflow_run_2 entry + assert result["total"] >= 1 + assert len(result["data"]) >= 1 + assert any(log.workflow_run_id == workflow_run_2.id for log in result["data"]) + + # Test 3: Search with % should NOT match 100% (verifies escaping works correctly) + workflow_run_4 = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status="succeeded", + inputs=json.dumps({"search_term": "100% different", "input2": "other_value"}), + outputs=json.dumps({"result": "100% different result", "status": "success"}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_run_4) + db.session.flush() + + workflow_app_log_4 = WorkflowAppLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run_4.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + workflow_app_log_4.id = str(uuid.uuid4()) + workflow_app_log_4.created_at = datetime.now(UTC) + db.session.add(workflow_app_log_4) + db.session.commit() + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 + ) + # Should only find the 50% entry (workflow_run_1), not the 100% entry (workflow_run_4) + # This verifies that escaping works correctly - 50% should not match 100% + assert result["total"] >= 1 + assert len(result["data"]) >= 1 + # Verify that we found workflow_run_1 (50% discount) but not workflow_run_4 (100% different) + found_run_ids = [log.workflow_run_id for log in result["data"]] + assert workflow_run_1.id in found_run_ids + assert workflow_run_4.id not in found_run_ids + def test_get_paginate_workflow_app_logs_with_status_filter( self, db_session_with_containers, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py index e29b98037f..b9977b1fb6 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py +++ b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py @@ -165,7 +165,7 @@ class TestRagPipelineRunTasks: "files": [], "user_id": account.id, "stream": False, - "invoke_from": "published", + "invoke_from": InvokeFrom.PUBLISHED_PIPELINE.value, "workflow_execution_id": str(uuid.uuid4()), "pipeline_config": { "app_id": str(uuid.uuid4()), @@ -249,7 +249,7 @@ class TestRagPipelineRunTasks: assert call_kwargs["pipeline"].id == pipeline.id assert call_kwargs["workflow_id"] == workflow.id assert call_kwargs["user"].id == account.id - assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED + assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE assert call_kwargs["streaming"] == False assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) @@ -294,7 +294,7 @@ class TestRagPipelineRunTasks: assert call_kwargs["pipeline"].id == pipeline.id assert call_kwargs["workflow_id"] == workflow.id assert call_kwargs["user"].id == account.id - assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED + assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE assert call_kwargs["streaming"] == False assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) @@ -743,7 +743,7 @@ class TestRagPipelineRunTasks: assert call_kwargs["pipeline"].id == pipeline.id assert call_kwargs["workflow_id"] == workflow.id assert call_kwargs["user"].id == account.id - assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED + assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE assert call_kwargs["streaming"] == False assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index 209b6bf59b..6fce7849f9 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -16,6 +16,7 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30") # Custom value for testing + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -51,6 +52,7 @@ def test_http_timeout_defaults(monkeypatch: pytest.MonkeyPatch): os.environ.clear() # Set minimal required env vars + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -75,6 +77,7 @@ def test_flask_configs(monkeypatch: pytest.MonkeyPatch): # Set environment variables using monkeypatch monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -124,6 +127,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch): # Set environment variables using monkeypatch monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -140,6 +144,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch): def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch): """Test that DB_EXTRAS options are properly merged with default timezone setting""" # Set environment variables + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -199,6 +204,7 @@ def test_celery_broker_url_with_special_chars_password( # Set up basic required environment variables (following existing pattern) monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") diff --git a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py new file mode 100644 index 0000000000..40eb59a8f4 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +import builtins +import sys +from datetime import datetime +from importlib import util +from pathlib import Path +from types import ModuleType, SimpleNamespace +from typing import Any + +import pytest +from flask.views import MethodView + +# kombu references MethodView as a global when importing celery/kombu pools. +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +def _load_app_module(): + module_name = "controllers.console.app.app" + if module_name in sys.modules: + return sys.modules[module_name] + + root = Path(__file__).resolve().parents[5] + module_path = root / "controllers" / "console" / "app" / "app.py" + + class _StubNamespace: + def __init__(self): + self.models: dict[str, Any] = {} + self.payload = None + + def schema_model(self, name, schema): + self.models[name] = schema + + def _decorator(self, obj): + return obj + + def doc(self, *args, **kwargs): + return self._decorator + + def expect(self, *args, **kwargs): + return self._decorator + + def response(self, *args, **kwargs): + return self._decorator + + def route(self, *args, **kwargs): + def decorator(obj): + return obj + + return decorator + + stub_namespace = _StubNamespace() + + original_console = sys.modules.get("controllers.console") + original_app_pkg = sys.modules.get("controllers.console.app") + stubbed_modules: list[tuple[str, ModuleType | None]] = [] + + console_module = ModuleType("controllers.console") + console_module.__path__ = [str(root / "controllers" / "console")] + console_module.console_ns = stub_namespace + console_module.api = None + console_module.bp = None + sys.modules["controllers.console"] = console_module + + app_package = ModuleType("controllers.console.app") + app_package.__path__ = [str(root / "controllers" / "console" / "app")] + sys.modules["controllers.console.app"] = app_package + console_module.app = app_package + + def _stub_module(name: str, attrs: dict[str, Any]): + original = sys.modules.get(name) + module = ModuleType(name) + for key, value in attrs.items(): + setattr(module, key, value) + sys.modules[name] = module + stubbed_modules.append((name, original)) + + class _OpsTraceManager: + @staticmethod + def get_app_tracing_config(app_id: str) -> dict[str, Any]: + return {} + + @staticmethod + def update_app_tracing_config(app_id: str, **kwargs) -> None: + return None + + _stub_module( + "core.ops.ops_trace_manager", + { + "OpsTraceManager": _OpsTraceManager, + "TraceQueueManager": object, + "TraceTask": object, + }, + ) + + spec = util.spec_from_file_location(module_name, module_path) + module = util.module_from_spec(spec) + sys.modules[module_name] = module + + try: + assert spec.loader is not None + spec.loader.exec_module(module) + finally: + for name, original in reversed(stubbed_modules): + if original is not None: + sys.modules[name] = original + else: + sys.modules.pop(name, None) + if original_console is not None: + sys.modules["controllers.console"] = original_console + else: + sys.modules.pop("controllers.console", None) + if original_app_pkg is not None: + sys.modules["controllers.console.app"] = original_app_pkg + else: + sys.modules.pop("controllers.console.app", None) + + return module + + +_app_module = _load_app_module() +AppDetailWithSite = _app_module.AppDetailWithSite +AppPagination = _app_module.AppPagination +AppPartial = _app_module.AppPartial + + +@pytest.fixture(autouse=True) +def patch_signed_url(monkeypatch): + """Ensure icon URL generation uses a deterministic helper for tests.""" + + def _fake_signed_url(key: str | None) -> str | None: + if not key: + return None + return f"signed:{key}" + + monkeypatch.setattr(_app_module.file_helpers, "get_signed_file_url", _fake_signed_url) + + +def _ts(hour: int = 12) -> datetime: + return datetime(2024, 1, 1, hour, 0, 0) + + +def _dummy_model_config(): + return SimpleNamespace( + model_dict={"provider": "openai", "name": "gpt-4o"}, + pre_prompt="hello", + created_by="config-author", + created_at=_ts(9), + updated_by="config-editor", + updated_at=_ts(10), + ) + + +def _dummy_workflow(): + return SimpleNamespace( + id="wf-1", + created_by="workflow-author", + created_at=_ts(8), + updated_by="workflow-editor", + updated_at=_ts(9), + ) + + +def test_app_partial_serialization_uses_aliases(): + created_at = _ts() + app_obj = SimpleNamespace( + id="app-1", + name="My App", + desc_or_prompt="Prompt snippet", + mode_compatible_with_agent="chat", + icon_type="image", + icon="icon-key", + icon_background="#fff", + app_model_config=_dummy_model_config(), + workflow=_dummy_workflow(), + created_by="creator", + created_at=created_at, + updated_by="editor", + updated_at=created_at, + tags=[SimpleNamespace(id="tag-1", name="Utilities", type="app")], + access_mode="private", + create_user_name="Creator", + author_name="Author", + has_draft_trigger=True, + ) + + serialized = AppPartial.model_validate(app_obj, from_attributes=True).model_dump(mode="json") + + assert serialized["description"] == "Prompt snippet" + assert serialized["mode"] == "chat" + assert serialized["icon_url"] == "signed:icon-key" + assert serialized["created_at"] == int(created_at.timestamp()) + assert serialized["updated_at"] == int(created_at.timestamp()) + assert serialized["model_config"]["model"] == {"provider": "openai", "name": "gpt-4o"} + assert serialized["workflow"]["id"] == "wf-1" + assert serialized["tags"][0]["name"] == "Utilities" + + +def test_app_detail_with_site_includes_nested_serialization(): + timestamp = _ts(14) + site = SimpleNamespace( + code="site-code", + title="Public Site", + icon_type="image", + icon="site-icon", + created_at=timestamp, + updated_at=timestamp, + ) + app_obj = SimpleNamespace( + id="app-2", + name="Detailed App", + description="Desc", + mode_compatible_with_agent="advanced-chat", + icon_type="image", + icon="detail-icon", + icon_background="#123456", + enable_site=True, + enable_api=True, + app_model_config={ + "opening_statement": "hi", + "model": {"provider": "openai", "name": "gpt-4o"}, + "retriever_resource": {"enabled": True}, + }, + workflow=_dummy_workflow(), + tracing={"enabled": True}, + use_icon_as_answer_icon=True, + created_by="creator", + created_at=timestamp, + updated_by="editor", + updated_at=timestamp, + access_mode="public", + tags=[SimpleNamespace(id="tag-2", name="Prod", type="app")], + api_base_url="https://api.example.com/v1", + max_active_requests=5, + deleted_tools=[{"type": "api", "tool_name": "search", "provider_id": "prov"}], + site=site, + ) + + serialized = AppDetailWithSite.model_validate(app_obj, from_attributes=True).model_dump(mode="json") + + assert serialized["icon_url"] == "signed:detail-icon" + assert serialized["model_config"]["retriever_resource"] == {"enabled": True} + assert serialized["deleted_tools"][0]["tool_name"] == "search" + assert serialized["site"]["icon_url"] == "signed:site-icon" + assert serialized["site"]["created_at"] == int(timestamp.timestamp()) + + +def test_app_pagination_aliases_per_page_and_has_next(): + item_one = SimpleNamespace( + id="app-10", + name="Paginated One", + desc_or_prompt="Summary", + mode_compatible_with_agent="chat", + icon_type="image", + icon="first-icon", + created_at=_ts(15), + updated_at=_ts(15), + ) + item_two = SimpleNamespace( + id="app-11", + name="Paginated Two", + desc_or_prompt="Summary", + mode_compatible_with_agent="agent-chat", + icon_type="emoji", + icon="🙂", + created_at=_ts(16), + updated_at=_ts(16), + ) + pagination = SimpleNamespace( + page=2, + per_page=10, + total=50, + has_next=True, + items=[item_one, item_two], + ) + + serialized = AppPagination.model_validate(pagination, from_attributes=True).model_dump(mode="json") + + assert serialized["page"] == 2 + assert serialized["limit"] == 10 + assert serialized["has_more"] is True + assert len(serialized["data"]) == 2 + assert serialized["data"][0]["icon_url"] == "signed:first-icon" + assert serialized["data"][1]["icon_url"] is None diff --git a/api/tests/unit_tests/controllers/console/test_files_security.py b/api/tests/unit_tests/controllers/console/test_files_security.py index 2630fbcfd0..370bf63fdb 100644 --- a/api/tests/unit_tests/controllers/console/test_files_security.py +++ b/api/tests/unit_tests/controllers/console/test_files_security.py @@ -1,7 +1,9 @@ +import builtins import io from unittest.mock import patch import pytest +from flask.views import MethodView from werkzeug.exceptions import Forbidden from controllers.common.errors import ( @@ -14,6 +16,9 @@ from controllers.common.errors import ( from services.errors.file import FileTooLargeError as ServiceFileTooLargeError from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + class TestFileUploadSecurity: """Test file upload security logic without complex framework setup""" @@ -128,7 +133,7 @@ class TestFileUploadSecurity: # Test passes if no exception is raised # Test 4: Service error handling - @patch("services.file_service.FileService.upload_file") + @patch("controllers.console.files.FileService.upload_file") def test_should_handle_file_too_large_error(self, mock_upload): """Test that service FileTooLargeError is properly converted""" mock_upload.side_effect = ServiceFileTooLargeError("File too large") @@ -140,7 +145,7 @@ class TestFileUploadSecurity: with pytest.raises(FileTooLargeError): raise FileTooLargeError(e.description) - @patch("services.file_service.FileService.upload_file") + @patch("controllers.console.files.FileService.upload_file") def test_should_handle_unsupported_file_type_error(self, mock_upload): """Test that service UnsupportedFileTypeError is properly converted""" mock_upload.side_effect = ServiceUnsupportedFileTypeError() diff --git a/api/tests/unit_tests/core/agent/__init__.py b/api/tests/unit_tests/core/agent/__init__.py new file mode 100644 index 0000000000..a9ccd45f4b --- /dev/null +++ b/api/tests/unit_tests/core/agent/__init__.py @@ -0,0 +1,3 @@ +""" +Mark agent test modules as a package to avoid import name collisions. +""" diff --git a/api/tests/unit_tests/core/agent/patterns/__init__.py b/api/tests/unit_tests/core/agent/patterns/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/agent/patterns/test_base.py b/api/tests/unit_tests/core/agent/patterns/test_base.py new file mode 100644 index 0000000000..b0e0d44940 --- /dev/null +++ b/api/tests/unit_tests/core/agent/patterns/test_base.py @@ -0,0 +1,324 @@ +"""Tests for AgentPattern base class.""" + +from decimal import Decimal +from unittest.mock import MagicMock + +import pytest + +from core.agent.entities import AgentLog, ExecutionContext +from core.agent.patterns.base import AgentPattern +from core.model_runtime.entities.llm_entities import LLMUsage + + +class ConcreteAgentPattern(AgentPattern): + """Concrete implementation of AgentPattern for testing.""" + + def run(self, prompt_messages, model_parameters, stop=[], stream=True): + """Minimal implementation for testing.""" + yield from [] + + +@pytest.fixture +def mock_model_instance(): + """Create a mock model instance.""" + model_instance = MagicMock() + model_instance.model = "test-model" + model_instance.provider = "test-provider" + return model_instance + + +@pytest.fixture +def mock_context(): + """Create a mock execution context.""" + return ExecutionContext( + user_id="test-user", + app_id="test-app", + conversation_id="test-conversation", + message_id="test-message", + tenant_id="test-tenant", + ) + + +@pytest.fixture +def agent_pattern(mock_model_instance, mock_context): + """Create a concrete agent pattern for testing.""" + return ConcreteAgentPattern( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + max_iterations=10, + ) + + +class TestAccumulateUsage: + """Tests for _accumulate_usage method.""" + + def test_accumulate_usage_to_empty_dict(self, agent_pattern): + """Test accumulating usage to an empty dict creates a copy.""" + total_usage: dict = {"usage": None} + delta_usage = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.1"), + completion_tokens=50, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.1"), + total_tokens=150, + total_price=Decimal("0.2"), + currency="USD", + latency=0.5, + ) + + agent_pattern._accumulate_usage(total_usage, delta_usage) + + assert total_usage["usage"] is not None + assert total_usage["usage"].total_tokens == 150 + assert total_usage["usage"].prompt_tokens == 100 + assert total_usage["usage"].completion_tokens == 50 + # Verify it's a copy, not a reference + assert total_usage["usage"] is not delta_usage + + def test_accumulate_usage_adds_to_existing(self, agent_pattern): + """Test accumulating usage adds to existing values.""" + initial_usage = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.1"), + completion_tokens=50, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.1"), + total_tokens=150, + total_price=Decimal("0.2"), + currency="USD", + latency=0.5, + ) + total_usage: dict = {"usage": initial_usage} + + delta_usage = LLMUsage( + prompt_tokens=200, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.2"), + completion_tokens=100, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.2"), + total_tokens=300, + total_price=Decimal("0.4"), + currency="USD", + latency=0.5, + ) + + agent_pattern._accumulate_usage(total_usage, delta_usage) + + assert total_usage["usage"].total_tokens == 450 # 150 + 300 + assert total_usage["usage"].prompt_tokens == 300 # 100 + 200 + assert total_usage["usage"].completion_tokens == 150 # 50 + 100 + + def test_accumulate_usage_multiple_rounds(self, agent_pattern): + """Test accumulating usage across multiple rounds.""" + total_usage: dict = {"usage": None} + + # Round 1: 100 tokens + round1_usage = LLMUsage( + prompt_tokens=70, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.07"), + completion_tokens=30, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.06"), + total_tokens=100, + total_price=Decimal("0.13"), + currency="USD", + latency=0.3, + ) + agent_pattern._accumulate_usage(total_usage, round1_usage) + assert total_usage["usage"].total_tokens == 100 + + # Round 2: 150 tokens + round2_usage = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.1"), + completion_tokens=50, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.1"), + total_tokens=150, + total_price=Decimal("0.2"), + currency="USD", + latency=0.4, + ) + agent_pattern._accumulate_usage(total_usage, round2_usage) + assert total_usage["usage"].total_tokens == 250 # 100 + 150 + + # Round 3: 200 tokens + round3_usage = LLMUsage( + prompt_tokens=130, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.13"), + completion_tokens=70, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.14"), + total_tokens=200, + total_price=Decimal("0.27"), + currency="USD", + latency=0.5, + ) + agent_pattern._accumulate_usage(total_usage, round3_usage) + assert total_usage["usage"].total_tokens == 450 # 100 + 150 + 200 + + +class TestCreateLog: + """Tests for _create_log method.""" + + def test_create_log_with_label_and_status(self, agent_pattern): + """Test creating a log with label and status.""" + log = agent_pattern._create_log( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={"key": "value"}, + ) + + assert log.label == "ROUND 1" + assert log.log_type == AgentLog.LogType.ROUND + assert log.status == AgentLog.LogStatus.START + assert log.data == {"key": "value"} + assert log.parent_id is None + + def test_create_log_with_parent_id(self, agent_pattern): + """Test creating a log with parent_id.""" + parent_log = agent_pattern._create_log( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={}, + ) + + child_log = agent_pattern._create_log( + label="CALL tool", + log_type=AgentLog.LogType.TOOL_CALL, + status=AgentLog.LogStatus.START, + data={}, + parent_id=parent_log.id, + ) + + assert child_log.parent_id == parent_log.id + assert child_log.log_type == AgentLog.LogType.TOOL_CALL + + +class TestFinishLog: + """Tests for _finish_log method.""" + + def test_finish_log_updates_status(self, agent_pattern): + """Test that finish_log updates status to SUCCESS.""" + log = agent_pattern._create_log( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={}, + ) + + finished_log = agent_pattern._finish_log(log, data={"result": "done"}) + + assert finished_log.status == AgentLog.LogStatus.SUCCESS + assert finished_log.data == {"result": "done"} + + def test_finish_log_adds_usage_metadata(self, agent_pattern): + """Test that finish_log adds usage to metadata.""" + log = agent_pattern._create_log( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={}, + ) + + usage = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.1"), + completion_tokens=50, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.1"), + total_tokens=150, + total_price=Decimal("0.2"), + currency="USD", + latency=0.5, + ) + + finished_log = agent_pattern._finish_log(log, usage=usage) + + assert finished_log.metadata[AgentLog.LogMetadata.TOTAL_TOKENS] == 150 + assert finished_log.metadata[AgentLog.LogMetadata.TOTAL_PRICE] == Decimal("0.2") + assert finished_log.metadata[AgentLog.LogMetadata.CURRENCY] == "USD" + assert finished_log.metadata[AgentLog.LogMetadata.LLM_USAGE] == usage + + +class TestFindToolByName: + """Tests for _find_tool_by_name method.""" + + def test_find_existing_tool(self, mock_model_instance, mock_context): + """Test finding an existing tool by name.""" + mock_tool = MagicMock() + mock_tool.entity.identity.name = "test_tool" + + pattern = ConcreteAgentPattern( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + found_tool = pattern._find_tool_by_name("test_tool") + assert found_tool == mock_tool + + def test_find_nonexistent_tool_returns_none(self, mock_model_instance, mock_context): + """Test that finding a nonexistent tool returns None.""" + mock_tool = MagicMock() + mock_tool.entity.identity.name = "test_tool" + + pattern = ConcreteAgentPattern( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + found_tool = pattern._find_tool_by_name("nonexistent_tool") + assert found_tool is None + + +class TestMaxIterationsCapping: + """Tests for max_iterations capping.""" + + def test_max_iterations_capped_at_99(self, mock_model_instance, mock_context): + """Test that max_iterations is capped at 99.""" + pattern = ConcreteAgentPattern( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + max_iterations=150, + ) + + assert pattern.max_iterations == 99 + + def test_max_iterations_not_capped_when_under_99(self, mock_model_instance, mock_context): + """Test that max_iterations is not capped when under 99.""" + pattern = ConcreteAgentPattern( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + max_iterations=50, + ) + + assert pattern.max_iterations == 50 diff --git a/api/tests/unit_tests/core/agent/patterns/test_function_call.py b/api/tests/unit_tests/core/agent/patterns/test_function_call.py new file mode 100644 index 0000000000..6b3600dbbf --- /dev/null +++ b/api/tests/unit_tests/core/agent/patterns/test_function_call.py @@ -0,0 +1,332 @@ +"""Tests for FunctionCallStrategy.""" + +from decimal import Decimal +from unittest.mock import MagicMock + +import pytest + +from core.agent.entities import AgentLog, ExecutionContext +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.message_entities import ( + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) + + +@pytest.fixture +def mock_model_instance(): + """Create a mock model instance.""" + model_instance = MagicMock() + model_instance.model = "test-model" + model_instance.provider = "test-provider" + return model_instance + + +@pytest.fixture +def mock_context(): + """Create a mock execution context.""" + return ExecutionContext( + user_id="test-user", + app_id="test-app", + conversation_id="test-conversation", + message_id="test-message", + tenant_id="test-tenant", + ) + + +@pytest.fixture +def mock_tool(): + """Create a mock tool.""" + tool = MagicMock() + tool.entity.identity.name = "test_tool" + tool.to_prompt_message_tool.return_value = PromptMessageTool( + name="test_tool", + description="A test tool", + parameters={ + "type": "object", + "properties": {"param1": {"type": "string", "description": "A parameter"}}, + "required": ["param1"], + }, + ) + return tool + + +class TestFunctionCallStrategyInit: + """Tests for FunctionCallStrategy initialization.""" + + def test_initialization(self, mock_model_instance, mock_context, mock_tool): + """Test basic initialization.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + max_iterations=10, + ) + + assert strategy.model_instance == mock_model_instance + assert strategy.context == mock_context + assert strategy.max_iterations == 10 + assert len(strategy.tools) == 1 + + def test_initialization_with_tool_invoke_hook(self, mock_model_instance, mock_context, mock_tool): + """Test initialization with tool_invoke_hook.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + mock_hook = MagicMock() + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + tool_invoke_hook=mock_hook, + ) + + assert strategy.tool_invoke_hook == mock_hook + + +class TestConvertToolsToPromptFormat: + """Tests for _convert_tools_to_prompt_format method.""" + + def test_convert_tools_returns_prompt_message_tools(self, mock_model_instance, mock_context, mock_tool): + """Test that _convert_tools_to_prompt_format returns PromptMessageTool list.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + tools = strategy._convert_tools_to_prompt_format() + + assert len(tools) == 1 + assert isinstance(tools[0], PromptMessageTool) + assert tools[0].name == "test_tool" + + def test_convert_tools_empty_when_no_tools(self, mock_model_instance, mock_context): + """Test that _convert_tools_to_prompt_format returns empty list when no tools.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + tools = strategy._convert_tools_to_prompt_format() + + assert tools == [] + + +class TestAgentLogGeneration: + """Tests for AgentLog generation during run.""" + + def test_round_log_structure(self, mock_model_instance, mock_context, mock_tool): + """Test that round logs have correct structure.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + max_iterations=1, + ) + + # Create a round log + round_log = strategy._create_log( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={"inputs": {"query": "test"}}, + ) + + assert round_log.label == "ROUND 1" + assert round_log.log_type == AgentLog.LogType.ROUND + assert round_log.status == AgentLog.LogStatus.START + assert "inputs" in round_log.data + + def test_tool_call_log_structure(self, mock_model_instance, mock_context, mock_tool): + """Test that tool call logs have correct structure.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + # Create a parent round log + round_log = strategy._create_log( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={}, + ) + + # Create a tool call log + tool_log = strategy._create_log( + label="CALL test_tool", + log_type=AgentLog.LogType.TOOL_CALL, + status=AgentLog.LogStatus.START, + data={"tool_name": "test_tool", "tool_args": {"param1": "value1"}}, + parent_id=round_log.id, + ) + + assert tool_log.label == "CALL test_tool" + assert tool_log.log_type == AgentLog.LogType.TOOL_CALL + assert tool_log.parent_id == round_log.id + assert tool_log.data["tool_name"] == "test_tool" + + +class TestToolInvocation: + """Tests for tool invocation.""" + + def test_invoke_tool_with_hook(self, mock_model_instance, mock_context, mock_tool): + """Test that tool invocation uses hook when provided.""" + from core.agent.patterns.function_call import FunctionCallStrategy + from core.tools.entities.tool_entities import ToolInvokeMeta + + mock_hook = MagicMock() + mock_meta = ToolInvokeMeta( + time_cost=0.5, + error=None, + tool_config={"tool_provider_type": "test", "tool_provider": "test_id"}, + ) + mock_hook.return_value = ("Tool result", ["file-1"], mock_meta) + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + tool_invoke_hook=mock_hook, + ) + + result, files, meta = strategy._invoke_tool(mock_tool, {"param1": "value"}, "test_tool") + + mock_hook.assert_called_once() + assert result == "Tool result" + assert files == [] # Hook returns file IDs, but _invoke_tool returns empty File list + assert meta == mock_meta + + def test_invoke_tool_without_hook_attribute_set(self, mock_model_instance, mock_context, mock_tool): + """Test that tool_invoke_hook is None when not provided.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + tool_invoke_hook=None, + ) + + # Verify that tool_invoke_hook is None + assert strategy.tool_invoke_hook is None + + +class TestUsageTracking: + """Tests for usage tracking across rounds.""" + + def test_round_usage_is_separate_from_total(self, mock_model_instance, mock_context): + """Test that round usage is tracked separately from total.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + # Simulate two rounds of usage + total_usage: dict = {"usage": None} + round1_usage: dict = {"usage": None} + round2_usage: dict = {"usage": None} + + # Round 1 + usage1 = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.1"), + completion_tokens=50, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.1"), + total_tokens=150, + total_price=Decimal("0.2"), + currency="USD", + latency=0.5, + ) + strategy._accumulate_usage(round1_usage, usage1) + strategy._accumulate_usage(total_usage, usage1) + + # Round 2 + usage2 = LLMUsage( + prompt_tokens=200, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.2"), + completion_tokens=100, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.2"), + total_tokens=300, + total_price=Decimal("0.4"), + currency="USD", + latency=0.5, + ) + strategy._accumulate_usage(round2_usage, usage2) + strategy._accumulate_usage(total_usage, usage2) + + # Verify round usage is separate + assert round1_usage["usage"].total_tokens == 150 + assert round2_usage["usage"].total_tokens == 300 + # Verify total is accumulated + assert total_usage["usage"].total_tokens == 450 + + +class TestPromptMessageHandling: + """Tests for prompt message handling.""" + + def test_messages_include_system_and_user(self, mock_model_instance, mock_context, mock_tool): + """Test that messages include system and user prompts.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + messages = [ + SystemPromptMessage(content="You are a helpful assistant."), + UserPromptMessage(content="Hello"), + ] + + # Just verify the messages can be processed + assert len(messages) == 2 + assert isinstance(messages[0], SystemPromptMessage) + assert isinstance(messages[1], UserPromptMessage) + + def test_assistant_message_with_tool_calls(self, mock_model_instance, mock_context, mock_tool): + """Test that assistant messages can contain tool calls.""" + from core.model_runtime.entities.message_entities import AssistantPromptMessage + + tool_call = AssistantPromptMessage.ToolCall( + id="call_123", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name="test_tool", + arguments='{"param1": "value1"}', + ), + ) + + assistant_message = AssistantPromptMessage( + content="I'll help you with that.", + tool_calls=[tool_call], + ) + + assert len(assistant_message.tool_calls) == 1 + assert assistant_message.tool_calls[0].function.name == "test_tool" diff --git a/api/tests/unit_tests/core/agent/patterns/test_react.py b/api/tests/unit_tests/core/agent/patterns/test_react.py new file mode 100644 index 0000000000..a942ba6100 --- /dev/null +++ b/api/tests/unit_tests/core/agent/patterns/test_react.py @@ -0,0 +1,224 @@ +"""Tests for ReActStrategy.""" + +from unittest.mock import MagicMock + +import pytest + +from core.agent.entities import ExecutionContext +from core.agent.patterns.react import ReActStrategy +from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage + + +@pytest.fixture +def mock_model_instance(): + """Create a mock model instance.""" + model_instance = MagicMock() + model_instance.model = "test-model" + model_instance.provider = "test-provider" + return model_instance + + +@pytest.fixture +def mock_context(): + """Create a mock execution context.""" + return ExecutionContext( + user_id="test-user", + app_id="test-app", + conversation_id="test-conversation", + message_id="test-message", + tenant_id="test-tenant", + ) + + +@pytest.fixture +def mock_tool(): + """Create a mock tool.""" + from core.model_runtime.entities.message_entities import PromptMessageTool + + tool = MagicMock() + tool.entity.identity.name = "test_tool" + tool.entity.identity.provider = "test_provider" + + # Use real PromptMessageTool for proper serialization + prompt_tool = PromptMessageTool( + name="test_tool", + description="A test tool", + parameters={"type": "object", "properties": {}}, + ) + tool.to_prompt_message_tool.return_value = prompt_tool + + return tool + + +class TestReActStrategyInit: + """Tests for ReActStrategy initialization.""" + + def test_init_with_instruction(self, mock_model_instance, mock_context): + """Test that instruction is stored correctly.""" + instruction = "You are a helpful assistant." + + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + instruction=instruction, + ) + + assert strategy.instruction == instruction + + def test_init_with_empty_instruction(self, mock_model_instance, mock_context): + """Test that empty instruction is handled correctly.""" + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + assert strategy.instruction == "" + + +class TestBuildPromptWithReactFormat: + """Tests for _build_prompt_with_react_format method.""" + + def test_replace_tools_placeholder(self, mock_model_instance, mock_context, mock_tool): + """Test that {{tools}} placeholder is replaced.""" + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + system_content = "You have access to: {{tools}}" + messages = [ + SystemPromptMessage(content=system_content), + UserPromptMessage(content="Hello"), + ] + + result = strategy._build_prompt_with_react_format(messages, [], True) + + # The tools placeholder should be replaced with JSON + assert "{{tools}}" not in result[0].content + assert "test_tool" in result[0].content + + def test_replace_tool_names_placeholder(self, mock_model_instance, mock_context, mock_tool): + """Test that {{tool_names}} placeholder is replaced.""" + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + system_content = "Valid actions: {{tool_names}}" + messages = [ + SystemPromptMessage(content=system_content), + ] + + result = strategy._build_prompt_with_react_format(messages, [], True) + + assert "{{tool_names}}" not in result[0].content + assert '"test_tool"' in result[0].content + + def test_replace_instruction_placeholder(self, mock_model_instance, mock_context): + """Test that {{instruction}} placeholder is replaced.""" + instruction = "You are a helpful coding assistant." + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + instruction=instruction, + ) + + system_content = "{{instruction}}\n\nYou have access to: {{tools}}" + messages = [ + SystemPromptMessage(content=system_content), + ] + + result = strategy._build_prompt_with_react_format(messages, [], True, instruction) + + assert "{{instruction}}" not in result[0].content + assert instruction in result[0].content + + def test_no_tools_available_message(self, mock_model_instance, mock_context): + """Test that 'No tools available' is shown when include_tools is False.""" + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + system_content = "You have access to: {{tools}}" + messages = [ + SystemPromptMessage(content=system_content), + ] + + result = strategy._build_prompt_with_react_format(messages, [], False) + + assert "No tools available" in result[0].content + + def test_scratchpad_appended_as_assistant_message(self, mock_model_instance, mock_context): + """Test that agent scratchpad is appended as AssistantPromptMessage.""" + from core.agent.entities import AgentScratchpadUnit + from core.model_runtime.entities import AssistantPromptMessage + + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + messages = [ + SystemPromptMessage(content="System prompt"), + UserPromptMessage(content="User query"), + ] + + scratchpad = [ + AgentScratchpadUnit( + thought="I need to search for information", + action_str='{"action": "search", "action_input": "query"}', + observation="Search results here", + ) + ] + + result = strategy._build_prompt_with_react_format(messages, scratchpad, True) + + # The last message should be an AssistantPromptMessage with scratchpad content + assert len(result) == 3 + assert isinstance(result[-1], AssistantPromptMessage) + assert "I need to search for information" in result[-1].content + assert "Search results here" in result[-1].content + + def test_empty_scratchpad_no_extra_message(self, mock_model_instance, mock_context): + """Test that empty scratchpad doesn't add extra message.""" + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + messages = [ + SystemPromptMessage(content="System prompt"), + UserPromptMessage(content="User query"), + ] + + result = strategy._build_prompt_with_react_format(messages, [], True) + + # Should only have the original 2 messages + assert len(result) == 2 + + def test_original_messages_not_modified(self, mock_model_instance, mock_context): + """Test that original messages list is not modified.""" + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + original_content = "Original system prompt {{tools}}" + messages = [ + SystemPromptMessage(content=original_content), + ] + + strategy._build_prompt_with_react_format(messages, [], True) + + # Original message should not be modified + assert messages[0].content == original_content diff --git a/api/tests/unit_tests/core/agent/patterns/test_strategy_factory.py b/api/tests/unit_tests/core/agent/patterns/test_strategy_factory.py new file mode 100644 index 0000000000..07b9df2acf --- /dev/null +++ b/api/tests/unit_tests/core/agent/patterns/test_strategy_factory.py @@ -0,0 +1,203 @@ +"""Tests for StrategyFactory.""" + +from unittest.mock import MagicMock + +import pytest + +from core.agent.entities import AgentEntity, ExecutionContext +from core.agent.patterns.function_call import FunctionCallStrategy +from core.agent.patterns.react import ReActStrategy +from core.agent.patterns.strategy_factory import StrategyFactory +from core.model_runtime.entities.model_entities import ModelFeature + + +@pytest.fixture +def mock_model_instance(): + """Create a mock model instance.""" + model_instance = MagicMock() + model_instance.model = "test-model" + model_instance.provider = "test-provider" + return model_instance + + +@pytest.fixture +def mock_context(): + """Create a mock execution context.""" + return ExecutionContext( + user_id="test-user", + app_id="test-app", + conversation_id="test-conversation", + message_id="test-message", + tenant_id="test-tenant", + ) + + +class TestStrategyFactory: + """Tests for StrategyFactory.create_strategy method.""" + + def test_create_function_call_strategy_with_tool_call_feature(self, mock_model_instance, mock_context): + """Test that FunctionCallStrategy is created when model supports TOOL_CALL.""" + model_features = [ModelFeature.TOOL_CALL] + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + ) + + assert isinstance(strategy, FunctionCallStrategy) + + def test_create_function_call_strategy_with_multi_tool_call_feature(self, mock_model_instance, mock_context): + """Test that FunctionCallStrategy is created when model supports MULTI_TOOL_CALL.""" + model_features = [ModelFeature.MULTI_TOOL_CALL] + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + ) + + assert isinstance(strategy, FunctionCallStrategy) + + def test_create_function_call_strategy_with_stream_tool_call_feature(self, mock_model_instance, mock_context): + """Test that FunctionCallStrategy is created when model supports STREAM_TOOL_CALL.""" + model_features = [ModelFeature.STREAM_TOOL_CALL] + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + ) + + assert isinstance(strategy, FunctionCallStrategy) + + def test_create_react_strategy_without_tool_call_features(self, mock_model_instance, mock_context): + """Test that ReActStrategy is created when model doesn't support tool calling.""" + model_features = [ModelFeature.VISION] # Only vision, no tool calling + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + ) + + assert isinstance(strategy, ReActStrategy) + + def test_create_react_strategy_with_empty_features(self, mock_model_instance, mock_context): + """Test that ReActStrategy is created when model has no features.""" + model_features: list[ModelFeature] = [] + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + ) + + assert isinstance(strategy, ReActStrategy) + + def test_explicit_function_calling_strategy_with_support(self, mock_model_instance, mock_context): + """Test explicit FUNCTION_CALLING strategy selection with model support.""" + model_features = [ModelFeature.TOOL_CALL] + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING, + ) + + assert isinstance(strategy, FunctionCallStrategy) + + def test_explicit_function_calling_strategy_without_support_falls_back_to_react( + self, mock_model_instance, mock_context + ): + """Test that explicit FUNCTION_CALLING falls back to ReAct when not supported.""" + model_features: list[ModelFeature] = [] # No tool calling support + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING, + ) + + # Should fall back to ReAct since FC is not supported + assert isinstance(strategy, ReActStrategy) + + def test_explicit_chain_of_thought_strategy(self, mock_model_instance, mock_context): + """Test explicit CHAIN_OF_THOUGHT strategy selection.""" + model_features = [ModelFeature.TOOL_CALL] # Even with tool call support + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + agent_strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT, + ) + + assert isinstance(strategy, ReActStrategy) + + def test_react_strategy_with_instruction(self, mock_model_instance, mock_context): + """Test that ReActStrategy receives instruction parameter.""" + model_features: list[ModelFeature] = [] + instruction = "You are a helpful assistant." + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + instruction=instruction, + ) + + assert isinstance(strategy, ReActStrategy) + assert strategy.instruction == instruction + + def test_max_iterations_passed_to_strategy(self, mock_model_instance, mock_context): + """Test that max_iterations is passed to the strategy.""" + model_features = [ModelFeature.TOOL_CALL] + max_iterations = 5 + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + max_iterations=max_iterations, + ) + + assert strategy.max_iterations == max_iterations + + def test_tool_invoke_hook_passed_to_strategy(self, mock_model_instance, mock_context): + """Test that tool_invoke_hook is passed to the strategy.""" + model_features = [ModelFeature.TOOL_CALL] + mock_hook = MagicMock() + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + tool_invoke_hook=mock_hook, + ) + + assert strategy.tool_invoke_hook == mock_hook diff --git a/api/tests/unit_tests/core/agent/test_agent_app_runner.py b/api/tests/unit_tests/core/agent/test_agent_app_runner.py new file mode 100644 index 0000000000..d9301ccfe0 --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_agent_app_runner.py @@ -0,0 +1,388 @@ +"""Tests for AgentAppRunner.""" + +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import pytest + +from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentResult +from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage +from core.model_runtime.entities.llm_entities import LLMUsage + + +class TestOrganizePromptMessages: + """Tests for _organize_prompt_messages method.""" + + @pytest.fixture + def mock_runner(self): + """Create a mock AgentAppRunner for testing.""" + # We'll patch the class to avoid complex initialization + with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None): + from core.agent.agent_app_runner import AgentAppRunner + + runner = AgentAppRunner.__new__(AgentAppRunner) + + # Set up required attributes + runner.config = MagicMock(spec=AgentEntity) + runner.config.strategy = AgentEntity.Strategy.FUNCTION_CALLING + runner.config.prompt = None + + runner.app_config = MagicMock() + runner.app_config.prompt_template = MagicMock() + runner.app_config.prompt_template.simple_prompt_template = "You are a helpful assistant." + + runner.history_prompt_messages = [] + runner.query = "Hello" + runner._current_thoughts = [] + runner.files = [] + runner.model_config = MagicMock() + runner.memory = None + runner.application_generate_entity = MagicMock() + runner.application_generate_entity.file_upload_config = None + + return runner + + def test_function_calling_uses_simple_prompt(self, mock_runner): + """Test that function calling strategy uses simple_prompt_template.""" + mock_runner.config.strategy = AgentEntity.Strategy.FUNCTION_CALLING + + with patch.object(mock_runner, "_init_system_message") as mock_init: + mock_init.return_value = [SystemPromptMessage(content="You are a helpful assistant.")] + with patch.object(mock_runner, "_organize_user_query") as mock_query: + mock_query.return_value = [UserPromptMessage(content="Hello")] + with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform: + mock_transform.return_value.get_prompt.return_value = [ + SystemPromptMessage(content="You are a helpful assistant.") + ] + + result = mock_runner._organize_prompt_messages() + + # Verify _init_system_message was called with simple_prompt_template + mock_init.assert_called_once() + call_args = mock_init.call_args[0] + assert call_args[0] == "You are a helpful assistant." + + def test_chain_of_thought_uses_agent_prompt(self, mock_runner): + """Test that chain of thought strategy uses agent prompt template.""" + mock_runner.config.strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + mock_runner.config.prompt = AgentPromptEntity( + first_prompt="ReAct prompt template with {{tools}}", + next_iteration="Continue...", + ) + + with patch.object(mock_runner, "_init_system_message") as mock_init: + mock_init.return_value = [SystemPromptMessage(content="ReAct prompt template with {{tools}}")] + with patch.object(mock_runner, "_organize_user_query") as mock_query: + mock_query.return_value = [UserPromptMessage(content="Hello")] + with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform: + mock_transform.return_value.get_prompt.return_value = [ + SystemPromptMessage(content="ReAct prompt template with {{tools}}") + ] + + result = mock_runner._organize_prompt_messages() + + # Verify _init_system_message was called with agent prompt + mock_init.assert_called_once() + call_args = mock_init.call_args[0] + assert call_args[0] == "ReAct prompt template with {{tools}}" + + def test_chain_of_thought_without_prompt_falls_back(self, mock_runner): + """Test that chain of thought without prompt falls back to simple_prompt_template.""" + mock_runner.config.strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + mock_runner.config.prompt = None + + with patch.object(mock_runner, "_init_system_message") as mock_init: + mock_init.return_value = [SystemPromptMessage(content="You are a helpful assistant.")] + with patch.object(mock_runner, "_organize_user_query") as mock_query: + mock_query.return_value = [UserPromptMessage(content="Hello")] + with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform: + mock_transform.return_value.get_prompt.return_value = [ + SystemPromptMessage(content="You are a helpful assistant.") + ] + + result = mock_runner._organize_prompt_messages() + + # Verify _init_system_message was called with simple_prompt_template + mock_init.assert_called_once() + call_args = mock_init.call_args[0] + assert call_args[0] == "You are a helpful assistant." + + +class TestInitSystemMessage: + """Tests for _init_system_message method.""" + + @pytest.fixture + def mock_runner(self): + """Create a mock AgentAppRunner for testing.""" + with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None): + from core.agent.agent_app_runner import AgentAppRunner + + runner = AgentAppRunner.__new__(AgentAppRunner) + return runner + + def test_empty_messages_with_template(self, mock_runner): + """Test that system message is created when messages are empty.""" + result = mock_runner._init_system_message("System template", []) + + assert len(result) == 1 + assert isinstance(result[0], SystemPromptMessage) + assert result[0].content == "System template" + + def test_empty_messages_without_template(self, mock_runner): + """Test that empty list is returned when no template and no messages.""" + result = mock_runner._init_system_message("", []) + + assert result == [] + + def test_existing_system_message_not_duplicated(self, mock_runner): + """Test that system message is not duplicated if already present.""" + existing_messages = [ + SystemPromptMessage(content="Existing system"), + UserPromptMessage(content="User message"), + ] + + result = mock_runner._init_system_message("New template", existing_messages) + + # Should not insert new system message + assert len(result) == 2 + assert result[0].content == "Existing system" + + def test_system_message_inserted_when_missing(self, mock_runner): + """Test that system message is inserted when first message is not system.""" + existing_messages = [ + UserPromptMessage(content="User message"), + ] + + result = mock_runner._init_system_message("System template", existing_messages) + + assert len(result) == 2 + assert isinstance(result[0], SystemPromptMessage) + assert result[0].content == "System template" + + +class TestClearUserPromptImageMessages: + """Tests for _clear_user_prompt_image_messages method.""" + + @pytest.fixture + def mock_runner(self): + """Create a mock AgentAppRunner for testing.""" + with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None): + from core.agent.agent_app_runner import AgentAppRunner + + runner = AgentAppRunner.__new__(AgentAppRunner) + return runner + + def test_text_content_unchanged(self, mock_runner): + """Test that text content is unchanged.""" + messages = [ + UserPromptMessage(content="Plain text message"), + ] + + result = mock_runner._clear_user_prompt_image_messages(messages) + + assert len(result) == 1 + assert result[0].content == "Plain text message" + + def test_original_messages_not_modified(self, mock_runner): + """Test that original messages are not modified (deep copy).""" + from core.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + TextPromptMessageContent, + ) + + messages = [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data="Text part"), + ImagePromptMessageContent( + data="http://example.com/image.jpg", + format="url", + mime_type="image/jpeg", + ), + ] + ), + ] + + result = mock_runner._clear_user_prompt_image_messages(messages) + + # Original should still have list content + assert isinstance(messages[0].content, list) + # Result should have string content + assert isinstance(result[0].content, str) + + +class TestToolInvokeHook: + """Tests for _create_tool_invoke_hook method.""" + + @pytest.fixture + def mock_runner(self): + """Create a mock AgentAppRunner for testing.""" + with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None): + from core.agent.agent_app_runner import AgentAppRunner + + runner = AgentAppRunner.__new__(AgentAppRunner) + + runner.user_id = "test-user" + runner.tenant_id = "test-tenant" + runner.application_generate_entity = MagicMock() + runner.application_generate_entity.trace_manager = None + runner.application_generate_entity.invoke_from = "api" + runner.application_generate_entity.app_config = MagicMock() + runner.application_generate_entity.app_config.app_id = "test-app" + runner.agent_callback = MagicMock() + runner.conversation = MagicMock() + runner.conversation.id = "test-conversation" + runner.queue_manager = MagicMock() + runner._current_message_file_ids = [] + + return runner + + def test_hook_calls_agent_invoke(self, mock_runner): + """Test that the hook calls ToolEngine.agent_invoke.""" + from core.tools.entities.tool_entities import ToolInvokeMeta + + mock_message = MagicMock() + mock_message.id = "test-message" + + mock_tool = MagicMock() + mock_tool_meta = ToolInvokeMeta( + time_cost=0.5, + error=None, + tool_config={ + "tool_provider_type": "test_provider", + "tool_provider": "test_id", + }, + ) + + with patch("core.agent.agent_app_runner.ToolEngine") as mock_engine: + mock_engine.agent_invoke.return_value = ("Tool result", ["file-1", "file-2"], mock_tool_meta) + + hook = mock_runner._create_tool_invoke_hook(mock_message) + result_content, result_files, result_meta = hook(mock_tool, {"arg": "value"}, "test_tool") + + # Verify ToolEngine.agent_invoke was called + mock_engine.agent_invoke.assert_called_once() + + # Verify return values + assert result_content == "Tool result" + assert result_files == ["file-1", "file-2"] + assert result_meta == mock_tool_meta + + def test_hook_publishes_file_events(self, mock_runner): + """Test that the hook publishes QueueMessageFileEvent for files.""" + from core.tools.entities.tool_entities import ToolInvokeMeta + + mock_message = MagicMock() + mock_message.id = "test-message" + + mock_tool = MagicMock() + mock_tool_meta = ToolInvokeMeta( + time_cost=0.5, + error=None, + tool_config={}, + ) + + with patch("core.agent.agent_app_runner.ToolEngine") as mock_engine: + mock_engine.agent_invoke.return_value = ("Tool result", ["file-1", "file-2"], mock_tool_meta) + + hook = mock_runner._create_tool_invoke_hook(mock_message) + hook(mock_tool, {}, "test_tool") + + # Verify file events were published + assert mock_runner.queue_manager.publish.call_count == 2 + assert mock_runner._current_message_file_ids == ["file-1", "file-2"] + + +class TestAgentLogProcessing: + """Tests for AgentLog processing in run method.""" + + def test_agent_log_status_enum(self): + """Test AgentLog status enum values.""" + assert AgentLog.LogStatus.START == "start" + assert AgentLog.LogStatus.SUCCESS == "success" + assert AgentLog.LogStatus.ERROR == "error" + + def test_agent_log_metadata_enum(self): + """Test AgentLog metadata enum values.""" + assert AgentLog.LogMetadata.STARTED_AT == "started_at" + assert AgentLog.LogMetadata.FINISHED_AT == "finished_at" + assert AgentLog.LogMetadata.ELAPSED_TIME == "elapsed_time" + assert AgentLog.LogMetadata.TOTAL_PRICE == "total_price" + assert AgentLog.LogMetadata.TOTAL_TOKENS == "total_tokens" + assert AgentLog.LogMetadata.LLM_USAGE == "llm_usage" + + def test_agent_result_structure(self): + """Test AgentResult structure.""" + usage = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.1"), + completion_tokens=50, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.1"), + total_tokens=150, + total_price=Decimal("0.2"), + currency="USD", + latency=0.5, + ) + + result = AgentResult( + text="Final answer", + files=[], + usage=usage, + finish_reason="stop", + ) + + assert result.text == "Final answer" + assert result.files == [] + assert result.usage == usage + assert result.finish_reason == "stop" + + +class TestOrganizeUserQuery: + """Tests for _organize_user_query method.""" + + @pytest.fixture + def mock_runner(self): + """Create a mock AgentAppRunner for testing.""" + with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None): + from core.agent.agent_app_runner import AgentAppRunner + + runner = AgentAppRunner.__new__(AgentAppRunner) + runner.files = [] + runner.application_generate_entity = MagicMock() + runner.application_generate_entity.file_upload_config = None + return runner + + def test_simple_query_without_files(self, mock_runner): + """Test organizing a simple query without files.""" + result = mock_runner._organize_user_query("Hello world", []) + + assert len(result) == 1 + assert isinstance(result[0], UserPromptMessage) + assert result[0].content == "Hello world" + + def test_query_with_files(self, mock_runner): + """Test organizing a query with files.""" + from core.file.models import File + + mock_file = MagicMock(spec=File) + mock_runner.files = [mock_file] + + with patch("core.agent.agent_app_runner.file_manager") as mock_fm: + from core.model_runtime.entities.message_entities import ImagePromptMessageContent + + mock_fm.to_prompt_message_content.return_value = ImagePromptMessageContent( + data="http://example.com/image.jpg", + format="url", + mime_type="image/jpeg", + ) + + result = mock_runner._organize_user_query("Describe this image", []) + + assert len(result) == 1 + assert isinstance(result[0], UserPromptMessage) + assert isinstance(result[0].content, list) + assert len(result[0].content) == 2 # Image + Text diff --git a/api/tests/unit_tests/core/agent/test_entities.py b/api/tests/unit_tests/core/agent/test_entities.py new file mode 100644 index 0000000000..5136f48aab --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_entities.py @@ -0,0 +1,191 @@ +"""Tests for agent entities.""" + +from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentScratchpadUnit, ExecutionContext + + +class TestExecutionContext: + """Tests for ExecutionContext entity.""" + + def test_create_with_all_fields(self): + """Test creating ExecutionContext with all fields.""" + context = ExecutionContext( + user_id="user-123", + app_id="app-456", + conversation_id="conv-789", + message_id="msg-012", + tenant_id="tenant-345", + ) + + assert context.user_id == "user-123" + assert context.app_id == "app-456" + assert context.conversation_id == "conv-789" + assert context.message_id == "msg-012" + assert context.tenant_id == "tenant-345" + + def test_create_minimal(self): + """Test creating minimal ExecutionContext.""" + context = ExecutionContext.create_minimal(user_id="user-123") + + assert context.user_id == "user-123" + assert context.app_id is None + assert context.conversation_id is None + assert context.message_id is None + assert context.tenant_id is None + + def test_to_dict(self): + """Test converting ExecutionContext to dictionary.""" + context = ExecutionContext( + user_id="user-123", + app_id="app-456", + conversation_id="conv-789", + message_id="msg-012", + tenant_id="tenant-345", + ) + + result = context.to_dict() + + assert result == { + "user_id": "user-123", + "app_id": "app-456", + "conversation_id": "conv-789", + "message_id": "msg-012", + "tenant_id": "tenant-345", + } + + def test_with_updates(self): + """Test creating new context with updates.""" + original = ExecutionContext( + user_id="user-123", + app_id="app-456", + ) + + updated = original.with_updates(message_id="msg-789") + + # Original should be unchanged + assert original.message_id is None + # Updated should have new value + assert updated.message_id == "msg-789" + assert updated.user_id == "user-123" + assert updated.app_id == "app-456" + + +class TestAgentLog: + """Tests for AgentLog entity.""" + + def test_create_log_with_required_fields(self): + """Test creating AgentLog with required fields.""" + log = AgentLog( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={"key": "value"}, + ) + + assert log.label == "ROUND 1" + assert log.log_type == AgentLog.LogType.ROUND + assert log.status == AgentLog.LogStatus.START + assert log.data == {"key": "value"} + assert log.id is not None # Auto-generated + assert log.parent_id is None + assert log.error is None + + def test_log_type_enum(self): + """Test LogType enum values.""" + assert AgentLog.LogType.ROUND == "round" + assert AgentLog.LogType.THOUGHT == "thought" + assert AgentLog.LogType.TOOL_CALL == "tool_call" + + def test_log_status_enum(self): + """Test LogStatus enum values.""" + assert AgentLog.LogStatus.START == "start" + assert AgentLog.LogStatus.SUCCESS == "success" + assert AgentLog.LogStatus.ERROR == "error" + + def test_log_metadata_enum(self): + """Test LogMetadata enum values.""" + assert AgentLog.LogMetadata.STARTED_AT == "started_at" + assert AgentLog.LogMetadata.FINISHED_AT == "finished_at" + assert AgentLog.LogMetadata.ELAPSED_TIME == "elapsed_time" + assert AgentLog.LogMetadata.TOTAL_PRICE == "total_price" + assert AgentLog.LogMetadata.TOTAL_TOKENS == "total_tokens" + assert AgentLog.LogMetadata.LLM_USAGE == "llm_usage" + + +class TestAgentScratchpadUnit: + """Tests for AgentScratchpadUnit entity.""" + + def test_is_final_with_final_answer_action(self): + """Test is_final returns True for Final Answer action.""" + unit = AgentScratchpadUnit( + thought="I know the answer", + action=AgentScratchpadUnit.Action( + action_name="Final Answer", + action_input="The answer is 42", + ), + ) + + assert unit.is_final() is True + + def test_is_final_with_tool_action(self): + """Test is_final returns False for tool action.""" + unit = AgentScratchpadUnit( + thought="I need to search", + action=AgentScratchpadUnit.Action( + action_name="search", + action_input={"query": "test"}, + ), + ) + + assert unit.is_final() is False + + def test_is_final_with_no_action(self): + """Test is_final returns True when no action.""" + unit = AgentScratchpadUnit( + thought="Just thinking", + ) + + assert unit.is_final() is True + + def test_action_to_dict(self): + """Test Action.to_dict method.""" + action = AgentScratchpadUnit.Action( + action_name="search", + action_input={"query": "test"}, + ) + + result = action.to_dict() + + assert result == { + "action": "search", + "action_input": {"query": "test"}, + } + + +class TestAgentEntity: + """Tests for AgentEntity.""" + + def test_strategy_enum(self): + """Test Strategy enum values.""" + assert AgentEntity.Strategy.CHAIN_OF_THOUGHT == "chain-of-thought" + assert AgentEntity.Strategy.FUNCTION_CALLING == "function-calling" + + def test_create_with_prompt(self): + """Test creating AgentEntity with prompt.""" + prompt = AgentPromptEntity( + first_prompt="You are a helpful assistant.", + next_iteration="Continue thinking...", + ) + + entity = AgentEntity( + provider="openai", + model="gpt-4", + strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT, + prompt=prompt, + max_iteration=5, + ) + + assert entity.provider == "openai" + assert entity.model == "gpt-4" + assert entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT + assert entity.prompt == prompt + assert entity.max_iteration == 5 diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_answer_node.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_answer_node.py new file mode 100644 index 0000000000..205b157542 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_answer_node.py @@ -0,0 +1,390 @@ +""" +Tests for AdvancedChatAppGenerateTaskPipeline._handle_node_succeeded_event method, +specifically testing the ANSWER node message_replace logic. +""" + +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity +from core.app.entities.queue_entities import QueueNodeSucceededEvent +from core.workflow.enums import NodeType +from models import EndUser +from models.model import AppMode + + +class TestAnswerNodeMessageReplace: + """Test cases for ANSWER node message_replace event logic.""" + + @pytest.fixture + def mock_application_generate_entity(self): + """Create a mock application generate entity.""" + entity = Mock(spec=AdvancedChatAppGenerateEntity) + entity.task_id = "test-task-id" + entity.app_id = "test-app-id" + entity.workflow_run_id = "test-workflow-run-id" + # minimal app_config used by pipeline internals + entity.app_config = SimpleNamespace( + tenant_id="test-tenant-id", + app_id="test-app-id", + app_mode=AppMode.ADVANCED_CHAT, + app_model_config_dict={}, + additional_features=None, + sensitive_word_avoidance=None, + ) + entity.query = "test query" + entity.files = [] + entity.extras = {} + entity.trace_manager = None + entity.inputs = {} + entity.invoke_from = "debugger" + return entity + + @pytest.fixture + def mock_workflow(self): + """Create a mock workflow.""" + workflow = Mock() + workflow.id = "test-workflow-id" + workflow.features_dict = {} + return workflow + + @pytest.fixture + def mock_queue_manager(self): + """Create a mock queue manager.""" + manager = Mock() + manager.listen.return_value = [] + manager.graph_runtime_state = None + return manager + + @pytest.fixture + def mock_conversation(self): + """Create a mock conversation.""" + conversation = Mock() + conversation.id = "test-conversation-id" + conversation.mode = "advanced_chat" + return conversation + + @pytest.fixture + def mock_message(self): + """Create a mock message.""" + message = Mock() + message.id = "test-message-id" + message.query = "test query" + message.created_at = Mock() + message.created_at.timestamp.return_value = 1234567890 + return message + + @pytest.fixture + def mock_user(self): + """Create a mock end user.""" + user = MagicMock(spec=EndUser) + user.id = "test-user-id" + user.session_id = "test-session-id" + return user + + @pytest.fixture + def mock_draft_var_saver_factory(self): + """Create a mock draft variable saver factory.""" + return Mock() + + @pytest.fixture + def pipeline( + self, + mock_application_generate_entity, + mock_workflow, + mock_queue_manager, + mock_conversation, + mock_message, + mock_user, + mock_draft_var_saver_factory, + ): + """Create an AdvancedChatAppGenerateTaskPipeline instance with mocked dependencies.""" + from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline + + with patch("core.app.apps.advanced_chat.generate_task_pipeline.db"): + pipeline = AdvancedChatAppGenerateTaskPipeline( + application_generate_entity=mock_application_generate_entity, + workflow=mock_workflow, + queue_manager=mock_queue_manager, + conversation=mock_conversation, + message=mock_message, + user=mock_user, + stream=True, + dialogue_count=1, + draft_var_saver_factory=mock_draft_var_saver_factory, + ) + # Initialize workflow run id to avoid validation errors + pipeline._workflow_run_id = "test-workflow-run-id" + # Mock the message cycle manager methods we need to track + pipeline._message_cycle_manager.message_replace_to_stream_response = Mock() + return pipeline + + def test_answer_node_with_different_output_sends_message_replace(self, pipeline, mock_application_generate_entity): + """ + Test that when an ANSWER node's final output differs from accumulated answer, + a message_replace event is sent. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "initial answer" + + # Create ANSWER node succeeded event with different final output + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={"answer": "updated final answer"}, + ) + + # Mock the workflow response converter to avoid extra processing + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + responses = list(pipeline._handle_node_succeeded_event(event)) + + # Assert + assert pipeline._task_state.answer == "updated final answer" + # Verify message_replace was called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_called_once_with( + answer="updated final answer", reason="variable_update" + ) + + def test_answer_node_with_same_output_does_not_send_message_replace(self, pipeline): + """ + Test that when an ANSWER node's final output is the same as accumulated answer, + no message_replace event is sent. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "same answer" + + # Create ANSWER node succeeded event with same output + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={"answer": "same answer"}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: answer should remain unchanged + assert pipeline._task_state.answer == "same answer" + # Verify message_replace was NOT called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called() + + def test_answer_node_with_none_output_does_not_send_message_replace(self, pipeline): + """ + Test that when an ANSWER node's output is None or missing 'answer' key, + no message_replace event is sent. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "existing answer" + + # Create ANSWER node succeeded event with None output + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={"answer": None}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: answer should remain unchanged + assert pipeline._task_state.answer == "existing answer" + # Verify message_replace was NOT called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called() + + def test_answer_node_with_empty_outputs_does_not_send_message_replace(self, pipeline): + """ + Test that when an ANSWER node has empty outputs dict, + no message_replace event is sent. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "existing answer" + + # Create ANSWER node succeeded event with empty outputs + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: answer should remain unchanged + assert pipeline._task_state.answer == "existing answer" + # Verify message_replace was NOT called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called() + + def test_answer_node_with_no_answer_key_in_outputs(self, pipeline): + """ + Test that when an ANSWER node's outputs don't contain 'answer' key, + no message_replace event is sent. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "existing answer" + + # Create ANSWER node succeeded event without 'answer' key in outputs + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={"other_key": "some value"}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: answer should remain unchanged + assert pipeline._task_state.answer == "existing answer" + # Verify message_replace was NOT called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called() + + def test_non_answer_node_does_not_send_message_replace(self, pipeline): + """ + Test that non-ANSWER nodes (e.g., LLM, END) don't trigger message_replace events. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "existing answer" + + # Test with LLM node + llm_event = QueueNodeSucceededEvent( + node_execution_id="test-llm-execution-id", + node_id="test-llm-node", + node_type=NodeType.LLM, + start_at=datetime.now(), + outputs={"answer": "different answer"}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(llm_event)) + + # Assert: answer should remain unchanged + assert pipeline._task_state.answer == "existing answer" + # Verify message_replace was NOT called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called() + + def test_end_node_does_not_send_message_replace(self, pipeline): + """ + Test that END nodes don't trigger message_replace events even with 'answer' output. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "existing answer" + + # Create END node succeeded event with answer output + event = QueueNodeSucceededEvent( + node_execution_id="test-end-execution-id", + node_id="test-end-node", + node_type=NodeType.END, + start_at=datetime.now(), + outputs={"answer": "different answer"}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: answer should remain unchanged + assert pipeline._task_state.answer == "existing answer" + # Verify message_replace was NOT called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called() + + def test_answer_node_with_numeric_output_converts_to_string(self, pipeline): + """ + Test that when an ANSWER node's final output is numeric, + it gets converted to string properly. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "text answer" + + # Create ANSWER node succeeded event with numeric output + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={"answer": 12345}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: answer should be converted to string + assert pipeline._task_state.answer == "12345" + # Verify message_replace was called with string + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_called_once_with( + answer="12345", reason="variable_update" + ) + + def test_answer_node_files_are_recorded(self, pipeline): + """ + Test that ANSWER nodes properly record files from outputs. + """ + # Arrange + pipeline._task_state.answer = "existing answer" + + # Create ANSWER node succeeded event with files + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={ + "answer": "same answer", + "files": [ + {"type": "image", "transfer_method": "remote_url", "remote_url": "http://example.com/img.png"} + ], + }, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.fetch_files_from_node_outputs = Mock(return_value=event.outputs["files"]) + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: files should be recorded + assert len(pipeline._recorded_files) == 1 + assert pipeline._recorded_files[0] == event.outputs["files"][0] diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index 1c9f577a50..6b40bf462b 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -431,10 +431,10 @@ class TestWorkflowResponseConverterServiceApiTruncation: description="Explore calls should have truncation enabled", ), TestCase( - name="published_truncation_enabled", - invoke_from=InvokeFrom.PUBLISHED, + name="published_pipeline_truncation_enabled", + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, expected_truncation_enabled=True, - description="Published app calls should have truncation enabled", + description="Published pipeline calls should have truncation enabled", ), ], ids=lambda x: x.name, diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_stream_chunk.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_stream_chunk.py new file mode 100644 index 0000000000..8779e8c586 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_stream_chunk.py @@ -0,0 +1,47 @@ +from unittest.mock import MagicMock + +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.workflow.graph_events import NodeRunStreamChunkEvent +from core.workflow.nodes import NodeType + + +class DummyQueueManager: + def __init__(self) -> None: + self.published = [] + + def publish(self, event, publish_from: PublishFrom) -> None: + self.published.append((event, publish_from)) + + +def test_skip_empty_final_chunk() -> None: + queue_manager = DummyQueueManager() + runner = WorkflowBasedAppRunner(queue_manager=queue_manager, app_id="app") + + empty_final_event = NodeRunStreamChunkEvent( + id="exec", + node_id="node", + node_type=NodeType.LLM, + selector=["node", "text"], + chunk="", + is_final=True, + ) + + runner._handle_event(workflow_entry=MagicMock(), event=empty_final_event) + assert queue_manager.published == [] + + normal_event = NodeRunStreamChunkEvent( + id="exec", + node_id="node", + node_type=NodeType.LLM, + selector=["node", "text"], + chunk="hi", + is_final=False, + ) + + runner._handle_event(workflow_entry=MagicMock(), event=normal_event) + + assert len(queue_manager.published) == 1 + published_event, publish_from = queue_manager.published[0] + assert publish_from == PublishFrom.APPLICATION_MANAGER + assert published_event.text == "hi" diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py new file mode 100644 index 0000000000..b6e8cc9c8e --- /dev/null +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -0,0 +1,144 @@ +from collections.abc import Sequence +from datetime import datetime +from unittest.mock import Mock + +from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer +from core.variables import StringVariable +from core.variables.segments import Segment +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus +from core.workflow.graph_engine.protocols.command_channel import CommandChannel +from core.workflow.graph_events.node import NodeRunSucceededEvent +from core.workflow.node_events import NodeRunResult +from core.workflow.nodes.variable_assigner.common import helpers as common_helpers +from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState +from core.workflow.system_variable import SystemVariable + + +class MockReadOnlyVariablePool: + def __init__(self, variables: dict[tuple[str, str], Segment] | None = None) -> None: + self._variables = variables or {} + + def get(self, selector: Sequence[str]) -> Segment | None: + if len(selector) < 2: + return None + return self._variables.get((selector[0], selector[1])) + + def get_all_by_node(self, node_id: str) -> dict[str, object]: + return {key: value for (nid, key), value in self._variables.items() if nid == node_id} + + def get_by_prefix(self, prefix: str) -> dict[str, object]: + return {key: value for (nid, key), value in self._variables.items() if nid == prefix} + + +def _build_graph_runtime_state( + variable_pool: MockReadOnlyVariablePool, + conversation_id: str | None = None, +) -> ReadOnlyGraphRuntimeState: + graph_runtime_state = Mock(spec=ReadOnlyGraphRuntimeState) + graph_runtime_state.variable_pool = variable_pool + graph_runtime_state.system_variable = SystemVariable(conversation_id=conversation_id).as_view() + return graph_runtime_state + + +def _build_node_run_succeeded_event( + *, + node_type: NodeType, + outputs: dict[str, object] | None = None, + process_data: dict[str, object] | None = None, +) -> NodeRunSucceededEvent: + return NodeRunSucceededEvent( + id="node-exec-id", + node_id="assigner", + node_type=node_type, + start_at=datetime.utcnow(), + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs=outputs or {}, + process_data=process_data or {}, + ), + ) + + +def test_persists_conversation_variables_from_assigner_output(): + conversation_id = "conv-123" + variable = StringVariable( + id="var-1", + name="name", + value="updated", + selector=[CONVERSATION_VARIABLE_NODE_ID, "name"], + ) + process_data = common_helpers.set_updated_variables( + {}, [common_helpers.variable_to_processed_data(variable.selector, variable)] + ) + + variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable}) + + updater = Mock() + layer = ConversationVariablePersistenceLayer(updater) + layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + + event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data) + layer.on_event(event) + + updater.update.assert_called_once_with(conversation_id=conversation_id, variable=variable) + updater.flush.assert_called_once() + + +def test_skips_when_outputs_missing(): + conversation_id = "conv-456" + variable = StringVariable( + id="var-2", + name="name", + value="updated", + selector=[CONVERSATION_VARIABLE_NODE_ID, "name"], + ) + + variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable}) + + updater = Mock() + layer = ConversationVariablePersistenceLayer(updater) + layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + + event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER) + layer.on_event(event) + + updater.update.assert_not_called() + updater.flush.assert_not_called() + + +def test_skips_non_assigner_nodes(): + updater = Mock() + layer = ConversationVariablePersistenceLayer(updater) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool()), Mock(spec=CommandChannel)) + + event = _build_node_run_succeeded_event(node_type=NodeType.LLM) + layer.on_event(event) + + updater.update.assert_not_called() + updater.flush.assert_not_called() + + +def test_skips_non_conversation_variables(): + conversation_id = "conv-789" + non_conversation_variable = StringVariable( + id="var-3", + name="name", + value="updated", + selector=["environment", "name"], + ) + process_data = common_helpers.set_updated_variables( + {}, [common_helpers.variable_to_processed_data(non_conversation_variable.selector, non_conversation_variable)] + ) + + variable_pool = MockReadOnlyVariablePool() + + updater = Mock() + layer = ConversationVariablePersistenceLayer(updater) + layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + + event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data) + layer.on_event(event) + + updater.update.assert_not_called() + updater.flush.assert_not_called() diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 5a5386ee57..1d885f6b2e 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -1,4 +1,5 @@ import json +from collections.abc import Sequence from time import time from unittest.mock import Mock @@ -67,8 +68,10 @@ class MockReadOnlyVariablePool: def __init__(self, variables: dict[tuple[str, str], object] | None = None): self._variables = variables or {} - def get(self, node_id: str, variable_key: str) -> Segment | None: - value = self._variables.get((node_id, variable_key)) + def get(self, selector: Sequence[str]) -> Segment | None: + if len(selector) < 2: + return None + value = self._variables.get((selector[0], selector[1])) if value is None: return None mock_segment = Mock(spec=Segment) diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index 3203aab8c3..f9e59a5f05 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -1,8 +1,12 @@ """Primarily used for testing merged cell scenarios""" +import os +import tempfile from types import SimpleNamespace from docx import Document +from docx.oxml import OxmlElement +from docx.oxml.ns import qn import core.rag.extractor.word_extractor as we from core.rag.extractor.word_extractor import WordExtractor @@ -165,3 +169,110 @@ def test_extract_images_from_docx_uses_internal_files_url(): dify_config.FILES_URL = original_files_url if original_internal_files_url is not None: dify_config.INTERNAL_FILES_URL = original_internal_files_url + + +def test_extract_hyperlinks(monkeypatch): + # Mock db and storage to avoid issues during image extraction (even if no images are present) + monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None)) + db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda o: None, commit=lambda: None)) + monkeypatch.setattr(we, "db", db_stub) + monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False) + monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False) + + doc = Document() + p = doc.add_paragraph("Visit ") + + # Adding a hyperlink manually + r_id = "rId99" + hyperlink = OxmlElement("w:hyperlink") + hyperlink.set(qn("r:id"), r_id) + + new_run = OxmlElement("w:r") + t = OxmlElement("w:t") + t.text = "Dify" + new_run.append(t) + hyperlink.append(new_run) + p._p.append(hyperlink) + + # Add relationship to the part + doc.part.rels.add_relationship( + "http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink", + "https://dify.ai", + r_id, + is_external=True, + ) + + with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp: + doc.save(tmp.name) + tmp_path = tmp.name + + try: + extractor = WordExtractor(tmp_path, "tenant_id", "user_id") + docs = extractor.extract() + # Verify modern hyperlink extraction + assert "Visit[Dify](https://dify.ai)" in docs[0].page_content + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) + + +def test_extract_legacy_hyperlinks(monkeypatch): + # Mock db and storage + monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None)) + db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda o: None, commit=lambda: None)) + monkeypatch.setattr(we, "db", db_stub) + monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False) + monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False) + + doc = Document() + p = doc.add_paragraph() + + # Construct a legacy HYPERLINK field: + # 1. w:fldChar (begin) + # 2. w:instrText (HYPERLINK "http://example.com") + # 3. w:fldChar (separate) + # 4. w:r (visible text "Example") + # 5. w:fldChar (end) + + run1 = OxmlElement("w:r") + fldCharBegin = OxmlElement("w:fldChar") + fldCharBegin.set(qn("w:fldCharType"), "begin") + run1.append(fldCharBegin) + p._p.append(run1) + + run2 = OxmlElement("w:r") + instrText = OxmlElement("w:instrText") + instrText.text = ' HYPERLINK "http://example.com" ' + run2.append(instrText) + p._p.append(run2) + + run3 = OxmlElement("w:r") + fldCharSep = OxmlElement("w:fldChar") + fldCharSep.set(qn("w:fldCharType"), "separate") + run3.append(fldCharSep) + p._p.append(run3) + + run4 = OxmlElement("w:r") + t4 = OxmlElement("w:t") + t4.text = "Example" + run4.append(t4) + p._p.append(run4) + + run5 = OxmlElement("w:r") + fldCharEnd = OxmlElement("w:fldChar") + fldCharEnd.set(qn("w:fldCharType"), "end") + run5.append(fldCharEnd) + p._p.append(run5) + + with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp: + doc.save(tmp.name) + tmp_path = tmp.name + + try: + extractor = WordExtractor(tmp_path, "tenant_id", "user_id") + docs = extractor.extract() + # Verify legacy hyperlink extraction + assert "[Example](http://example.com)" in docs[0].page_content + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py index c1fc4acd73..fe3ea576c1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py @@ -3,6 +3,7 @@ from __future__ import annotations import queue +import threading from unittest import mock from core.workflow.entities.pause_reason import SchedulingPause @@ -36,6 +37,7 @@ def test_dispatcher_should_consume_remains_events_after_pause(): event_queue=event_queue, event_handler=event_handler, execution_coordinator=execution_coordinator, + stop_event=threading.Event(), ) dispatcher._dispatcher_loop() assert event_queue.empty() @@ -96,6 +98,7 @@ def _run_dispatcher_for_event(event) -> int: event_queue=event_queue, event_handler=event_handler, execution_coordinator=coordinator, + stop_event=threading.Event(), ) dispatcher._dispatcher_loop() @@ -181,6 +184,7 @@ def test_dispatcher_drain_event_queue(): event_queue=event_queue, event_handler=event_handler, execution_coordinator=coordinator, + stop_event=threading.Event(), ) dispatcher._dispatcher_loop() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py index b02f90588b..5ceb8dd7f7 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py @@ -5,6 +5,8 @@ This module provides a flexible configuration system for customizing the behavior of mock nodes during testing. """ +from __future__ import annotations + from collections.abc import Callable from dataclasses import dataclass, field from typing import Any @@ -95,67 +97,67 @@ class MockConfigBuilder: def __init__(self) -> None: self._config = MockConfig() - def with_auto_mock(self, enabled: bool = True) -> "MockConfigBuilder": + def with_auto_mock(self, enabled: bool = True) -> MockConfigBuilder: """Enable or disable auto-mocking.""" self._config.enable_auto_mock = enabled return self - def with_delays(self, enabled: bool = True) -> "MockConfigBuilder": + def with_delays(self, enabled: bool = True) -> MockConfigBuilder: """Enable or disable simulated execution delays.""" self._config.simulate_delays = enabled return self - def with_llm_response(self, response: str) -> "MockConfigBuilder": + def with_llm_response(self, response: str) -> MockConfigBuilder: """Set default LLM response.""" self._config.default_llm_response = response return self - def with_agent_response(self, response: str) -> "MockConfigBuilder": + def with_agent_response(self, response: str) -> MockConfigBuilder: """Set default agent response.""" self._config.default_agent_response = response return self - def with_tool_response(self, response: dict[str, Any]) -> "MockConfigBuilder": + def with_tool_response(self, response: dict[str, Any]) -> MockConfigBuilder: """Set default tool response.""" self._config.default_tool_response = response return self - def with_retrieval_response(self, response: str) -> "MockConfigBuilder": + def with_retrieval_response(self, response: str) -> MockConfigBuilder: """Set default retrieval response.""" self._config.default_retrieval_response = response return self - def with_http_response(self, response: dict[str, Any]) -> "MockConfigBuilder": + def with_http_response(self, response: dict[str, Any]) -> MockConfigBuilder: """Set default HTTP response.""" self._config.default_http_response = response return self - def with_template_transform_response(self, response: str) -> "MockConfigBuilder": + def with_template_transform_response(self, response: str) -> MockConfigBuilder: """Set default template transform response.""" self._config.default_template_transform_response = response return self - def with_code_response(self, response: dict[str, Any]) -> "MockConfigBuilder": + def with_code_response(self, response: dict[str, Any]) -> MockConfigBuilder: """Set default code execution response.""" self._config.default_code_response = response return self - def with_node_output(self, node_id: str, outputs: dict[str, Any]) -> "MockConfigBuilder": + def with_node_output(self, node_id: str, outputs: dict[str, Any]) -> MockConfigBuilder: """Set outputs for a specific node.""" self._config.set_node_outputs(node_id, outputs) return self - def with_node_error(self, node_id: str, error: str) -> "MockConfigBuilder": + def with_node_error(self, node_id: str, error: str) -> MockConfigBuilder: """Set error for a specific node.""" self._config.set_node_error(node_id, error) return self - def with_node_config(self, config: NodeMockConfig) -> "MockConfigBuilder": + def with_node_config(self, config: NodeMockConfig) -> MockConfigBuilder: """Add a node-specific configuration.""" self._config.set_node_config(config.node_id, config) return self - def with_default_config(self, node_type: NodeType, config: dict[str, Any]) -> "MockConfigBuilder": + def with_default_config(self, node_type: NodeType, config: dict[str, Any]) -> MockConfigBuilder: """Set default configuration for a node type.""" self._config.set_default_config(node_type, config) return self diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py new file mode 100644 index 0000000000..822b6a808f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py @@ -0,0 +1,231 @@ +"""Tests for ResponseStreamCoordinator object field streaming.""" + +from unittest.mock import MagicMock + +from core.workflow.entities.tool_entities import ToolResultStatus +from core.workflow.enums import NodeType +from core.workflow.graph.graph import Graph +from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator +from core.workflow.graph_engine.response_coordinator.session import ResponseSession +from core.workflow.graph_events import ( + ChunkType, + NodeRunStreamChunkEvent, + ToolCall, + ToolResult, +) +from core.workflow.nodes.base.entities import BaseNodeData +from core.workflow.nodes.base.template import Template, VariableSegment +from core.workflow.runtime import VariablePool + + +class TestResponseCoordinatorObjectStreaming: + """Test streaming of object-type variables with child fields.""" + + def test_object_field_streaming(self): + """Test that when selecting an object variable, all child field streams are forwarded.""" + # Create mock graph and variable pool + graph = MagicMock(spec=Graph) + variable_pool = MagicMock(spec=VariablePool) + + # Mock nodes + llm_node = MagicMock() + llm_node.id = "llm_node" + llm_node.node_type = NodeType.LLM + llm_node.execution_type = MagicMock() + llm_node.blocks_variable_output = MagicMock(return_value=False) + + response_node = MagicMock() + response_node.id = "response_node" + response_node.node_type = NodeType.ANSWER + response_node.execution_type = MagicMock() + response_node.blocks_variable_output = MagicMock(return_value=False) + + # Mock template for response node + response_node.node_data = MagicMock(spec=BaseNodeData) + response_node.node_data.answer = "{{#llm_node.generation#}}" + + graph.nodes = { + "llm_node": llm_node, + "response_node": response_node, + } + graph.root_node = llm_node + graph.get_outgoing_edges = MagicMock(return_value=[]) + + # Create coordinator + coordinator = ResponseStreamCoordinator(variable_pool, graph) + + # Track execution + coordinator.track_node_execution("llm_node", "exec_123") + coordinator.track_node_execution("response_node", "exec_456") + + # Simulate streaming events for child fields of generation object + # 1. Content stream + content_event_1 = NodeRunStreamChunkEvent( + id="exec_123", + node_id="llm_node", + node_type=NodeType.LLM, + selector=["llm_node", "generation", "content"], + chunk="Hello", + is_final=False, + chunk_type=ChunkType.TEXT, + ) + content_event_2 = NodeRunStreamChunkEvent( + id="exec_123", + node_id="llm_node", + node_type=NodeType.LLM, + selector=["llm_node", "generation", "content"], + chunk=" world", + is_final=True, + chunk_type=ChunkType.TEXT, + ) + + # 2. Tool call stream + tool_call_event = NodeRunStreamChunkEvent( + id="exec_123", + node_id="llm_node", + node_type=NodeType.LLM, + selector=["llm_node", "generation", "tool_calls"], + chunk='{"query": "test"}', + is_final=True, + chunk_type=ChunkType.TOOL_CALL, + tool_call=ToolCall( + id="call_123", + name="search", + arguments='{"query": "test"}', + ), + ) + + # 3. Tool result stream + tool_result_event = NodeRunStreamChunkEvent( + id="exec_123", + node_id="llm_node", + node_type=NodeType.LLM, + selector=["llm_node", "generation", "tool_results"], + chunk="Found 10 results", + is_final=True, + chunk_type=ChunkType.TOOL_RESULT, + tool_result=ToolResult( + id="call_123", + name="search", + output="Found 10 results", + files=[], + status=ToolResultStatus.SUCCESS, + ), + ) + + # Intercept these events + coordinator.intercept_event(content_event_1) + coordinator.intercept_event(tool_call_event) + coordinator.intercept_event(tool_result_event) + coordinator.intercept_event(content_event_2) + + # Verify that all child streams are buffered + assert ("llm_node", "generation", "content") in coordinator._stream_buffers + assert ("llm_node", "generation", "tool_calls") in coordinator._stream_buffers + assert ("llm_node", "generation", "tool_results") in coordinator._stream_buffers + + # Verify payloads are preserved in buffered events + buffered_call = coordinator._stream_buffers[("llm_node", "generation", "tool_calls")][0] + assert buffered_call.tool_call is not None + assert buffered_call.tool_call.id == "call_123" + buffered_result = coordinator._stream_buffers[("llm_node", "generation", "tool_results")][0] + assert buffered_result.tool_result is not None + assert buffered_result.tool_result.status == "success" + + # Verify we can find child streams + child_streams = coordinator._find_child_streams(["llm_node", "generation"]) + assert len(child_streams) == 3 + assert ("llm_node", "generation", "content") in child_streams + assert ("llm_node", "generation", "tool_calls") in child_streams + assert ("llm_node", "generation", "tool_results") in child_streams + + def test_find_child_streams(self): + """Test the _find_child_streams method.""" + graph = MagicMock(spec=Graph) + variable_pool = MagicMock(spec=VariablePool) + + coordinator = ResponseStreamCoordinator(variable_pool, graph) + + # Add some mock streams + coordinator._stream_buffers = { + ("node1", "generation", "content"): [], + ("node1", "generation", "tool_calls"): [], + ("node1", "generation", "thought"): [], + ("node1", "text"): [], # Not a child of generation + ("node2", "generation", "content"): [], # Different node + } + + # Find children of node1.generation + children = coordinator._find_child_streams(["node1", "generation"]) + + assert len(children) == 3 + assert ("node1", "generation", "content") in children + assert ("node1", "generation", "tool_calls") in children + assert ("node1", "generation", "thought") in children + assert ("node1", "text") not in children + assert ("node2", "generation", "content") not in children + + def test_find_child_streams_with_closed_streams(self): + """Test that _find_child_streams also considers closed streams.""" + graph = MagicMock(spec=Graph) + variable_pool = MagicMock(spec=VariablePool) + + coordinator = ResponseStreamCoordinator(variable_pool, graph) + + # Add some streams - some buffered, some closed + coordinator._stream_buffers = { + ("node1", "generation", "content"): [], + } + coordinator._closed_streams = { + ("node1", "generation", "tool_calls"), + ("node1", "generation", "thought"), + } + + # Should find all children regardless of whether they're in buffers or closed + children = coordinator._find_child_streams(["node1", "generation"]) + + assert len(children) == 3 + assert ("node1", "generation", "content") in children + assert ("node1", "generation", "tool_calls") in children + assert ("node1", "generation", "thought") in children + + def test_special_selector_rewrites_to_active_response_node(self): + """Ensure special selectors attribute streams to the active response node.""" + graph = MagicMock(spec=Graph) + variable_pool = MagicMock(spec=VariablePool) + + response_node = MagicMock() + response_node.id = "response_node" + response_node.node_type = NodeType.ANSWER + graph.nodes = {"response_node": response_node} + graph.root_node = response_node + + coordinator = ResponseStreamCoordinator(variable_pool, graph) + coordinator.track_node_execution("response_node", "exec_resp") + + coordinator._active_session = ResponseSession( + node_id="response_node", + template=Template(segments=[VariableSegment(selector=["sys", "foo"])]), + ) + + event = NodeRunStreamChunkEvent( + id="stream_1", + node_id="llm_node", + node_type=NodeType.LLM, + selector=["sys", "foo"], + chunk="hi", + is_final=True, + chunk_type=ChunkType.TEXT, + ) + + coordinator._stream_buffers[("sys", "foo")] = [event] + coordinator._stream_positions[("sys", "foo")] = 0 + coordinator._closed_streams.add(("sys", "foo")) + + events, is_complete = coordinator._process_variable_segment(VariableSegment(selector=["sys", "foo"])) + + assert is_complete + assert len(events) == 1 + rewritten = events[0] + assert rewritten.node_id == "response_node" + assert rewritten.id == "exec_resp" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py b/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py new file mode 100644 index 0000000000..ea8d3a977f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py @@ -0,0 +1,539 @@ +""" +Unit tests for stop_event functionality in GraphEngine. + +Tests the unified stop_event management by GraphEngine and its propagation +to WorkerPool, Worker, Dispatcher, and Nodes. +""" + +import threading +import time +from unittest.mock import MagicMock, Mock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.graph_init_params import GraphInitParams +from core.workflow.graph import Graph +from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine.command_channels import InMemoryChannel +from core.workflow.graph_events import ( + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunStartedEvent, +) +from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.runtime import GraphRuntimeState, VariablePool +from models.enums import UserFrom + + +class TestStopEventPropagation: + """Test suite for stop_event propagation through GraphEngine components.""" + + def test_graph_engine_creates_stop_event(self): + """Test that GraphEngine creates a stop_event on initialization.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Verify stop_event was created + assert engine._stop_event is not None + assert isinstance(engine._stop_event, threading.Event) + + # Verify it was set in graph_runtime_state + assert runtime_state.stop_event is not None + assert runtime_state.stop_event is engine._stop_event + + def test_stop_event_cleared_on_start(self): + """Test that stop_event is cleared when execution starts.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" # Set proper id + + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + mock_graph.nodes["start"] = start_node + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Set the stop_event before running + engine._stop_event.set() + assert engine._stop_event.is_set() + + # Run the engine (should clear the stop_event) + events = list(engine.run()) + + # After running, stop_event should be set again (by _stop_execution) + # But during start it was cleared + assert any(isinstance(e, GraphRunStartedEvent) for e in events) + assert any(isinstance(e, GraphRunSucceededEvent) for e in events) + + def test_stop_event_set_on_stop(self): + """Test that stop_event is set when execution stops.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" # Set proper id + + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + mock_graph.nodes["start"] = start_node + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Initially not set + assert not engine._stop_event.is_set() + + # Run the engine + list(engine.run()) + + # After execution completes, stop_event should be set + assert engine._stop_event.is_set() + + def test_stop_event_passed_to_worker_pool(self): + """Test that stop_event is passed to WorkerPool.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Verify WorkerPool has the stop_event + assert engine._worker_pool._stop_event is not None + assert engine._worker_pool._stop_event is engine._stop_event + + def test_stop_event_passed_to_dispatcher(self): + """Test that stop_event is passed to Dispatcher.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Verify Dispatcher has the stop_event + assert engine._dispatcher._stop_event is not None + assert engine._dispatcher._stop_event is engine._stop_event + + +class TestNodeStopCheck: + """Test suite for Node._should_stop() functionality.""" + + def test_node_should_stop_checks_runtime_state(self): + """Test that Node._should_stop() checks GraphRuntimeState.stop_event.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + answer_node = AnswerNode( + id="answer", + config={"id": "answer", "data": {"title": "answer", "answer": "{{#start.result#}}"}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + + # Initially stop_event is not set + assert not answer_node._should_stop() + + # Set the stop_event + runtime_state.stop_event.set() + + # Now _should_stop should return True + assert answer_node._should_stop() + + def test_node_run_checks_stop_event_between_yields(self): + """Test that Node.run() checks stop_event between yielding events.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + # Create a simple node + answer_node = AnswerNode( + id="answer", + config={"id": "answer", "data": {"title": "answer", "answer": "hello"}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + + # Set stop_event BEFORE running the node + runtime_state.stop_event.set() + + # Run the node - should yield start event then detect stop + # The node should check stop_event before processing + assert answer_node._should_stop(), "stop_event should be set" + + # Run and collect events + events = list(answer_node.run()) + + # Since stop_event is set at the start, we should get: + # 1. NodeRunStartedEvent (always yielded first) + # 2. Either NodeRunFailedEvent (if detected early) or NodeRunSucceededEvent (if too fast) + assert len(events) >= 2 + assert isinstance(events[0], NodeRunStartedEvent) + + # Note: AnswerNode is very simple and might complete before stop check + # The important thing is that _should_stop() returns True when stop_event is set + assert answer_node._should_stop() + + +class TestStopEventIntegration: + """Integration tests for stop_event in workflow execution.""" + + def test_simple_workflow_respects_stop_event(self): + """Test that a simple workflow respects stop_event.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" + + # Create start and answer nodes + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + + answer_node = AnswerNode( + id="answer", + config={"id": "answer", "data": {"title": "answer", "answer": "hello"}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + + mock_graph.nodes["start"] = start_node + mock_graph.nodes["answer"] = answer_node + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Set stop_event before running + runtime_state.stop_event.set() + + # Run the engine + events = list(engine.run()) + + # Should get started event but not succeeded (due to stop) + assert any(isinstance(e, GraphRunStartedEvent) for e in events) + # The workflow should still complete (start node runs quickly) + # but answer node might be cancelled depending on timing + + def test_stop_event_with_concurrent_nodes(self): + """Test stop_event behavior with multiple concurrent nodes.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + # Create multiple nodes + for i in range(3): + answer_node = AnswerNode( + id=f"answer_{i}", + config={"id": f"answer_{i}", "data": {"title": f"answer_{i}", "answer": f"test{i}"}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + mock_graph.nodes[f"answer_{i}"] = answer_node + + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # All nodes should share the same stop_event + for node in mock_graph.nodes.values(): + assert node.graph_runtime_state.stop_event is runtime_state.stop_event + assert node.graph_runtime_state.stop_event is engine._stop_event + + +class TestStopEventTimeoutBehavior: + """Test stop_event behavior with join timeouts.""" + + @patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread") + def test_dispatcher_uses_shorter_timeout(self, mock_thread_cls: MagicMock): + """Test that Dispatcher uses 2s timeout instead of 10s.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + dispatcher = engine._dispatcher + dispatcher.start() # This will create and start the mocked thread + + mock_thread_instance = mock_thread_cls.return_value + mock_thread_instance.is_alive.return_value = True + + dispatcher.stop() + + mock_thread_instance.join.assert_called_once_with(timeout=2.0) + + @patch("core.workflow.graph_engine.worker_management.worker_pool.Worker") + def test_worker_pool_uses_shorter_timeout(self, mock_worker_cls: MagicMock): + """Test that WorkerPool uses 2s timeout instead of 10s.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + worker_pool = engine._worker_pool + worker_pool.start(initial_count=1) # Start with one worker + + mock_worker_instance = mock_worker_cls.return_value + mock_worker_instance.is_alive.return_value = True + + worker_pool.stop() + + mock_worker_instance.join.assert_called_once_with(timeout=2.0) + + +class TestStopEventResumeBehavior: + """Test stop_event behavior during workflow resume.""" + + def test_stop_event_cleared_on_resume(self): + """Test that stop_event is cleared when resuming a paused workflow.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" # Set proper id + + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + mock_graph.nodes["start"] = start_node + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Simulate a previous execution that set stop_event + engine._stop_event.set() + assert engine._stop_event.is_set() + + # Run the engine (should clear stop_event in _start_execution) + events = list(engine.run()) + + # Execution should complete successfully + assert any(isinstance(e, GraphRunStartedEvent) for e in events) + assert any(isinstance(e, GraphRunSucceededEvent) for e in events) + + +class TestWorkerStopBehavior: + """Test Worker behavior with shared stop_event.""" + + def test_worker_uses_shared_stop_event(self): + """Test that Worker uses shared stop_event from GraphEngine.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Get the worker pool and check workers + worker_pool = engine._worker_pool + + # Start the worker pool to create workers + worker_pool.start() + + # Check that at least one worker was created + assert len(worker_pool._workers) > 0 + + # Verify workers use the shared stop_event + for worker in worker_pool._workers: + assert worker._stop_event is engine._stop_event + + # Clean up + worker_pool.stop() + + def test_worker_stop_is_noop(self): + """Test that Worker.stop() is now a no-op.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + # Create a mock worker + from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue + from core.workflow.graph_engine.worker import Worker + + ready_queue = InMemoryReadyQueue() + event_queue = MagicMock() + + # Create a proper mock graph with real dict + mock_graph = Mock(spec=Graph) + mock_graph.nodes = {} # Use real dict + + stop_event = threading.Event() + + worker = Worker( + ready_queue=ready_queue, + event_queue=event_queue, + graph=mock_graph, + layers=[], + stop_event=stop_event, + ) + + # Calling stop() should do nothing (no-op) + # and should NOT set the stop_event + worker.stop() + assert not stop_event.is_set() diff --git a/api/tests/unit_tests/core/workflow/node_events/test_stream_chunk_events.py b/api/tests/unit_tests/core/workflow/node_events/test_stream_chunk_events.py new file mode 100644 index 0000000000..951149e933 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/node_events/test_stream_chunk_events.py @@ -0,0 +1,328 @@ +"""Tests for StreamChunkEvent and its subclasses.""" + +from core.workflow.entities import ToolCall, ToolResult, ToolResultStatus +from core.workflow.node_events import ( + ChunkType, + StreamChunkEvent, + ThoughtChunkEvent, + ToolCallChunkEvent, + ToolResultChunkEvent, +) + + +class TestChunkType: + """Tests for ChunkType enum.""" + + def test_chunk_type_values(self): + """Test that ChunkType has expected values.""" + assert ChunkType.TEXT == "text" + assert ChunkType.TOOL_CALL == "tool_call" + assert ChunkType.TOOL_RESULT == "tool_result" + assert ChunkType.THOUGHT == "thought" + + def test_chunk_type_is_str_enum(self): + """Test that ChunkType values are strings.""" + for chunk_type in ChunkType: + assert isinstance(chunk_type.value, str) + + +class TestStreamChunkEvent: + """Tests for base StreamChunkEvent.""" + + def test_create_with_required_fields(self): + """Test creating StreamChunkEvent with required fields.""" + event = StreamChunkEvent( + selector=["node1", "text"], + chunk="Hello", + ) + + assert event.selector == ["node1", "text"] + assert event.chunk == "Hello" + assert event.is_final is False + assert event.chunk_type == ChunkType.TEXT + + def test_create_with_all_fields(self): + """Test creating StreamChunkEvent with all fields.""" + event = StreamChunkEvent( + selector=["node1", "output"], + chunk="World", + is_final=True, + chunk_type=ChunkType.TEXT, + ) + + assert event.selector == ["node1", "output"] + assert event.chunk == "World" + assert event.is_final is True + assert event.chunk_type == ChunkType.TEXT + + def test_default_chunk_type_is_text(self): + """Test that default chunk_type is TEXT.""" + event = StreamChunkEvent( + selector=["node1", "text"], + chunk="test", + ) + + assert event.chunk_type == ChunkType.TEXT + + def test_serialization(self): + """Test that event can be serialized to dict.""" + event = StreamChunkEvent( + selector=["node1", "text"], + chunk="Hello", + is_final=True, + ) + + data = event.model_dump() + + assert data["selector"] == ["node1", "text"] + assert data["chunk"] == "Hello" + assert data["is_final"] is True + assert data["chunk_type"] == "text" + + +class TestToolCallChunkEvent: + """Tests for ToolCallChunkEvent.""" + + def test_create_with_required_fields(self): + """Test creating ToolCallChunkEvent with required fields.""" + event = ToolCallChunkEvent( + selector=["node1", "tool_calls"], + chunk='{"city": "Beijing"}', + tool_call=ToolCall(id="call_123", name="weather", arguments=None), + ) + + assert event.selector == ["node1", "tool_calls"] + assert event.chunk == '{"city": "Beijing"}' + assert event.tool_call.id == "call_123" + assert event.tool_call.name == "weather" + assert event.chunk_type == ChunkType.TOOL_CALL + + def test_chunk_type_is_tool_call(self): + """Test that chunk_type is always TOOL_CALL.""" + event = ToolCallChunkEvent( + selector=["node1", "tool_calls"], + chunk="", + tool_call=ToolCall(id="call_123", name="test_tool", arguments=None), + ) + + assert event.chunk_type == ChunkType.TOOL_CALL + + def test_tool_arguments_field(self): + """Test tool_arguments field.""" + event = ToolCallChunkEvent( + selector=["node1", "tool_calls"], + chunk='{"param": "value"}', + tool_call=ToolCall( + id="call_123", + name="test_tool", + arguments='{"param": "value"}', + ), + ) + + assert event.tool_call.arguments == '{"param": "value"}' + + def test_serialization(self): + """Test that event can be serialized to dict.""" + event = ToolCallChunkEvent( + selector=["node1", "tool_calls"], + chunk='{"city": "Beijing"}', + tool_call=ToolCall( + id="call_123", + name="weather", + arguments='{"city": "Beijing"}', + ), + is_final=True, + ) + + data = event.model_dump() + + assert data["chunk_type"] == "tool_call" + assert data["tool_call"]["id"] == "call_123" + assert data["tool_call"]["name"] == "weather" + assert data["tool_call"]["arguments"] == '{"city": "Beijing"}' + assert data["is_final"] is True + + +class TestToolResultChunkEvent: + """Tests for ToolResultChunkEvent.""" + + def test_create_with_required_fields(self): + """Test creating ToolResultChunkEvent with required fields.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="Weather: Sunny, 25°C", + tool_result=ToolResult(id="call_123", name="weather", output="Weather: Sunny, 25°C"), + ) + + assert event.selector == ["node1", "tool_results"] + assert event.chunk == "Weather: Sunny, 25°C" + assert event.tool_result.id == "call_123" + assert event.tool_result.name == "weather" + assert event.chunk_type == ChunkType.TOOL_RESULT + + def test_chunk_type_is_tool_result(self): + """Test that chunk_type is always TOOL_RESULT.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="result", + tool_result=ToolResult(id="call_123", name="test_tool"), + ) + + assert event.chunk_type == ChunkType.TOOL_RESULT + + def test_tool_files_default_empty(self): + """Test that tool_files defaults to empty list.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="result", + tool_result=ToolResult(id="call_123", name="test_tool"), + ) + + assert event.tool_result.files == [] + + def test_tool_files_with_values(self): + """Test tool_files with file IDs.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="result", + tool_result=ToolResult( + id="call_123", + name="test_tool", + files=["file_1", "file_2"], + ), + ) + + assert event.tool_result.files == ["file_1", "file_2"] + + def test_tool_error_output(self): + """Test error output captured in tool_result.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="", + tool_result=ToolResult( + id="call_123", + name="test_tool", + output="Tool execution failed", + status=ToolResultStatus.ERROR, + ), + ) + + assert event.tool_result.output == "Tool execution failed" + assert event.tool_result.status == ToolResultStatus.ERROR + + def test_serialization(self): + """Test that event can be serialized to dict.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="Weather: Sunny", + tool_result=ToolResult( + id="call_123", + name="weather", + output="Weather: Sunny", + files=["file_1"], + status=ToolResultStatus.SUCCESS, + ), + is_final=True, + ) + + data = event.model_dump() + + assert data["chunk_type"] == "tool_result" + assert data["tool_result"]["id"] == "call_123" + assert data["tool_result"]["name"] == "weather" + assert data["tool_result"]["files"] == ["file_1"] + assert data["is_final"] is True + + +class TestThoughtChunkEvent: + """Tests for ThoughtChunkEvent.""" + + def test_create_with_required_fields(self): + """Test creating ThoughtChunkEvent with required fields.""" + event = ThoughtChunkEvent( + selector=["node1", "thought"], + chunk="I need to query the weather...", + ) + + assert event.selector == ["node1", "thought"] + assert event.chunk == "I need to query the weather..." + assert event.chunk_type == ChunkType.THOUGHT + + def test_chunk_type_is_thought(self): + """Test that chunk_type is always THOUGHT.""" + event = ThoughtChunkEvent( + selector=["node1", "thought"], + chunk="thinking...", + ) + + assert event.chunk_type == ChunkType.THOUGHT + + def test_serialization(self): + """Test that event can be serialized to dict.""" + event = ThoughtChunkEvent( + selector=["node1", "thought"], + chunk="I need to analyze this...", + is_final=False, + ) + + data = event.model_dump() + + assert data["chunk_type"] == "thought" + assert data["chunk"] == "I need to analyze this..." + assert data["is_final"] is False + + +class TestEventInheritance: + """Tests for event inheritance relationships.""" + + def test_tool_call_is_stream_chunk(self): + """Test that ToolCallChunkEvent is a StreamChunkEvent.""" + event = ToolCallChunkEvent( + selector=["node1", "tool_calls"], + chunk="", + tool_call=ToolCall(id="call_123", name="test", arguments=None), + ) + + assert isinstance(event, StreamChunkEvent) + + def test_tool_result_is_stream_chunk(self): + """Test that ToolResultChunkEvent is a StreamChunkEvent.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="result", + tool_result=ToolResult(id="call_123", name="test"), + ) + + assert isinstance(event, StreamChunkEvent) + + def test_thought_is_stream_chunk(self): + """Test that ThoughtChunkEvent is a StreamChunkEvent.""" + event = ThoughtChunkEvent( + selector=["node1", "thought"], + chunk="thinking...", + ) + + assert isinstance(event, StreamChunkEvent) + + def test_all_events_have_common_fields(self): + """Test that all events have common StreamChunkEvent fields.""" + events = [ + StreamChunkEvent(selector=["n", "t"], chunk="a"), + ToolCallChunkEvent( + selector=["n", "t"], + chunk="b", + tool_call=ToolCall(id="1", name="t", arguments=None), + ), + ToolResultChunkEvent( + selector=["n", "t"], + chunk="c", + tool_result=ToolResult(id="1", name="t"), + ), + ThoughtChunkEvent(selector=["n", "t"], chunk="d"), + ] + + for event in events: + assert hasattr(event, "selector") + assert hasattr(event, "chunk") + assert hasattr(event, "is_final") + assert hasattr(event, "chunk_type") diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py index e8f257bf2f..1e224d56a5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py @@ -78,7 +78,7 @@ class TestFileSaverImpl: file_binary=_PNG_DATA, mimetype=mime_type, ) - mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png") + mocked_sign_file.assert_called_once_with(tool_file_id=mock_tool_file.id, extension=".png", for_external=True) def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch): _TEST_URL = "https://example.com/image.png" diff --git a/api/tests/unit_tests/core/workflow/nodes/test_llm_node_streaming.py b/api/tests/unit_tests/core/workflow/nodes/test_llm_node_streaming.py new file mode 100644 index 0000000000..9d793f804f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_llm_node_streaming.py @@ -0,0 +1,148 @@ +import types +from collections.abc import Generator +from typing import Any + +import pytest + +from core.model_runtime.entities.llm_entities import LLMUsage +from core.workflow.entities import ToolCallResult +from core.workflow.entities.tool_entities import ToolResultStatus +from core.workflow.node_events import ModelInvokeCompletedEvent, NodeEventBase +from core.workflow.nodes.llm.node import LLMNode + + +class _StubModelInstance: + """Minimal stub to satisfy _stream_llm_events signature.""" + + provider_model_bundle = None + + +def _drain(generator: Generator[NodeEventBase, None, Any]): + events: list = [] + try: + while True: + events.append(next(generator)) + except StopIteration as exc: + return events, exc.value + + +@pytest.fixture(autouse=True) +def patch_deduct_llm_quota(monkeypatch): + # Avoid touching real quota logic during unit tests + monkeypatch.setattr("core.workflow.nodes.llm.node.llm_utils.deduct_llm_quota", lambda **_: None) + + +def _make_llm_node(reasoning_format: str) -> LLMNode: + node = LLMNode.__new__(LLMNode) + object.__setattr__(node, "_node_data", types.SimpleNamespace(reasoning_format=reasoning_format, tools=[])) + object.__setattr__(node, "tenant_id", "tenant") + return node + + +def test_stream_llm_events_extracts_reasoning_for_tagged(): + node = _make_llm_node(reasoning_format="tagged") + tagged_text = "ThoughtAnswer" + usage = LLMUsage.empty_usage() + + def generator(): + yield ModelInvokeCompletedEvent( + text=tagged_text, + usage=usage, + finish_reason="stop", + reasoning_content="", + structured_output=None, + ) + + events, returned = _drain( + node._stream_llm_events(generator(), model_instance=types.SimpleNamespace(provider_model_bundle=None)) + ) + + assert events == [] + clean_text, reasoning_content, gen_reasoning, gen_clean, ret_usage, finish_reason, structured, gen_data = returned + assert clean_text == tagged_text # original preserved for output + assert reasoning_content == "" # tagged mode keeps reasoning separate + assert gen_clean == "Answer" # stripped content for generation + assert gen_reasoning == "Thought" # reasoning extracted from tag + assert ret_usage == usage + assert finish_reason == "stop" + assert structured is None + assert gen_data is None + + # generation building should include reasoning and sequence + generation_content = gen_clean or clean_text + sequence = [ + {"type": "reasoning", "index": 0}, + {"type": "content", "start": 0, "end": len(generation_content)}, + ] + assert sequence == [ + {"type": "reasoning", "index": 0}, + {"type": "content", "start": 0, "end": len("Answer")}, + ] + + +def test_stream_llm_events_no_reasoning_results_in_empty_sequence(): + node = _make_llm_node(reasoning_format="tagged") + plain_text = "Hello world" + usage = LLMUsage.empty_usage() + + def generator(): + yield ModelInvokeCompletedEvent( + text=plain_text, + usage=usage, + finish_reason=None, + reasoning_content="", + structured_output=None, + ) + + events, returned = _drain( + node._stream_llm_events(generator(), model_instance=types.SimpleNamespace(provider_model_bundle=None)) + ) + + assert events == [] + _, _, gen_reasoning, gen_clean, *_ = returned + generation_content = gen_clean or plain_text + assert gen_reasoning == "" + assert generation_content == plain_text + # Empty reasoning should imply empty sequence in generation construction + sequence = [] + assert sequence == [] + + +def test_serialize_tool_call_strips_files_to_ids(): + file_cls = pytest.importorskip("core.file").File + file_type = pytest.importorskip("core.file.enums").FileType + transfer_method = pytest.importorskip("core.file.enums").FileTransferMethod + + file_with_id = file_cls( + id="f1", + tenant_id="t", + type=file_type.IMAGE, + transfer_method=transfer_method.REMOTE_URL, + remote_url="http://example.com/f1", + storage_key="k1", + ) + file_with_related = file_cls( + id=None, + tenant_id="t", + type=file_type.IMAGE, + transfer_method=transfer_method.REMOTE_URL, + related_id="rel2", + remote_url="http://example.com/f2", + storage_key="k2", + ) + tool_call = ToolCallResult( + id="tc", + name="do", + arguments='{"a":1}', + output="ok", + files=[file_with_id, file_with_related], + status=ToolResultStatus.SUCCESS, + ) + + serialized = LLMNode._serialize_tool_call(tool_call) + + assert serialized["files"] == ["f1", "rel2"] + assert serialized["id"] == "tc" + assert serialized["name"] == "do" + assert serialized["arguments"] == '{"a":1}' + assert serialized["output"] == "ok" diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 09b8191870..06927cddcf 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys import types from collections.abc import Generator @@ -21,7 +23,7 @@ if TYPE_CHECKING: # pragma: no cover - imported for type checking only @pytest.fixture -def tool_node(monkeypatch) -> "ToolNode": +def tool_node(monkeypatch) -> ToolNode: module_name = "core.ops.ops_trace_manager" if module_name not in sys.modules: ops_stub = types.ModuleType(module_name) @@ -85,7 +87,7 @@ def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]: return events, stop.value -def _run_transform(tool_node: "ToolNode", message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]: +def _run_transform(tool_node: ToolNode, message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]: def _identity_transform(messages, *_args, **_kwargs): return messages @@ -103,7 +105,7 @@ def _run_transform(tool_node: "ToolNode", message: ToolInvokeMessage) -> tuple[l return _collect_events(generator) -def test_link_messages_with_file_populate_files_output(tool_node: "ToolNode"): +def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): file_obj = File( tenant_id="tenant-id", type=FileType.DOCUMENT, @@ -139,7 +141,7 @@ def test_link_messages_with_file_populate_files_output(tool_node: "ToolNode"): assert files_segment.value == [file_obj] -def test_plain_link_messages_remain_links(tool_node: "ToolNode"): +def test_plain_link_messages_remain_links(tool_node: ToolNode): message = ToolInvokeMessage( type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text="https://dify.ai"), diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index c62fc4d8fe..1df75380af 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -1,14 +1,14 @@ import time import uuid -from unittest import mock from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import ArrayStringVariable, StringVariable -from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph +from core.workflow.graph_events.node import NodeRunSucceededEvent from core.workflow.nodes.node_factory import DifyNodeFactory +from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode from core.workflow.runtime import GraphRuntimeState, VariablePool @@ -86,9 +86,6 @@ def test_overwrite_string_variable(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) - mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) - node_config = { "id": "node_id", "data": { @@ -104,20 +101,14 @@ def test_overwrite_string_variable(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, config=node_config, - conv_var_updater_factory=mock_conv_var_updater_factory, ) - list(node.run()) - expected_var = StringVariable( - id=conversation_variable.id, - name=conversation_variable.name, - description=conversation_variable.description, - selector=conversation_variable.selector, - value_type=conversation_variable.value_type, - value=input_variable.value, - ) - mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) - mock_conv_var_updater.flush.assert_called_once() + events = list(node.run()) + succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) + updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) + assert updated_variables is not None + assert updated_variables[0].name == conversation_variable.name + assert updated_variables[0].new_value == input_variable.value got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None @@ -191,9 +182,6 @@ def test_append_variable_to_array(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) - mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) - node_config = { "id": "node_id", "data": { @@ -209,22 +197,14 @@ def test_append_variable_to_array(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, config=node_config, - conv_var_updater_factory=mock_conv_var_updater_factory, ) - list(node.run()) - expected_value = list(conversation_variable.value) - expected_value.append(input_variable.value) - expected_var = ArrayStringVariable( - id=conversation_variable.id, - name=conversation_variable.name, - description=conversation_variable.description, - selector=conversation_variable.selector, - value_type=conversation_variable.value_type, - value=expected_value, - ) - mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) - mock_conv_var_updater.flush.assert_called_once() + events = list(node.run()) + succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) + updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) + assert updated_variables is not None + assert updated_variables[0].name == conversation_variable.name + assert updated_variables[0].new_value == ["the first value", "the second value"] got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None @@ -287,9 +267,6 @@ def test_clear_array(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) - mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) - node_config = { "id": "node_id", "data": { @@ -305,20 +282,14 @@ def test_clear_array(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, config=node_config, - conv_var_updater_factory=mock_conv_var_updater_factory, ) - list(node.run()) - expected_var = ArrayStringVariable( - id=conversation_variable.id, - name=conversation_variable.name, - description=conversation_variable.description, - selector=conversation_variable.selector, - value_type=conversation_variable.value_type, - value=[], - ) - mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) - mock_conv_var_updater.flush.assert_called_once() + events = list(node.run()) + succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) + updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) + assert updated_variables is not None + assert updated_variables[0].name == conversation_variable.name + assert updated_variables[0].new_value == [] got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index caa36734ad..353d56fe25 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -390,3 +390,42 @@ def test_remove_last_from_empty_array(): got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None assert got.to_object() == [] + + +def test_node_factory_creates_variable_assigner_node(): + graph_config = { + "edges": [], + "nodes": [ + { + "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, + "id": "assigner", + }, + ], + } + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + variable_pool = VariablePool( + system_variables=SystemVariable(conversation_id="conversation_id"), + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + node_factory = DifyNodeFactory( + graph_init_params=init_params, + graph_runtime_state=graph_runtime_state, + ) + + node = node_factory.create_node(graph_config["nodes"][0]) + + assert isinstance(node, VariableAssignerNode) diff --git a/api/tests/unit_tests/fields/test_file_fields.py b/api/tests/unit_tests/fields/test_file_fields.py new file mode 100644 index 0000000000..8be8df16f4 --- /dev/null +++ b/api/tests/unit_tests/fields/test_file_fields.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from datetime import datetime +from types import SimpleNamespace + +from fields.file_fields import FileResponse, FileWithSignedUrl, RemoteFileInfo, UploadConfig + + +def test_file_response_serializes_datetime() -> None: + created_at = datetime(2024, 1, 1, 12, 0, 0) + file_obj = SimpleNamespace( + id="file-1", + name="example.txt", + size=1024, + extension="txt", + mime_type="text/plain", + created_by="user-1", + created_at=created_at, + preview_url="https://preview", + source_url="https://source", + original_url="https://origin", + user_id="user-1", + tenant_id="tenant-1", + conversation_id="conv-1", + file_key="key-1", + ) + + serialized = FileResponse.model_validate(file_obj, from_attributes=True).model_dump(mode="json") + + assert serialized["id"] == "file-1" + assert serialized["created_at"] == int(created_at.timestamp()) + assert serialized["preview_url"] == "https://preview" + assert serialized["source_url"] == "https://source" + assert serialized["original_url"] == "https://origin" + assert serialized["user_id"] == "user-1" + assert serialized["tenant_id"] == "tenant-1" + assert serialized["conversation_id"] == "conv-1" + assert serialized["file_key"] == "key-1" + + +def test_file_with_signed_url_builds_payload() -> None: + payload = FileWithSignedUrl( + id="file-2", + name="remote.pdf", + size=2048, + extension="pdf", + url="https://signed", + mime_type="application/pdf", + created_by="user-2", + created_at=datetime(2024, 1, 2, 0, 0, 0), + ) + + dumped = payload.model_dump(mode="json") + + assert dumped["url"] == "https://signed" + assert dumped["created_at"] == int(datetime(2024, 1, 2, 0, 0, 0).timestamp()) + + +def test_remote_file_info_and_upload_config() -> None: + info = RemoteFileInfo(file_type="text/plain", file_length=123) + assert info.model_dump(mode="json") == {"file_type": "text/plain", "file_length": 123} + + config = UploadConfig( + file_size_limit=1, + batch_count_limit=2, + file_upload_limit=3, + image_file_size_limit=4, + video_file_size_limit=5, + audio_file_size_limit=6, + workflow_file_upload_limit=7, + image_file_batch_limit=8, + single_chunk_attachment_limit=9, + attachment_image_file_size_limit=10, + ) + + dumped = config.model_dump(mode="json") + assert dumped["file_upload_limit"] == 3 + assert dumped["attachment_image_file_size_limit"] == 10 diff --git a/api/tests/unit_tests/libs/test_helper.py b/api/tests/unit_tests/libs/test_helper.py index 85789bfa7e..de74eff82f 100644 --- a/api/tests/unit_tests/libs/test_helper.py +++ b/api/tests/unit_tests/libs/test_helper.py @@ -1,6 +1,6 @@ import pytest -from libs.helper import extract_tenant_id +from libs.helper import escape_like_pattern, extract_tenant_id from models.account import Account from models.model import EndUser @@ -63,3 +63,51 @@ class TestExtractTenantId: with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"): extract_tenant_id(dict_user) + + +class TestEscapeLikePattern: + """Test cases for the escape_like_pattern utility function.""" + + def test_escape_percent_character(self): + """Test escaping percent character.""" + result = escape_like_pattern("50% discount") + assert result == "50\\% discount" + + def test_escape_underscore_character(self): + """Test escaping underscore character.""" + result = escape_like_pattern("test_data") + assert result == "test\\_data" + + def test_escape_backslash_character(self): + """Test escaping backslash character.""" + result = escape_like_pattern("path\\to\\file") + assert result == "path\\\\to\\\\file" + + def test_escape_combined_special_characters(self): + """Test escaping multiple special characters together.""" + result = escape_like_pattern("file_50%\\path") + assert result == "file\\_50\\%\\\\path" + + def test_escape_empty_string(self): + """Test escaping empty string returns empty string.""" + result = escape_like_pattern("") + assert result == "" + + def test_escape_none_handling(self): + """Test escaping None returns None (falsy check handles it).""" + # The function checks `if not pattern`, so None is falsy and returns as-is + result = escape_like_pattern(None) + assert result is None + + def test_escape_normal_string_no_change(self): + """Test that normal strings without special characters are unchanged.""" + result = escape_like_pattern("normal text") + assert result == "normal text" + + def test_escape_order_matters(self): + """Test that backslash is escaped first to prevent double escaping.""" + # If we escape % first, then escape \, we might get wrong results + # This test ensures the order is correct: \ first, then % and _ + result = escape_like_pattern("test\\%_value") + # Should be: test\\\%\_value + assert result == "test\\\\\\%\\_value" diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index e35788660d..8be2eea121 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -114,7 +114,7 @@ class TestAppModelValidation: def test_icon_type_validation(self): """Test icon type enum values.""" # Assert - assert {t.value for t in IconType} == {"image", "emoji"} + assert {t.value for t in IconType} == {"image", "emoji", "link"} def test_app_desc_or_prompt_with_description(self): """Test desc_or_prompt property when description exists.""" diff --git a/docker/.env.example b/docker/.env.example index 66d937e8e7..09ee1060e2 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -58,8 +58,8 @@ FILES_URL= INTERNAL_FILES_URL= # Ensure UTF-8 encoding -LANG=en_US.UTF-8 -LC_ALL=en_US.UTF-8 +LANG=C.UTF-8 +LC_ALL=C.UTF-8 PYTHONIOENCODING=utf-8 # ------------------------------ diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 3c88cddf8c..709aff23df 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -475,7 +475,8 @@ services: OB_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OB_SERVER_IP: 127.0.0.1 MODE: mini - LANG: en_US.UTF-8 + LANG: C.UTF-8 + LC_ALL: C.UTF-8 ports: - "${OCEANBASE_VECTOR_PORT:-2881}:2881" healthcheck: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 54b9e744f8..712de84c62 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -13,8 +13,8 @@ x-shared-env: &shared-api-worker-env APP_WEB_URL: ${APP_WEB_URL:-} FILES_URL: ${FILES_URL:-} INTERNAL_FILES_URL: ${INTERNAL_FILES_URL:-} - LANG: ${LANG:-en_US.UTF-8} - LC_ALL: ${LC_ALL:-en_US.UTF-8} + LANG: ${LANG:-C.UTF-8} + LC_ALL: ${LC_ALL:-C.UTF-8} PYTHONIOENCODING: ${PYTHONIOENCODING:-utf-8} LOG_LEVEL: ${LOG_LEVEL:-INFO} LOG_OUTPUT_FORMAT: ${LOG_OUTPUT_FORMAT:-text} @@ -1157,7 +1157,8 @@ services: OB_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OB_SERVER_IP: 127.0.0.1 MODE: mini - LANG: en_US.UTF-8 + LANG: C.UTF-8 + LC_ALL: C.UTF-8 ports: - "${OCEANBASE_VECTOR_PORT:-2881}:2881" healthcheck: diff --git a/web/__tests__/goto-anything/search-error-handling.test.ts b/web/__tests__/goto-anything/search-error-handling.test.ts index 3a495834cd..42eb829583 100644 --- a/web/__tests__/goto-anything/search-error-handling.test.ts +++ b/web/__tests__/goto-anything/search-error-handling.test.ts @@ -14,6 +14,14 @@ import { fetchAppList } from '@/service/apps' import { postMarketplace } from '@/service/base' import { fetchDatasets } from '@/service/datasets' +// Mock react-i18next before importing modules that use it +vi.mock('react-i18next', () => ({ + getI18n: () => ({ + t: (key: string) => key, + language: 'en', + }), +})) + // Mock API functions vi.mock('@/service/base', () => ({ postMarketplace: vi.fn(), diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx index 8080b565cd..1d65e4de53 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx @@ -1,11 +1,8 @@ -/* eslint-disable dify-i18n/require-ns-option */ -import * as React from 'react' +import { useTranslation } from '#i18n' import Form from '@/app/components/datasets/settings/form' -import { getLocaleOnServer, getTranslation } from '@/i18n-config/server' -const Settings = async () => { - const locale = await getLocaleOnServer() - const { t } = await getTranslation(locale, 'dataset-settings') +const Settings = () => { + const { t } = useTranslation('datasetSettings') return (
diff --git a/web/app/(commonLayout)/plugins/page.tsx b/web/app/(commonLayout)/plugins/page.tsx index 2df9cf23c4..81bda3a8a3 100644 --- a/web/app/(commonLayout)/plugins/page.tsx +++ b/web/app/(commonLayout)/plugins/page.tsx @@ -1,14 +1,12 @@ import Marketplace from '@/app/components/plugins/marketplace' import PluginPage from '@/app/components/plugins/plugin-page' import PluginsPanel from '@/app/components/plugins/plugin-page/plugins-panel' -import { getLocaleOnServer } from '@/i18n-config/server' const PluginList = async () => { - const locale = await getLocaleOnServer() return ( } - marketplace={} + marketplace={} /> ) } diff --git a/web/app/components/apps/app-card-skeleton.tsx b/web/app/components/apps/app-card-skeleton.tsx new file mode 100644 index 0000000000..806f19973a --- /dev/null +++ b/web/app/components/apps/app-card-skeleton.tsx @@ -0,0 +1,41 @@ +'use client' + +import * as React from 'react' +import { SkeletonContainer, SkeletonRectangle, SkeletonRow } from '@/app/components/base/skeleton' + +type AppCardSkeletonProps = { + count?: number +} + +/** + * Skeleton placeholder for App cards during loading states. + * Matches the visual layout of AppCard component. + */ +export const AppCardSkeleton = React.memo(({ count = 6 }: AppCardSkeletonProps) => { + return ( + <> + {Array.from({ length: count }).map((_, index) => ( +
+ + + +
+ + +
+
+
+ + +
+
+
+ ))} + + ) +}) + +AppCardSkeleton.displayName = 'AppCardSkeleton' diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index 003b463595..290a73fc7c 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -27,7 +27,9 @@ import { useGlobalPublicStore } from '@/context/global-public-context' import { CheckModal } from '@/hooks/use-pay' import { useInfiniteAppList } from '@/service/use-apps' import { AppModeEnum } from '@/types/app' +import { cn } from '@/utils/classnames' import AppCard from './app-card' +import { AppCardSkeleton } from './app-card-skeleton' import Empty from './empty' import Footer from './footer' import useAppsQueryState from './hooks/use-apps-query-state' @@ -45,7 +47,7 @@ const List = () => { const { t } = useTranslation() const { systemFeatures } = useGlobalPublicStore() const router = useRouter() - const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator } = useAppContext() + const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, isLoadingCurrentWorkspace } = useAppContext() const showTagManagementModal = useTagStore(s => s.showTagManagementModal) const [activeTab, setActiveTab] = useQueryState( 'category', @@ -89,6 +91,7 @@ const List = () => { const { data, isLoading, + isFetching, isFetchingNextPage, fetchNextPage, hasNextPage, @@ -172,6 +175,8 @@ const List = () => { const pages = data?.pages ?? [] const hasAnyApp = (pages[0]?.total ?? 0) > 0 + // Show skeleton during initial load or when refetching with no previous data + const showSkeleton = isLoading || (isFetching && pages.length === 0) return ( <> @@ -205,23 +210,34 @@ const List = () => { />
- {hasAnyApp - ? ( -
- {isCurrentWorkspaceEditor - && } - {pages.map(({ data: apps }) => apps.map(app => ( - - )))} -
- ) - : ( -
- {isCurrentWorkspaceEditor - && } - -
- )} +
+ {(isCurrentWorkspaceEditor || isLoadingCurrentWorkspace) && ( + + )} + {(() => { + if (showSkeleton) + return + + if (hasAnyApp) { + return pages.flatMap(({ data: apps }) => apps).map(app => ( + + )) + } + + // No apps - show empty state + return + })()} +
{isCurrentWorkspaceEditor && (
import('@/app/components/app/create-fro export type CreateAppCardProps = { className?: string + isLoading?: boolean onSuccess?: () => void ref: React.RefObject selectedAppType?: string @@ -33,6 +34,7 @@ export type CreateAppCardProps = { const CreateAppCard = ({ ref, className, + isLoading = false, onSuccess, selectedAppType, }: CreateAppCardProps) => { @@ -56,7 +58,11 @@ const CreateAppCard = ({ return (
{t('createApp', { ns: 'app' })}
diff --git a/web/app/components/base/chat/chat-with-history/hooks.spec.tsx b/web/app/components/base/chat/chat-with-history/hooks.spec.tsx index 32ef133453..a6d51d8643 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.spec.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.spec.tsx @@ -17,7 +17,7 @@ vi.mock('@/hooks/use-app-favicon', () => ({ useAppFavicon: vi.fn(), })) -vi.mock('@/i18n-config/i18next-config', () => ({ +vi.mock('@/i18n-config/client', () => ({ changeLanguage: vi.fn().mockResolvedValue(undefined), })) diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index 5ff8e61ff6..ed1981b530 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -25,7 +25,7 @@ import { useToastContext } from '@/app/components/base/toast' import { InputVarType } from '@/app/components/workflow/types' import { useWebAppStore } from '@/context/web-app-context' import { useAppFavicon } from '@/hooks/use-app-favicon' -import { changeLanguage } from '@/i18n-config/i18next-config' +import { changeLanguage } from '@/i18n-config/client' import { delConversation, pinConversation, diff --git a/web/app/components/base/chat/embedded-chatbot/hooks.spec.tsx b/web/app/components/base/chat/embedded-chatbot/hooks.spec.tsx index ca6a90c4d8..066fb8ebe9 100644 --- a/web/app/components/base/chat/embedded-chatbot/hooks.spec.tsx +++ b/web/app/components/base/chat/embedded-chatbot/hooks.spec.tsx @@ -13,7 +13,7 @@ import { shareQueryKeys } from '@/service/use-share' import { CONVERSATION_ID_INFO } from '../constants' import { useEmbeddedChatbot } from './hooks' -vi.mock('@/i18n-config/i18next-config', () => ({ +vi.mock('@/i18n-config/client', () => ({ changeLanguage: vi.fn().mockResolvedValue(undefined), })) diff --git a/web/app/components/base/chat/embedded-chatbot/hooks.tsx b/web/app/components/base/chat/embedded-chatbot/hooks.tsx index 803e905837..9028d10000 100644 --- a/web/app/components/base/chat/embedded-chatbot/hooks.tsx +++ b/web/app/components/base/chat/embedded-chatbot/hooks.tsx @@ -23,7 +23,7 @@ import { useToastContext } from '@/app/components/base/toast' import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils' import { InputVarType } from '@/app/components/workflow/types' import { useWebAppStore } from '@/context/web-app-context' -import { changeLanguage } from '@/i18n-config/i18next-config' +import { changeLanguage } from '@/i18n-config/client' import { updateFeedback } from '@/service/share' import { useInvalidateShareConversations, diff --git a/web/app/components/base/form/hooks/use-get-form-values.ts b/web/app/components/base/form/hooks/use-get-form-values.ts index 9ea418ea00..3dd2eceb30 100644 --- a/web/app/components/base/form/hooks/use-get-form-values.ts +++ b/web/app/components/base/form/hooks/use-get-form-values.ts @@ -4,7 +4,7 @@ import type { GetValuesOptions, } from '../types' import { useCallback } from 'react' -import { getTransformedValuesWhenSecretInputPristine } from '../utils' +import { getTransformedValuesWhenSecretInputPristine } from '../utils/secret-input' import { useCheckValidated } from './use-check-validated' export const useGetFormValues = (form: AnyFormApi, formSchemas: FormSchema[]) => { diff --git a/web/app/components/base/form/utils/index.ts b/web/app/components/base/form/utils/index.ts deleted file mode 100644 index 0abb8d1ad5..0000000000 --- a/web/app/components/base/form/utils/index.ts +++ /dev/null @@ -1 +0,0 @@ -export * from './secret-input' diff --git a/web/app/components/base/form/utils/zod-submit-validator.ts b/web/app/components/base/form/utils/zod-submit-validator.ts new file mode 100644 index 0000000000..23eacaf8a4 --- /dev/null +++ b/web/app/components/base/form/utils/zod-submit-validator.ts @@ -0,0 +1,22 @@ +import type { ZodSchema } from 'zod' + +type SubmitValidator = ({ value }: { value: T }) => { fields: Record } | undefined + +export const zodSubmitValidator = (schema: ZodSchema): SubmitValidator => { + return ({ value }) => { + const result = schema.safeParse(value) + if (!result.success) { + const fieldErrors: Record = {} + for (const issue of result.error.issues) { + const path = issue.path[0] + if (path === undefined) + continue + const key = String(path) + if (!fieldErrors[key]) + fieldErrors[key] = issue.message + } + return { fields: fieldErrors } + } + return undefined + } +} diff --git a/web/app/components/datasets/create/step-one/components/data-source-type-selector.tsx b/web/app/components/datasets/create/step-one/components/data-source-type-selector.tsx new file mode 100644 index 0000000000..6bdc2ace56 --- /dev/null +++ b/web/app/components/datasets/create/step-one/components/data-source-type-selector.tsx @@ -0,0 +1,97 @@ +'use client' + +import { useCallback, useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import { ENABLE_WEBSITE_FIRECRAWL, ENABLE_WEBSITE_JINAREADER, ENABLE_WEBSITE_WATERCRAWL } from '@/config' +import { DataSourceType } from '@/models/datasets' +import { cn } from '@/utils/classnames' +import s from '../index.module.css' + +type DataSourceTypeSelectorProps = { + currentType: DataSourceType + disabled: boolean + onChange: (type: DataSourceType) => void + onClearPreviews: (type: DataSourceType) => void +} + +type DataSourceLabelKey + = | 'stepOne.dataSourceType.file' + | 'stepOne.dataSourceType.notion' + | 'stepOne.dataSourceType.web' + +type DataSourceOption = { + type: DataSourceType + iconClass?: string + labelKey: DataSourceLabelKey +} + +const DATA_SOURCE_OPTIONS: DataSourceOption[] = [ + { + type: DataSourceType.FILE, + labelKey: 'stepOne.dataSourceType.file', + }, + { + type: DataSourceType.NOTION, + iconClass: s.notion, + labelKey: 'stepOne.dataSourceType.notion', + }, + { + type: DataSourceType.WEB, + iconClass: s.web, + labelKey: 'stepOne.dataSourceType.web', + }, +] + +/** + * Data source type selector component for choosing between file, notion, and web sources. + */ +function DataSourceTypeSelector({ + currentType, + disabled, + onChange, + onClearPreviews, +}: DataSourceTypeSelectorProps) { + const { t } = useTranslation() + + const isWebEnabled = ENABLE_WEBSITE_FIRECRAWL || ENABLE_WEBSITE_JINAREADER || ENABLE_WEBSITE_WATERCRAWL + + const handleTypeChange = useCallback((type: DataSourceType) => { + if (disabled) + return + onChange(type) + onClearPreviews(type) + }, [disabled, onChange, onClearPreviews]) + + const visibleOptions = useMemo(() => DATA_SOURCE_OPTIONS.filter((option) => { + if (option.type === DataSourceType.WEB) + return isWebEnabled + return true + }), [isWebEnabled]) + + return ( +
+ {visibleOptions.map(option => ( +
handleTypeChange(option.type)} + > + + + {t(option.labelKey, { ns: 'datasetCreation' })} + +
+ ))} +
+ ) +} + +export default DataSourceTypeSelector diff --git a/web/app/components/datasets/create/step-one/components/index.ts b/web/app/components/datasets/create/step-one/components/index.ts new file mode 100644 index 0000000000..5271835741 --- /dev/null +++ b/web/app/components/datasets/create/step-one/components/index.ts @@ -0,0 +1,3 @@ +export { default as DataSourceTypeSelector } from './data-source-type-selector' +export { default as NextStepButton } from './next-step-button' +export { default as PreviewPanel } from './preview-panel' diff --git a/web/app/components/datasets/create/step-one/components/next-step-button.tsx b/web/app/components/datasets/create/step-one/components/next-step-button.tsx new file mode 100644 index 0000000000..71e4e87fcf --- /dev/null +++ b/web/app/components/datasets/create/step-one/components/next-step-button.tsx @@ -0,0 +1,30 @@ +'use client' + +import { RiArrowRightLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import Button from '@/app/components/base/button' + +type NextStepButtonProps = { + disabled: boolean + onClick: () => void +} + +/** + * Reusable next step button component for dataset creation flow. + */ +function NextStepButton({ disabled, onClick }: NextStepButtonProps) { + const { t } = useTranslation() + + return ( +
+ +
+ ) +} + +export default NextStepButton diff --git a/web/app/components/datasets/create/step-one/components/preview-panel.tsx b/web/app/components/datasets/create/step-one/components/preview-panel.tsx new file mode 100644 index 0000000000..8ae0b7df55 --- /dev/null +++ b/web/app/components/datasets/create/step-one/components/preview-panel.tsx @@ -0,0 +1,62 @@ +'use client' + +import type { NotionPage } from '@/models/common' +import type { CrawlResultItem } from '@/models/datasets' +import { useTranslation } from 'react-i18next' +import PlanUpgradeModal from '@/app/components/billing/plan-upgrade-modal' +import FilePreview from '../../file-preview' +import NotionPagePreview from '../../notion-page-preview' +import WebsitePreview from '../../website/preview' + +type PreviewPanelProps = { + currentFile: File | undefined + currentNotionPage: NotionPage | undefined + currentWebsite: CrawlResultItem | undefined + notionCredentialId: string + isShowPlanUpgradeModal: boolean + hideFilePreview: () => void + hideNotionPagePreview: () => void + hideWebsitePreview: () => void + hidePlanUpgradeModal: () => void +} + +/** + * Right panel component for displaying file, notion page, or website previews. + */ +function PreviewPanel({ + currentFile, + currentNotionPage, + currentWebsite, + notionCredentialId, + isShowPlanUpgradeModal, + hideFilePreview, + hideNotionPagePreview, + hideWebsitePreview, + hidePlanUpgradeModal, +}: PreviewPanelProps) { + const { t } = useTranslation() + + return ( +
+ {currentFile && } + {currentNotionPage && ( + + )} + {currentWebsite && } + {isShowPlanUpgradeModal && ( + + )} +
+ ) +} + +export default PreviewPanel diff --git a/web/app/components/datasets/create/step-one/hooks/index.ts b/web/app/components/datasets/create/step-one/hooks/index.ts new file mode 100644 index 0000000000..bae5ce4fce --- /dev/null +++ b/web/app/components/datasets/create/step-one/hooks/index.ts @@ -0,0 +1,2 @@ +export { default as usePreviewState } from './use-preview-state' +export type { PreviewActions, PreviewState, UsePreviewStateReturn } from './use-preview-state' diff --git a/web/app/components/datasets/create/step-one/hooks/use-preview-state.ts b/web/app/components/datasets/create/step-one/hooks/use-preview-state.ts new file mode 100644 index 0000000000..3984947ab1 --- /dev/null +++ b/web/app/components/datasets/create/step-one/hooks/use-preview-state.ts @@ -0,0 +1,70 @@ +'use client' + +import type { NotionPage } from '@/models/common' +import type { CrawlResultItem } from '@/models/datasets' +import { useCallback, useState } from 'react' + +export type PreviewState = { + currentFile: File | undefined + currentNotionPage: NotionPage | undefined + currentWebsite: CrawlResultItem | undefined +} + +export type PreviewActions = { + showFilePreview: (file: File) => void + hideFilePreview: () => void + showNotionPagePreview: (page: NotionPage) => void + hideNotionPagePreview: () => void + showWebsitePreview: (website: CrawlResultItem) => void + hideWebsitePreview: () => void +} + +export type UsePreviewStateReturn = PreviewState & PreviewActions + +/** + * Custom hook for managing preview state across different data source types. + * Handles file, notion page, and website preview visibility. + */ +function usePreviewState(): UsePreviewStateReturn { + const [currentFile, setCurrentFile] = useState() + const [currentNotionPage, setCurrentNotionPage] = useState() + const [currentWebsite, setCurrentWebsite] = useState() + + const showFilePreview = useCallback((file: File) => { + setCurrentFile(file) + }, []) + + const hideFilePreview = useCallback(() => { + setCurrentFile(undefined) + }, []) + + const showNotionPagePreview = useCallback((page: NotionPage) => { + setCurrentNotionPage(page) + }, []) + + const hideNotionPagePreview = useCallback(() => { + setCurrentNotionPage(undefined) + }, []) + + const showWebsitePreview = useCallback((website: CrawlResultItem) => { + setCurrentWebsite(website) + }, []) + + const hideWebsitePreview = useCallback(() => { + setCurrentWebsite(undefined) + }, []) + + return { + currentFile, + currentNotionPage, + currentWebsite, + showFilePreview, + hideFilePreview, + showNotionPagePreview, + hideNotionPagePreview, + showWebsitePreview, + hideWebsitePreview, + } +} + +export default usePreviewState diff --git a/web/app/components/datasets/create/step-one/index.spec.tsx b/web/app/components/datasets/create/step-one/index.spec.tsx new file mode 100644 index 0000000000..1ff77dc1f6 --- /dev/null +++ b/web/app/components/datasets/create/step-one/index.spec.tsx @@ -0,0 +1,1204 @@ +import type { DataSourceAuth } from '@/app/components/header/account-setting/data-source-page-new/types' +import type { NotionPage } from '@/models/common' +import type { CrawlOptions, CrawlResultItem, DataSet, FileItem } from '@/models/datasets' +import { act, fireEvent, render, renderHook, screen } from '@testing-library/react' +import { Plan } from '@/app/components/billing/type' +import { DataSourceType } from '@/models/datasets' +import { DataSourceTypeSelector, NextStepButton, PreviewPanel } from './components' +import { usePreviewState } from './hooks' +import StepOne from './index' + +// ========================================== +// Mock External Dependencies +// ========================================== + +// Mock config for website crawl features +vi.mock('@/config', () => ({ + ENABLE_WEBSITE_FIRECRAWL: true, + ENABLE_WEBSITE_JINAREADER: false, + ENABLE_WEBSITE_WATERCRAWL: false, +})) + +// Mock dataset detail context +let mockDatasetDetail: DataSet | undefined +vi.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContextWithSelector: (selector: (state: { dataset: DataSet | undefined }) => DataSet | undefined) => { + return selector({ dataset: mockDatasetDetail }) + }, +})) + +// Mock provider context +let mockPlan = { + type: Plan.professional, + usage: { vectorSpace: 50, buildApps: 0, documentsUploadQuota: 0, vectorStorageQuota: 0 }, + total: { vectorSpace: 100, buildApps: 0, documentsUploadQuota: 0, vectorStorageQuota: 0 }, +} +let mockEnableBilling = false + +vi.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + plan: mockPlan, + enableBilling: mockEnableBilling, + }), +})) + +// Mock child components +vi.mock('../file-uploader', () => ({ + default: ({ onPreview, fileList }: { onPreview: (file: File) => void, fileList: FileItem[] }) => ( +
+ {fileList.length} + +
+ ), +})) + +vi.mock('../website', () => ({ + default: ({ onPreview }: { onPreview: (item: CrawlResultItem) => void }) => ( +
+ +
+ ), +})) + +vi.mock('../empty-dataset-creation-modal', () => ({ + default: ({ show, onHide }: { show: boolean, onHide: () => void }) => ( + show + ? ( +
+ +
+ ) + : null + ), +})) + +// NotionConnector is a base component - imported directly without mock +// It only depends on i18n which is globally mocked + +vi.mock('@/app/components/base/notion-page-selector', () => ({ + NotionPageSelector: ({ onPreview }: { onPreview: (page: NotionPage) => void }) => ( +
+ +
+ ), +})) + +vi.mock('@/app/components/billing/vector-space-full', () => ({ + default: () =>
Vector Space Full
, +})) + +vi.mock('@/app/components/billing/plan-upgrade-modal', () => ({ + default: ({ show, onClose }: { show: boolean, onClose: () => void }) => ( + show + ? ( +
+ +
+ ) + : null + ), +})) + +vi.mock('../file-preview', () => ({ + default: ({ file, hidePreview }: { file: File, hidePreview: () => void }) => ( +
+ {file.name} + +
+ ), +})) + +vi.mock('../notion-page-preview', () => ({ + default: ({ currentPage, hidePreview }: { currentPage: NotionPage, hidePreview: () => void }) => ( +
+ {currentPage.page_id} + +
+ ), +})) + +// WebsitePreview is a sibling component without API dependencies - imported directly +// It only depends on i18n which is globally mocked + +vi.mock('./upgrade-card', () => ({ + default: () =>
Upgrade Card
, +})) + +// ========================================== +// Test Data Builders +// ========================================== + +const createMockCustomFile = (overrides: { id?: string, name?: string } = {}) => { + const file = new File(['test content'], overrides.name ?? 'test.txt', { type: 'text/plain' }) + return Object.assign(file, { + id: overrides.id ?? 'uploaded-id', + extension: 'txt', + mime_type: 'text/plain', + created_by: 'user-1', + created_at: Date.now(), + }) +} + +const createMockFileItem = (overrides: Partial = {}): FileItem => ({ + fileID: `file-${Date.now()}`, + file: createMockCustomFile(overrides.file as { id?: string, name?: string }), + progress: 100, + ...overrides, +}) + +const createMockNotionPage = (overrides: Partial = {}): NotionPage => ({ + page_id: `page-${Date.now()}`, + type: 'page', + ...overrides, +} as NotionPage) + +const createMockCrawlResult = (overrides: Partial = {}): CrawlResultItem => ({ + title: 'Test Page', + markdown: 'Test content', + description: 'Test description', + source_url: 'https://example.com', + ...overrides, +}) + +const createMockDataSourceAuth = (overrides: Partial = {}): DataSourceAuth => ({ + credential_id: 'cred-1', + provider: 'notion_datasource', + plugin_id: 'plugin-1', + credentials_list: [{ id: 'cred-1', name: 'Workspace 1' }], + ...overrides, +} as DataSourceAuth) + +const defaultProps = { + dataSourceType: DataSourceType.FILE, + dataSourceTypeDisable: false, + onSetting: vi.fn(), + files: [] as FileItem[], + updateFileList: vi.fn(), + updateFile: vi.fn(), + notionPages: [] as NotionPage[], + notionCredentialId: '', + updateNotionPages: vi.fn(), + updateNotionCredentialId: vi.fn(), + onStepChange: vi.fn(), + changeType: vi.fn(), + websitePages: [] as CrawlResultItem[], + updateWebsitePages: vi.fn(), + onWebsiteCrawlProviderChange: vi.fn(), + onWebsiteCrawlJobIdChange: vi.fn(), + crawlOptions: { + crawl_sub_pages: true, + only_main_content: true, + includes: '', + excludes: '', + limit: 10, + max_depth: '', + use_sitemap: true, + } as CrawlOptions, + onCrawlOptionsChange: vi.fn(), + authedDataSourceList: [] as DataSourceAuth[], +} + +// ========================================== +// usePreviewState Hook Tests +// ========================================== +describe('usePreviewState Hook', () => { + // -------------------------------------------------------------------------- + // Initial State Tests + // -------------------------------------------------------------------------- + describe('Initial State', () => { + it('should initialize with all preview states undefined', () => { + // Arrange & Act + const { result } = renderHook(() => usePreviewState()) + + // Assert + expect(result.current.currentFile).toBeUndefined() + expect(result.current.currentNotionPage).toBeUndefined() + expect(result.current.currentWebsite).toBeUndefined() + }) + }) + + // -------------------------------------------------------------------------- + // File Preview Tests + // -------------------------------------------------------------------------- + describe('File Preview', () => { + it('should show file preview when showFilePreview is called', () => { + // Arrange + const { result } = renderHook(() => usePreviewState()) + const mockFile = new File(['test'], 'test.txt') + + // Act + act(() => { + result.current.showFilePreview(mockFile) + }) + + // Assert + expect(result.current.currentFile).toBe(mockFile) + }) + + it('should hide file preview when hideFilePreview is called', () => { + // Arrange + const { result } = renderHook(() => usePreviewState()) + const mockFile = new File(['test'], 'test.txt') + + act(() => { + result.current.showFilePreview(mockFile) + }) + + // Act + act(() => { + result.current.hideFilePreview() + }) + + // Assert + expect(result.current.currentFile).toBeUndefined() + }) + }) + + // -------------------------------------------------------------------------- + // Notion Page Preview Tests + // -------------------------------------------------------------------------- + describe('Notion Page Preview', () => { + it('should show notion page preview when showNotionPagePreview is called', () => { + // Arrange + const { result } = renderHook(() => usePreviewState()) + const mockPage = createMockNotionPage() + + // Act + act(() => { + result.current.showNotionPagePreview(mockPage) + }) + + // Assert + expect(result.current.currentNotionPage).toBe(mockPage) + }) + + it('should hide notion page preview when hideNotionPagePreview is called', () => { + // Arrange + const { result } = renderHook(() => usePreviewState()) + const mockPage = createMockNotionPage() + + act(() => { + result.current.showNotionPagePreview(mockPage) + }) + + // Act + act(() => { + result.current.hideNotionPagePreview() + }) + + // Assert + expect(result.current.currentNotionPage).toBeUndefined() + }) + }) + + // -------------------------------------------------------------------------- + // Website Preview Tests + // -------------------------------------------------------------------------- + describe('Website Preview', () => { + it('should show website preview when showWebsitePreview is called', () => { + // Arrange + const { result } = renderHook(() => usePreviewState()) + const mockWebsite = createMockCrawlResult() + + // Act + act(() => { + result.current.showWebsitePreview(mockWebsite) + }) + + // Assert + expect(result.current.currentWebsite).toBe(mockWebsite) + }) + + it('should hide website preview when hideWebsitePreview is called', () => { + // Arrange + const { result } = renderHook(() => usePreviewState()) + const mockWebsite = createMockCrawlResult() + + act(() => { + result.current.showWebsitePreview(mockWebsite) + }) + + // Act + act(() => { + result.current.hideWebsitePreview() + }) + + // Assert + expect(result.current.currentWebsite).toBeUndefined() + }) + }) + + // -------------------------------------------------------------------------- + // Callback Stability Tests (Memoization) + // -------------------------------------------------------------------------- + describe('Callback Stability', () => { + it('should maintain stable showFilePreview callback reference', () => { + // Arrange + const { result, rerender } = renderHook(() => usePreviewState()) + const initialCallback = result.current.showFilePreview + + // Act + rerender() + + // Assert + expect(result.current.showFilePreview).toBe(initialCallback) + }) + + it('should maintain stable hideFilePreview callback reference', () => { + // Arrange + const { result, rerender } = renderHook(() => usePreviewState()) + const initialCallback = result.current.hideFilePreview + + // Act + rerender() + + // Assert + expect(result.current.hideFilePreview).toBe(initialCallback) + }) + + it('should maintain stable showNotionPagePreview callback reference', () => { + // Arrange + const { result, rerender } = renderHook(() => usePreviewState()) + const initialCallback = result.current.showNotionPagePreview + + // Act + rerender() + + // Assert + expect(result.current.showNotionPagePreview).toBe(initialCallback) + }) + + it('should maintain stable hideNotionPagePreview callback reference', () => { + // Arrange + const { result, rerender } = renderHook(() => usePreviewState()) + const initialCallback = result.current.hideNotionPagePreview + + // Act + rerender() + + // Assert + expect(result.current.hideNotionPagePreview).toBe(initialCallback) + }) + + it('should maintain stable showWebsitePreview callback reference', () => { + // Arrange + const { result, rerender } = renderHook(() => usePreviewState()) + const initialCallback = result.current.showWebsitePreview + + // Act + rerender() + + // Assert + expect(result.current.showWebsitePreview).toBe(initialCallback) + }) + + it('should maintain stable hideWebsitePreview callback reference', () => { + // Arrange + const { result, rerender } = renderHook(() => usePreviewState()) + const initialCallback = result.current.hideWebsitePreview + + // Act + rerender() + + // Assert + expect(result.current.hideWebsitePreview).toBe(initialCallback) + }) + }) +}) + +// ========================================== +// DataSourceTypeSelector Component Tests +// ========================================== +describe('DataSourceTypeSelector', () => { + const defaultSelectorProps = { + currentType: DataSourceType.FILE, + disabled: false, + onChange: vi.fn(), + onClearPreviews: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + // -------------------------------------------------------------------------- + // Rendering Tests + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render all data source options when web is enabled', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByText('datasetCreation.stepOne.dataSourceType.file')).toBeInTheDocument() + expect(screen.getByText('datasetCreation.stepOne.dataSourceType.notion')).toBeInTheDocument() + expect(screen.getByText('datasetCreation.stepOne.dataSourceType.web')).toBeInTheDocument() + }) + + it('should highlight active type', () => { + // Arrange & Act + const { container } = render( + , + ) + + // Assert - The active item should have the active class + const items = container.querySelectorAll('[class*="dataSourceItem"]') + expect(items.length).toBeGreaterThan(0) + }) + }) + + // -------------------------------------------------------------------------- + // User Interactions Tests + // -------------------------------------------------------------------------- + describe('User Interactions', () => { + it('should call onChange when a type is clicked', () => { + // Arrange + const onChange = vi.fn() + render() + + // Act + fireEvent.click(screen.getByText('datasetCreation.stepOne.dataSourceType.notion')) + + // Assert + expect(onChange).toHaveBeenCalledWith(DataSourceType.NOTION) + }) + + it('should call onClearPreviews when a type is clicked', () => { + // Arrange + const onClearPreviews = vi.fn() + render() + + // Act + fireEvent.click(screen.getByText('datasetCreation.stepOne.dataSourceType.web')) + + // Assert + expect(onClearPreviews).toHaveBeenCalledWith(DataSourceType.WEB) + }) + + it('should not call onChange when disabled', () => { + // Arrange + const onChange = vi.fn() + render() + + // Act + fireEvent.click(screen.getByText('datasetCreation.stepOne.dataSourceType.notion')) + + // Assert + expect(onChange).not.toHaveBeenCalled() + }) + + it('should not call onClearPreviews when disabled', () => { + // Arrange + const onClearPreviews = vi.fn() + render() + + // Act + fireEvent.click(screen.getByText('datasetCreation.stepOne.dataSourceType.notion')) + + // Assert + expect(onClearPreviews).not.toHaveBeenCalled() + }) + }) +}) + +// ========================================== +// NextStepButton Component Tests +// ========================================== +describe('NextStepButton', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + // -------------------------------------------------------------------------- + // Rendering Tests + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render with correct label', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByText('datasetCreation.stepOne.button')).toBeInTheDocument() + }) + + it('should render with arrow icon', () => { + // Arrange & Act + const { container } = render() + + // Assert + const svgIcon = container.querySelector('svg') + expect(svgIcon).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Props Tests + // -------------------------------------------------------------------------- + describe('Props', () => { + it('should be disabled when disabled prop is true', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByRole('button')).toBeDisabled() + }) + + it('should be enabled when disabled prop is false', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByRole('button')).not.toBeDisabled() + }) + + it('should call onClick when clicked and not disabled', () => { + // Arrange + const onClick = vi.fn() + render() + + // Act + fireEvent.click(screen.getByRole('button')) + + // Assert + expect(onClick).toHaveBeenCalledTimes(1) + }) + + it('should not call onClick when clicked and disabled', () => { + // Arrange + const onClick = vi.fn() + render() + + // Act + fireEvent.click(screen.getByRole('button')) + + // Assert + expect(onClick).not.toHaveBeenCalled() + }) + }) +}) + +// ========================================== +// PreviewPanel Component Tests +// ========================================== +describe('PreviewPanel', () => { + const defaultPreviewProps = { + currentFile: undefined as File | undefined, + currentNotionPage: undefined as NotionPage | undefined, + currentWebsite: undefined as CrawlResultItem | undefined, + notionCredentialId: 'cred-1', + isShowPlanUpgradeModal: false, + hideFilePreview: vi.fn(), + hideNotionPagePreview: vi.fn(), + hideWebsitePreview: vi.fn(), + hidePlanUpgradeModal: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + // -------------------------------------------------------------------------- + // Conditional Rendering Tests + // -------------------------------------------------------------------------- + describe('Conditional Rendering', () => { + it('should not render FilePreview when currentFile is undefined', () => { + // Arrange & Act + render() + + // Assert + expect(screen.queryByTestId('file-preview')).not.toBeInTheDocument() + }) + + it('should render FilePreview when currentFile is defined', () => { + // Arrange + const file = new File(['test'], 'test.txt') + + // Act + render() + + // Assert + expect(screen.getByTestId('file-preview')).toBeInTheDocument() + }) + + it('should not render NotionPagePreview when currentNotionPage is undefined', () => { + // Arrange & Act + render() + + // Assert + expect(screen.queryByTestId('notion-page-preview')).not.toBeInTheDocument() + }) + + it('should render NotionPagePreview when currentNotionPage is defined', () => { + // Arrange + const page = createMockNotionPage() + + // Act + render() + + // Assert + expect(screen.getByTestId('notion-page-preview')).toBeInTheDocument() + }) + + it('should not render WebsitePreview when currentWebsite is undefined', () => { + // Arrange & Act + render() + + // Assert - pagePreview is the title shown in WebsitePreview + expect(screen.queryByText('datasetCreation.stepOne.pagePreview')).not.toBeInTheDocument() + }) + + it('should render WebsitePreview when currentWebsite is defined', () => { + // Arrange + const website = createMockCrawlResult() + + // Act + render() + + // Assert - Check for the preview title and source URL + expect(screen.getByText('datasetCreation.stepOne.pagePreview')).toBeInTheDocument() + expect(screen.getByText(website.source_url)).toBeInTheDocument() + }) + + it('should not render PlanUpgradeModal when isShowPlanUpgradeModal is false', () => { + // Arrange & Act + render() + + // Assert + expect(screen.queryByTestId('plan-upgrade-modal')).not.toBeInTheDocument() + }) + + it('should render PlanUpgradeModal when isShowPlanUpgradeModal is true', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByTestId('plan-upgrade-modal')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Event Handler Tests + // -------------------------------------------------------------------------- + describe('Event Handlers', () => { + it('should call hideFilePreview when file preview close is clicked', () => { + // Arrange + const hideFilePreview = vi.fn() + const file = new File(['test'], 'test.txt') + render() + + // Act + fireEvent.click(screen.getByTestId('hide-file-preview')) + + // Assert + expect(hideFilePreview).toHaveBeenCalledTimes(1) + }) + + it('should call hideNotionPagePreview when notion preview close is clicked', () => { + // Arrange + const hideNotionPagePreview = vi.fn() + const page = createMockNotionPage() + render() + + // Act + fireEvent.click(screen.getByTestId('hide-notion-preview')) + + // Assert + expect(hideNotionPagePreview).toHaveBeenCalledTimes(1) + }) + + it('should call hideWebsitePreview when website preview close is clicked', () => { + // Arrange + const hideWebsitePreview = vi.fn() + const website = createMockCrawlResult() + const { container } = render() + + // Act - Find the close button (div with cursor-pointer class containing the XMarkIcon) + const closeButton = container.querySelector('.cursor-pointer') + expect(closeButton).toBeInTheDocument() + fireEvent.click(closeButton!) + + // Assert + expect(hideWebsitePreview).toHaveBeenCalledTimes(1) + }) + + it('should call hidePlanUpgradeModal when modal close is clicked', () => { + // Arrange + const hidePlanUpgradeModal = vi.fn() + render() + + // Act + fireEvent.click(screen.getByTestId('close-upgrade-modal')) + + // Assert + expect(hidePlanUpgradeModal).toHaveBeenCalledTimes(1) + }) + }) +}) + +// ========================================== +// StepOne Component Tests +// ========================================== +describe('StepOne', () => { + beforeEach(() => { + vi.clearAllMocks() + mockDatasetDetail = undefined + mockPlan = { + type: Plan.professional, + usage: { vectorSpace: 50, buildApps: 0, documentsUploadQuota: 0, vectorStorageQuota: 0 }, + total: { vectorSpace: 100, buildApps: 0, documentsUploadQuota: 0, vectorStorageQuota: 0 }, + } + mockEnableBilling = false + }) + + // -------------------------------------------------------------------------- + // Rendering Tests + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render without crashing', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByText('datasetCreation.steps.one')).toBeInTheDocument() + }) + + it('should render DataSourceTypeSelector when not editing existing dataset', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByText('datasetCreation.stepOne.dataSourceType.file')).toBeInTheDocument() + }) + + it('should render FileUploader when dataSourceType is FILE', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByTestId('file-uploader')).toBeInTheDocument() + }) + + it('should render NotionConnector when dataSourceType is NOTION and not authenticated', () => { + // Arrange & Act + render() + + // Assert - NotionConnector shows sync title and connect button + expect(screen.getByText('datasetCreation.stepOne.notionSyncTitle')).toBeInTheDocument() + expect(screen.getByRole('button', { name: /datasetCreation.stepOne.connect/i })).toBeInTheDocument() + }) + + it('should render NotionPageSelector when dataSourceType is NOTION and authenticated', () => { + // Arrange + const authedDataSourceList = [createMockDataSourceAuth()] + + // Act + render() + + // Assert + expect(screen.getByTestId('notion-page-selector')).toBeInTheDocument() + }) + + it('should render Website when dataSourceType is WEB', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByTestId('website')).toBeInTheDocument() + }) + + it('should render empty dataset creation link when no datasetId', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByText('datasetCreation.stepOne.emptyDatasetCreation')).toBeInTheDocument() + }) + + it('should not render empty dataset creation link when datasetId exists', () => { + // Arrange & Act + render() + + // Assert + expect(screen.queryByText('datasetCreation.stepOne.emptyDatasetCreation')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Props Tests + // -------------------------------------------------------------------------- + describe('Props', () => { + it('should pass files to FileUploader', () => { + // Arrange + const files = [createMockFileItem()] + + // Act + render() + + // Assert + expect(screen.getByTestId('file-count')).toHaveTextContent('1') + }) + + it('should call onSetting when NotionConnector connect button is clicked', () => { + // Arrange + const onSetting = vi.fn() + render() + + // Act - The NotionConnector's button calls onSetting + fireEvent.click(screen.getByRole('button', { name: /datasetCreation.stepOne.connect/i })) + + // Assert + expect(onSetting).toHaveBeenCalledTimes(1) + }) + + it('should call changeType when data source type is changed', () => { + // Arrange + const changeType = vi.fn() + render() + + // Act + fireEvent.click(screen.getByText('datasetCreation.stepOne.dataSourceType.notion')) + + // Assert + expect(changeType).toHaveBeenCalledWith(DataSourceType.NOTION) + }) + }) + + // -------------------------------------------------------------------------- + // State Management Tests + // -------------------------------------------------------------------------- + describe('State Management', () => { + it('should open empty dataset modal when link is clicked', () => { + // Arrange + render() + + // Act + fireEvent.click(screen.getByText('datasetCreation.stepOne.emptyDatasetCreation')) + + // Assert + expect(screen.getByTestId('empty-dataset-modal')).toBeInTheDocument() + }) + + it('should close empty dataset modal when close is clicked', () => { + // Arrange + render() + fireEvent.click(screen.getByText('datasetCreation.stepOne.emptyDatasetCreation')) + + // Act + fireEvent.click(screen.getByTestId('close-modal')) + + // Assert + expect(screen.queryByTestId('empty-dataset-modal')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Memoization Tests + // -------------------------------------------------------------------------- + describe('Memoization', () => { + it('should correctly compute isNotionAuthed based on authedDataSourceList', () => { + // Arrange - No auth + const { rerender } = render() + // NotionConnector shows the sync title when not authenticated + expect(screen.getByText('datasetCreation.stepOne.notionSyncTitle')).toBeInTheDocument() + + // Act - Add auth + const authedDataSourceList = [createMockDataSourceAuth()] + rerender() + + // Assert + expect(screen.getByTestId('notion-page-selector')).toBeInTheDocument() + }) + + it('should correctly compute fileNextDisabled when files are empty', () => { + // Arrange & Act + render() + + // Assert - Button should be disabled + expect(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })).toBeDisabled() + }) + + it('should correctly compute fileNextDisabled when files are loaded', () => { + // Arrange + const files = [createMockFileItem()] + + // Act + render() + + // Assert - Button should be enabled + expect(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })).not.toBeDisabled() + }) + + it('should correctly compute fileNextDisabled when some files are not uploaded', () => { + // Arrange - Create a file item without id (not yet uploaded) + const file = new File(['test'], 'test.txt', { type: 'text/plain' }) + const fileItem: FileItem = { + fileID: 'temp-id', + file: Object.assign(file, { id: undefined, extension: 'txt', mime_type: 'text/plain' }), + progress: 0, + } + + // Act + render() + + // Assert - Button should be disabled + expect(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })).toBeDisabled() + }) + }) + + // -------------------------------------------------------------------------- + // Callback Tests + // -------------------------------------------------------------------------- + describe('Callbacks', () => { + it('should call onStepChange when next button is clicked with valid files', () => { + // Arrange + const onStepChange = vi.fn() + const files = [createMockFileItem()] + render() + + // Act + fireEvent.click(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })) + + // Assert + expect(onStepChange).toHaveBeenCalledTimes(1) + }) + + it('should show plan upgrade modal when batch upload not supported and multiple files', () => { + // Arrange + mockEnableBilling = true + mockPlan.type = Plan.sandbox + const files = [createMockFileItem(), createMockFileItem()] + render() + + // Act + fireEvent.click(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })) + + // Assert + expect(screen.getByTestId('plan-upgrade-modal')).toBeInTheDocument() + }) + + it('should show upgrade card when in sandbox plan with files', () => { + // Arrange + mockEnableBilling = true + mockPlan.type = Plan.sandbox + const files = [createMockFileItem()] + + // Act + render() + + // Assert + expect(screen.getByTestId('upgrade-card')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Vector Space Full Tests + // -------------------------------------------------------------------------- + describe('Vector Space Full', () => { + it('should show VectorSpaceFull when vector space is full and billing is enabled', () => { + // Arrange + mockEnableBilling = true + mockPlan.usage.vectorSpace = 100 + mockPlan.total.vectorSpace = 100 + const files = [createMockFileItem()] + + // Act + render() + + // Assert + expect(screen.getByTestId('vector-space-full')).toBeInTheDocument() + }) + + it('should disable next button when vector space is full', () => { + // Arrange + mockEnableBilling = true + mockPlan.usage.vectorSpace = 100 + mockPlan.total.vectorSpace = 100 + const files = [createMockFileItem()] + + // Act + render() + + // Assert + expect(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })).toBeDisabled() + }) + }) + + // -------------------------------------------------------------------------- + // Preview Integration Tests + // -------------------------------------------------------------------------- + describe('Preview Integration', () => { + it('should show file preview when file preview button is clicked', () => { + // Arrange + render() + + // Act + fireEvent.click(screen.getByTestId('preview-file')) + + // Assert + expect(screen.getByTestId('file-preview')).toBeInTheDocument() + }) + + it('should hide file preview when hide button is clicked', () => { + // Arrange + render() + fireEvent.click(screen.getByTestId('preview-file')) + + // Act + fireEvent.click(screen.getByTestId('hide-file-preview')) + + // Assert + expect(screen.queryByTestId('file-preview')).not.toBeInTheDocument() + }) + + it('should show notion page preview when preview button is clicked', () => { + // Arrange + const authedDataSourceList = [createMockDataSourceAuth()] + render() + + // Act + fireEvent.click(screen.getByTestId('preview-notion')) + + // Assert + expect(screen.getByTestId('notion-page-preview')).toBeInTheDocument() + }) + + it('should show website preview when preview button is clicked', () => { + // Arrange + render() + + // Act + fireEvent.click(screen.getByTestId('preview-website')) + + // Assert - Check for pagePreview title which is shown by WebsitePreview + expect(screen.getByText('datasetCreation.stepOne.pagePreview')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle empty notionPages array', () => { + // Arrange + const authedDataSourceList = [createMockDataSourceAuth()] + + // Act + render() + + // Assert - Button should be disabled when no pages selected + expect(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })).toBeDisabled() + }) + + it('should handle empty websitePages array', () => { + // Arrange & Act + render() + + // Assert - Button should be disabled when no pages crawled + expect(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })).toBeDisabled() + }) + + it('should handle empty authedDataSourceList', () => { + // Arrange & Act + render() + + // Assert - Should show NotionConnector with connect button + expect(screen.getByText('datasetCreation.stepOne.notionSyncTitle')).toBeInTheDocument() + }) + + it('should handle authedDataSourceList without notion credentials', () => { + // Arrange + const authedDataSourceList = [createMockDataSourceAuth({ credentials_list: [] })] + + // Act + render() + + // Assert - Should show NotionConnector with connect button + expect(screen.getByText('datasetCreation.stepOne.notionSyncTitle')).toBeInTheDocument() + }) + + it('should clear previews when switching data source types', () => { + // Arrange + render() + fireEvent.click(screen.getByTestId('preview-file')) + expect(screen.getByTestId('file-preview')).toBeInTheDocument() + + // Act - Change to NOTION + fireEvent.click(screen.getByText('datasetCreation.stepOne.dataSourceType.notion')) + + // Assert - File preview should be cleared + expect(screen.queryByTestId('file-preview')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Integration Tests + // -------------------------------------------------------------------------- + describe('Integration', () => { + it('should complete file upload flow', () => { + // Arrange + const onStepChange = vi.fn() + const files = [createMockFileItem()] + + // Act + render() + fireEvent.click(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })) + + // Assert + expect(onStepChange).toHaveBeenCalled() + }) + + it('should complete notion page selection flow', () => { + // Arrange + const onStepChange = vi.fn() + const authedDataSourceList = [createMockDataSourceAuth()] + const notionPages = [createMockNotionPage()] + + // Act + render( + , + ) + fireEvent.click(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })) + + // Assert + expect(onStepChange).toHaveBeenCalled() + }) + + it('should complete website crawl flow', () => { + // Arrange + const onStepChange = vi.fn() + const websitePages = [createMockCrawlResult()] + + // Act + render( + , + ) + fireEvent.click(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })) + + // Assert + expect(onStepChange).toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/datasets/create/step-one/index.tsx b/web/app/components/datasets/create/step-one/index.tsx index 5c74e69e7e..a86c9d86c2 100644 --- a/web/app/components/datasets/create/step-one/index.tsx +++ b/web/app/components/datasets/create/step-one/index.tsx @@ -1,29 +1,25 @@ 'use client' + import type { DataSourceAuth } from '@/app/components/header/account-setting/data-source-page-new/types' import type { DataSourceProvider, NotionPage } from '@/models/common' import type { CrawlOptions, CrawlResultItem, FileItem } from '@/models/datasets' -import { RiArrowRightLine, RiFolder6Line } from '@remixicon/react' +import { RiFolder6Line } from '@remixicon/react' import { useBoolean } from 'ahooks' -import * as React from 'react' -import { useCallback, useMemo, useState } from 'react' +import { useCallback, useMemo } from 'react' import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' import NotionConnector from '@/app/components/base/notion-connector' import { NotionPageSelector } from '@/app/components/base/notion-page-selector' -import PlanUpgradeModal from '@/app/components/billing/plan-upgrade-modal' import { Plan } from '@/app/components/billing/type' import VectorSpaceFull from '@/app/components/billing/vector-space-full' -import { ENABLE_WEBSITE_FIRECRAWL, ENABLE_WEBSITE_JINAREADER, ENABLE_WEBSITE_WATERCRAWL } from '@/config' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { useProviderContext } from '@/context/provider-context' import { DataSourceType } from '@/models/datasets' import { cn } from '@/utils/classnames' import EmptyDatasetCreationModal from '../empty-dataset-creation-modal' -import FilePreview from '../file-preview' import FileUploader from '../file-uploader' -import NotionPagePreview from '../notion-page-preview' import Website from '../website' -import WebsitePreview from '../website/preview' +import { DataSourceTypeSelector, NextStepButton, PreviewPanel } from './components' +import { usePreviewState } from './hooks' import s from './index.module.css' import UpgradeCard from './upgrade-card' @@ -50,6 +46,24 @@ type IStepOneProps = { authedDataSourceList: DataSourceAuth[] } +// Helper function to check if notion is authenticated +function checkNotionAuth(authedDataSourceList: DataSourceAuth[]): boolean { + const notionSource = authedDataSourceList.find(item => item.provider === 'notion_datasource') + return Boolean(notionSource && notionSource.credentials_list.length > 0) +} + +// Helper function to get notion credential list +function getNotionCredentialList(authedDataSourceList: DataSourceAuth[]) { + return authedDataSourceList.find(item => item.provider === 'notion_datasource')?.credentials_list || [] +} + +// Lookup table for checking multiple items by data source type +const MULTIPLE_ITEMS_CHECK: Record boolean> = { + [DataSourceType.FILE]: ({ files }) => files.length > 1, + [DataSourceType.NOTION]: ({ notionPages }) => notionPages.length > 1, + [DataSourceType.WEB]: ({ websitePages }) => websitePages.length > 1, +} + const StepOne = ({ datasetId, dataSourceType: inCreatePageDataSourceType, @@ -72,76 +86,47 @@ const StepOne = ({ onCrawlOptionsChange, authedDataSourceList, }: IStepOneProps) => { - const dataset = useDatasetDetailContextWithSelector(state => state.dataset) - const [showModal, setShowModal] = useState(false) - const [currentFile, setCurrentFile] = useState() - const [currentNotionPage, setCurrentNotionPage] = useState() - const [currentWebsite, setCurrentWebsite] = useState() const { t } = useTranslation() + const dataset = useDatasetDetailContextWithSelector(state => state.dataset) + const { plan, enableBilling } = useProviderContext() - const modalShowHandle = () => setShowModal(true) - const modalCloseHandle = () => setShowModal(false) + // Preview state management + const { + currentFile, + currentNotionPage, + currentWebsite, + showFilePreview, + hideFilePreview, + showNotionPagePreview, + hideNotionPagePreview, + showWebsitePreview, + hideWebsitePreview, + } = usePreviewState() - const updateCurrentFile = useCallback((file: File) => { - setCurrentFile(file) - }, []) + // Empty dataset modal state + const [showModal, { setTrue: openModal, setFalse: closeModal }] = useBoolean(false) - const hideFilePreview = useCallback(() => { - setCurrentFile(undefined) - }, []) - - const updateCurrentPage = useCallback((page: NotionPage) => { - setCurrentNotionPage(page) - }, []) - - const hideNotionPagePreview = useCallback(() => { - setCurrentNotionPage(undefined) - }, []) - - const updateWebsite = useCallback((website: CrawlResultItem) => { - setCurrentWebsite(website) - }, []) - - const hideWebsitePreview = useCallback(() => { - setCurrentWebsite(undefined) - }, []) + // Plan upgrade modal state + const [isShowPlanUpgradeModal, { setTrue: showPlanUpgradeModal, setFalse: hidePlanUpgradeModal }] = useBoolean(false) + // Computed values const shouldShowDataSourceTypeList = !datasetId || (datasetId && !dataset?.data_source_type) const isInCreatePage = shouldShowDataSourceTypeList - const dataSourceType = isInCreatePage ? inCreatePageDataSourceType : dataset?.data_source_type - const { plan, enableBilling } = useProviderContext() - const allFileLoaded = (files.length > 0 && files.every(file => file.file.id)) - const hasNotin = notionPages.length > 0 + // Default to FILE type when no type is provided from either source + const dataSourceType = isInCreatePage + ? (inCreatePageDataSourceType ?? DataSourceType.FILE) + : (dataset?.data_source_type ?? DataSourceType.FILE) + + const allFileLoaded = files.length > 0 && files.every(file => file.file.id) + const hasNotion = notionPages.length > 0 const isVectorSpaceFull = plan.usage.vectorSpace >= plan.total.vectorSpace - const isShowVectorSpaceFull = (allFileLoaded || hasNotin) && isVectorSpaceFull && enableBilling + const isShowVectorSpaceFull = (allFileLoaded || hasNotion) && isVectorSpaceFull && enableBilling const supportBatchUpload = !enableBilling || plan.type !== Plan.sandbox - const notSupportBatchUpload = !supportBatchUpload - const [isShowPlanUpgradeModal, { - setTrue: showPlanUpgradeModal, - setFalse: hidePlanUpgradeModal, - }] = useBoolean(false) - const onStepChange = useCallback(() => { - if (notSupportBatchUpload) { - let isMultiple = false - if (dataSourceType === DataSourceType.FILE && files.length > 1) - isMultiple = true + const isNotionAuthed = useMemo(() => checkNotionAuth(authedDataSourceList), [authedDataSourceList]) + const notionCredentialList = useMemo(() => getNotionCredentialList(authedDataSourceList), [authedDataSourceList]) - if (dataSourceType === DataSourceType.NOTION && notionPages.length > 1) - isMultiple = true - - if (dataSourceType === DataSourceType.WEB && websitePages.length > 1) - isMultiple = true - - if (isMultiple) { - showPlanUpgradeModal() - return - } - } - doOnStepChange() - }, [dataSourceType, doOnStepChange, files.length, notSupportBatchUpload, notionPages.length, showPlanUpgradeModal, websitePages.length]) - - const nextDisabled = useMemo(() => { + const fileNextDisabled = useMemo(() => { if (!files.length) return true if (files.some(file => !file.file.id)) @@ -149,109 +134,50 @@ const StepOne = ({ return isShowVectorSpaceFull }, [files, isShowVectorSpaceFull]) - const isNotionAuthed = useMemo(() => { - if (!authedDataSourceList) - return false - const notionSource = authedDataSourceList.find(item => item.provider === 'notion_datasource') - if (!notionSource) - return false - return notionSource.credentials_list.length > 0 - }, [authedDataSourceList]) + // Clear previews when switching data source type + const handleClearPreviews = useCallback((newType: DataSourceType) => { + if (newType !== DataSourceType.FILE) + hideFilePreview() + if (newType !== DataSourceType.NOTION) + hideNotionPagePreview() + if (newType !== DataSourceType.WEB) + hideWebsitePreview() + }, [hideFilePreview, hideNotionPagePreview, hideWebsitePreview]) - const notionCredentialList = useMemo(() => { - return authedDataSourceList.find(item => item.provider === 'notion_datasource')?.credentials_list || [] - }, [authedDataSourceList]) + // Handle step change with batch upload check + const onStepChange = useCallback(() => { + if (!supportBatchUpload && dataSourceType) { + const checkFn = MULTIPLE_ITEMS_CHECK[dataSourceType] + if (checkFn?.({ files, notionPages, websitePages })) { + showPlanUpgradeModal() + return + } + } + doOnStepChange() + }, [dataSourceType, doOnStepChange, files, supportBatchUpload, notionPages, showPlanUpgradeModal, websitePages]) return (
+ {/* Left Panel - Form */}
- { - shouldShowDataSourceTypeList && ( + {shouldShowDataSourceTypeList && ( + <>
{t('steps.one', { ns: 'datasetCreation' })}
- ) - } - { - shouldShowDataSourceTypeList && ( -
-
{ - if (dataSourceTypeDisable) - return - changeType(DataSourceType.FILE) - hideNotionPagePreview() - hideWebsitePreview() - }} - > - - - {t('stepOne.dataSourceType.file', { ns: 'datasetCreation' })} - -
-
{ - if (dataSourceTypeDisable) - return - changeType(DataSourceType.NOTION) - hideFilePreview() - hideWebsitePreview() - }} - > - - - {t('stepOne.dataSourceType.notion', { ns: 'datasetCreation' })} - -
- {(ENABLE_WEBSITE_FIRECRAWL || ENABLE_WEBSITE_JINAREADER || ENABLE_WEBSITE_WATERCRAWL) && ( -
{ - if (dataSourceTypeDisable) - return - changeType(DataSourceType.WEB) - hideFilePreview() - hideNotionPagePreview() - }} - > - - - {t('stepOne.dataSourceType.web', { ns: 'datasetCreation' })} - -
- )} -
- ) - } + + + )} + + {/* File Data Source */} {dataSourceType === DataSourceType.FILE && ( <> {isShowVectorSpaceFull && ( @@ -268,24 +194,17 @@ const StepOne = ({
)} -
- -
- { - enableBilling && plan.type === Plan.sandbox && files.length > 0 && ( -
-
- -
- ) - } + + {enableBilling && plan.type === Plan.sandbox && files.length > 0 && ( +
+
+ +
+ )} )} + + {/* Notion Data Source */} {dataSourceType === DataSourceType.NOTION && ( <> {!isNotionAuthed && } @@ -295,7 +214,7 @@ const StepOne = ({ page.page_id)} onSelect={updateNotionPages} - onPreview={updateCurrentPage} + onPreview={showNotionPagePreview} credentialList={notionCredentialList} onSelectCredential={updateNotionCredentialId} datasetId={datasetId} @@ -306,23 +225,21 @@ const StepOne = ({
)} -
- -
+ )} )} + + {/* Web Data Source */} {dataSourceType === DataSourceType.WEB && ( <>
)} -
- -
+ )} + + {/* Empty Dataset Creation Link */} {!datasetId && ( <>
- + {t('stepOne.emptyDatasetCreation', { ns: 'datasetCreation' })} )}
- +
-
- {currentFile && } - {currentNotionPage && ( - - )} - {currentWebsite && } - {isShowPlanUpgradeModal && ( - - )} -
+ + {/* Right Panel - Preview */} +
) diff --git a/web/app/components/datasets/list/dataset-card/components/corner-labels.tsx b/web/app/components/datasets/list/dataset-card/components/corner-labels.tsx new file mode 100644 index 0000000000..03ca543ee7 --- /dev/null +++ b/web/app/components/datasets/list/dataset-card/components/corner-labels.tsx @@ -0,0 +1,36 @@ +import type { DataSet } from '@/models/datasets' +import * as React from 'react' +import { useTranslation } from 'react-i18next' +import CornerLabel from '@/app/components/base/corner-label' + +type CornerLabelsProps = { + dataset: DataSet +} + +const CornerLabels = ({ dataset }: CornerLabelsProps) => { + const { t } = useTranslation() + + if (!dataset.embedding_available) { + return ( + + ) + } + + if (dataset.runtime_mode === 'rag_pipeline') { + return ( + + ) + } + + return null +} + +export default React.memo(CornerLabels) diff --git a/web/app/components/datasets/list/dataset-card/components/dataset-card-footer.tsx b/web/app/components/datasets/list/dataset-card/components/dataset-card-footer.tsx new file mode 100644 index 0000000000..854f34f49c --- /dev/null +++ b/web/app/components/datasets/list/dataset-card/components/dataset-card-footer.tsx @@ -0,0 +1,62 @@ +import type { DataSet } from '@/models/datasets' +import { RiFileTextFill, RiRobot2Fill } from '@remixicon/react' +import * as React from 'react' +import { useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import Tooltip from '@/app/components/base/tooltip' +import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now' +import { cn } from '@/utils/classnames' + +const EXTERNAL_PROVIDER = 'external' + +type DatasetCardFooterProps = { + dataset: DataSet +} + +const DatasetCardFooter = ({ dataset }: DatasetCardFooterProps) => { + const { t } = useTranslation() + const { formatTimeFromNow } = useFormatTimeFromNow() + const isExternalProvider = dataset.provider === EXTERNAL_PROVIDER + + const documentCount = useMemo(() => { + const availableDocCount = dataset.total_available_documents ?? 0 + if (availableDocCount < dataset.document_count) + return `${availableDocCount} / ${dataset.document_count}` + return `${dataset.document_count}` + }, [dataset.document_count, dataset.total_available_documents]) + + const documentCountTooltip = useMemo(() => { + const availableDocCount = dataset.total_available_documents ?? 0 + if (availableDocCount < dataset.document_count) + return t('partialEnabled', { ns: 'dataset', count: dataset.document_count, num: availableDocCount }) + return t('docAllEnabled', { ns: 'dataset', count: availableDocCount }) + }, [t, dataset.document_count, dataset.total_available_documents]) + + return ( +
+ +
+ + {documentCount} +
+
+ {!isExternalProvider && ( + +
+ + {dataset.app_count} +
+
+ )} + / + {`${t('updated', { ns: 'dataset' })} ${formatTimeFromNow(dataset.updated_at * 1000)}`} +
+ ) +} + +export default React.memo(DatasetCardFooter) diff --git a/web/app/components/datasets/list/dataset-card/components/dataset-card-header.tsx b/web/app/components/datasets/list/dataset-card/components/dataset-card-header.tsx new file mode 100644 index 0000000000..abe7595e14 --- /dev/null +++ b/web/app/components/datasets/list/dataset-card/components/dataset-card-header.tsx @@ -0,0 +1,148 @@ +import type { DataSet } from '@/models/datasets' +import * as React from 'react' +import { useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import AppIcon from '@/app/components/base/app-icon' +import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now' +import { useKnowledge } from '@/hooks/use-knowledge' +import { DOC_FORM_ICON_WITH_BG, DOC_FORM_TEXT } from '@/models/datasets' +import { cn } from '@/utils/classnames' + +const EXTERNAL_PROVIDER = 'external' + +type DatasetCardHeaderProps = { + dataset: DataSet +} + +// DocModeInfo component - placed before usage +type DocModeInfoProps = { + dataset: DataSet + isExternalProvider: boolean + isShowDocModeInfo: boolean +} + +const DocModeInfo = ({ + dataset, + isExternalProvider, + isShowDocModeInfo, +}: DocModeInfoProps) => { + const { t } = useTranslation() + const { formatIndexingTechniqueAndMethod } = useKnowledge() + + if (isExternalProvider) { + return ( +
+ {t('externalKnowledgeBase', { ns: 'dataset' })} +
+ ) + } + + if (!isShowDocModeInfo) + return null + + const indexingText = dataset.indexing_technique + ? formatIndexingTechniqueAndMethod( + dataset.indexing_technique as 'economy' | 'high_quality', + dataset.retrieval_model_dict?.search_method as Parameters[1], + ) + : '' + + return ( +
+ {dataset.doc_form && ( + + {t(`chunkingMode.${DOC_FORM_TEXT[dataset.doc_form]}`, { ns: 'dataset' })} + + )} + {dataset.indexing_technique && indexingText && ( + + {indexingText} + + )} + {dataset.is_multimodal && ( + + {t('multimodal', { ns: 'dataset' })} + + )} +
+ ) +} + +// Main DatasetCardHeader component +const DatasetCardHeader = ({ dataset }: DatasetCardHeaderProps) => { + const { t } = useTranslation() + const { formatTimeFromNow } = useFormatTimeFromNow() + + const isExternalProvider = dataset.provider === EXTERNAL_PROVIDER + + const isShowChunkingModeIcon = dataset.doc_form && (dataset.runtime_mode !== 'rag_pipeline' || dataset.is_published) + const isShowDocModeInfo = Boolean( + dataset.doc_form + && dataset.indexing_technique + && dataset.retrieval_model_dict?.search_method + && (dataset.runtime_mode !== 'rag_pipeline' || dataset.is_published), + ) + + const chunkingModeIcon = dataset.doc_form ? DOC_FORM_ICON_WITH_BG[dataset.doc_form] : React.Fragment + const Icon = isExternalProvider ? DOC_FORM_ICON_WITH_BG.external : chunkingModeIcon + + const iconInfo = useMemo(() => dataset.icon_info || { + icon: '📙', + icon_type: 'emoji' as const, + icon_background: '#FFF4ED', + icon_url: '', + }, [dataset.icon_info]) + + const editTimeText = useMemo( + () => `${t('segment.editedAt', { ns: 'datasetDocuments' })} ${formatTimeFromNow(dataset.updated_at * 1000)}`, + [t, dataset.updated_at, formatTimeFromNow], + ) + + return ( +
+
+ + {(isShowChunkingModeIcon || isExternalProvider) && ( +
+ +
+ )} +
+
+
+ {dataset.name} +
+
+
{dataset.author_name}
+
·
+
{editTimeText}
+
+ +
+
+ ) +} + +export default React.memo(DatasetCardHeader) diff --git a/web/app/components/datasets/list/dataset-card/components/dataset-card-modals.tsx b/web/app/components/datasets/list/dataset-card/components/dataset-card-modals.tsx new file mode 100644 index 0000000000..8162bc94c4 --- /dev/null +++ b/web/app/components/datasets/list/dataset-card/components/dataset-card-modals.tsx @@ -0,0 +1,55 @@ +import type { DataSet } from '@/models/datasets' +import * as React from 'react' +import { useTranslation } from 'react-i18next' +import Confirm from '@/app/components/base/confirm' +import RenameDatasetModal from '../../../rename-modal' + +type ModalState = { + showRenameModal: boolean + showConfirmDelete: boolean + confirmMessage: string +} + +type DatasetCardModalsProps = { + dataset: DataSet + modalState: ModalState + onCloseRename: () => void + onCloseConfirm: () => void + onConfirmDelete: () => void + onSuccess?: () => void +} + +const DatasetCardModals = ({ + dataset, + modalState, + onCloseRename, + onCloseConfirm, + onConfirmDelete, + onSuccess, +}: DatasetCardModalsProps) => { + const { t } = useTranslation() + + return ( + <> + {modalState.showRenameModal && ( + + )} + {modalState.showConfirmDelete && ( + + )} + + ) +} + +export default React.memo(DatasetCardModals) diff --git a/web/app/components/datasets/list/dataset-card/components/description.tsx b/web/app/components/datasets/list/dataset-card/components/description.tsx new file mode 100644 index 0000000000..79604e92ab --- /dev/null +++ b/web/app/components/datasets/list/dataset-card/components/description.tsx @@ -0,0 +1,18 @@ +import type { DataSet } from '@/models/datasets' +import * as React from 'react' +import { cn } from '@/utils/classnames' + +type DescriptionProps = { + dataset: DataSet +} + +const Description = ({ dataset }: DescriptionProps) => ( +
+ {dataset.description} +
+) + +export default React.memo(Description) diff --git a/web/app/components/datasets/list/dataset-card/components/operations-popover.tsx b/web/app/components/datasets/list/dataset-card/components/operations-popover.tsx new file mode 100644 index 0000000000..80ae2fb7a1 --- /dev/null +++ b/web/app/components/datasets/list/dataset-card/components/operations-popover.tsx @@ -0,0 +1,52 @@ +import type { DataSet } from '@/models/datasets' +import { RiMoreFill } from '@remixicon/react' +import * as React from 'react' +import CustomPopover from '@/app/components/base/popover' +import { cn } from '@/utils/classnames' +import Operations from '../operations' + +type OperationsPopoverProps = { + dataset: DataSet + isCurrentWorkspaceDatasetOperator: boolean + openRenameModal: () => void + handleExportPipeline: (include?: boolean) => void + detectIsUsedByApp: () => void +} + +const OperationsPopover = ({ + dataset, + isCurrentWorkspaceDatasetOperator, + openRenameModal, + handleExportPipeline, + detectIsUsedByApp, +}: OperationsPopoverProps) => ( +
+ + )} + className="z-20 min-w-[186px]" + popupClassName="rounded-xl bg-none shadow-none ring-0 min-w-[186px]" + position="br" + trigger="click" + btnElement={( +
+ +
+ )} + btnClassName={open => + cn( + 'size-9 cursor-pointer justify-center rounded-[10px] border-[0.5px] border-components-actionbar-border bg-components-actionbar-bg p-0 shadow-lg shadow-shadow-shadow-5 ring-[2px] ring-inset ring-components-actionbar-bg hover:border-components-actionbar-border', + open ? 'border-components-actionbar-border bg-state-base-hover' : '', + )} + /> +
+) + +export default React.memo(OperationsPopover) diff --git a/web/app/components/datasets/list/dataset-card/components/tag-area.tsx b/web/app/components/datasets/list/dataset-card/components/tag-area.tsx new file mode 100644 index 0000000000..f55a064387 --- /dev/null +++ b/web/app/components/datasets/list/dataset-card/components/tag-area.tsx @@ -0,0 +1,55 @@ +import type { Tag } from '@/app/components/base/tag-management/constant' +import type { DataSet } from '@/models/datasets' +import * as React from 'react' +import TagSelector from '@/app/components/base/tag-management/selector' +import { cn } from '@/utils/classnames' + +type TagAreaProps = { + dataset: DataSet + tags: Tag[] + setTags: (tags: Tag[]) => void + onSuccess?: () => void + isHoveringTagSelector: boolean + onClick: (e: React.MouseEvent) => void +} + +const TagArea = React.forwardRef(({ + dataset, + tags, + setTags, + onSuccess, + isHoveringTagSelector, + onClick, +}, ref) => ( +
+
0 && 'visible', + )} + > + tag.id)} + selectedTags={tags} + onCacheUpdate={setTags} + onChange={onSuccess} + /> +
+
+
+)) +TagArea.displayName = 'TagArea' + +export default TagArea diff --git a/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts b/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts new file mode 100644 index 0000000000..ad68a1df1c --- /dev/null +++ b/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts @@ -0,0 +1,138 @@ +import type { Tag } from '@/app/components/base/tag-management/constant' +import type { DataSet } from '@/models/datasets' +import { useCallback, useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import Toast from '@/app/components/base/toast' +import { useCheckDatasetUsage, useDeleteDataset } from '@/service/use-dataset-card' +import { useExportPipelineDSL } from '@/service/use-pipeline' + +type ModalState = { + showRenameModal: boolean + showConfirmDelete: boolean + confirmMessage: string +} + +type UseDatasetCardStateOptions = { + dataset: DataSet + onSuccess?: () => void +} + +export const useDatasetCardState = ({ dataset, onSuccess }: UseDatasetCardStateOptions) => { + const { t } = useTranslation() + const [tags, setTags] = useState(dataset.tags) + + useEffect(() => { + setTags(dataset.tags) + }, [dataset.tags]) + + // Modal state + const [modalState, setModalState] = useState({ + showRenameModal: false, + showConfirmDelete: false, + confirmMessage: '', + }) + + // Export state + const [exporting, setExporting] = useState(false) + + // Modal handlers + const openRenameModal = useCallback(() => { + setModalState(prev => ({ ...prev, showRenameModal: true })) + }, []) + + const closeRenameModal = useCallback(() => { + setModalState(prev => ({ ...prev, showRenameModal: false })) + }, []) + + const closeConfirmDelete = useCallback(() => { + setModalState(prev => ({ ...prev, showConfirmDelete: false })) + }, []) + + // API mutations + const { mutateAsync: checkUsage } = useCheckDatasetUsage() + const { mutateAsync: deleteDatasetMutation } = useDeleteDataset() + const { mutateAsync: exportPipelineConfig } = useExportPipelineDSL() + + // Export pipeline handler + const handleExportPipeline = useCallback(async (include: boolean = false) => { + const { pipeline_id, name } = dataset + if (!pipeline_id || exporting) + return + + try { + setExporting(true) + const { data } = await exportPipelineConfig({ + pipelineId: pipeline_id, + include, + }) + const a = document.createElement('a') + const file = new Blob([data], { type: 'application/yaml' }) + const url = URL.createObjectURL(file) + a.href = url + a.download = `${name}.pipeline` + a.click() + URL.revokeObjectURL(url) + } + catch { + Toast.notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) + } + finally { + setExporting(false) + } + }, [dataset, exportPipelineConfig, exporting, t]) + + // Delete flow handlers + const detectIsUsedByApp = useCallback(async () => { + try { + const { is_using: isUsedByApp } = await checkUsage(dataset.id) + const message = isUsedByApp + ? t('datasetUsedByApp', { ns: 'dataset' })! + : t('deleteDatasetConfirmContent', { ns: 'dataset' })! + setModalState(prev => ({ + ...prev, + confirmMessage: message, + showConfirmDelete: true, + })) + } + catch (e: unknown) { + if (e instanceof Response) { + const res = await e.json() + Toast.notify({ type: 'error', message: res?.message || 'Unknown error' }) + } + else { + Toast.notify({ type: 'error', message: (e as Error)?.message || 'Unknown error' }) + } + } + }, [dataset.id, checkUsage, t]) + + const onConfirmDelete = useCallback(async () => { + try { + await deleteDatasetMutation(dataset.id) + Toast.notify({ type: 'success', message: t('datasetDeleted', { ns: 'dataset' }) }) + onSuccess?.() + } + finally { + closeConfirmDelete() + } + }, [dataset.id, deleteDatasetMutation, onSuccess, t, closeConfirmDelete]) + + return { + // Tag state + tags, + setTags, + + // Modal state + modalState, + openRenameModal, + closeRenameModal, + closeConfirmDelete, + + // Export state + exporting, + + // Handlers + handleExportPipeline, + detectIsUsedByApp, + onConfirmDelete, + } +} diff --git a/web/app/components/datasets/list/dataset-card/index.tsx b/web/app/components/datasets/list/dataset-card/index.tsx index 99404b0454..85dba7e8ff 100644 --- a/web/app/components/datasets/list/dataset-card/index.tsx +++ b/web/app/components/datasets/list/dataset-card/index.tsx @@ -1,28 +1,17 @@ 'use client' -import type { Tag } from '@/app/components/base/tag-management/constant' import type { DataSet } from '@/models/datasets' -import { RiFileTextFill, RiMoreFill, RiRobot2Fill } from '@remixicon/react' import { useHover } from 'ahooks' import { useRouter } from 'next/navigation' -import * as React from 'react' -import { useCallback, useEffect, useMemo, useRef, useState } from 'react' -import { useTranslation } from 'react-i18next' -import AppIcon from '@/app/components/base/app-icon' -import Confirm from '@/app/components/base/confirm' -import CornerLabel from '@/app/components/base/corner-label' -import CustomPopover from '@/app/components/base/popover' -import TagSelector from '@/app/components/base/tag-management/selector' -import Toast from '@/app/components/base/toast' -import Tooltip from '@/app/components/base/tooltip' +import { useMemo, useRef } from 'react' import { useSelector as useAppContextWithSelector } from '@/context/app-context' -import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now' -import { useKnowledge } from '@/hooks/use-knowledge' -import { DOC_FORM_ICON_WITH_BG, DOC_FORM_TEXT } from '@/models/datasets' -import { checkIsUsedInApp, deleteDataset } from '@/service/datasets' -import { useExportPipelineDSL } from '@/service/use-pipeline' -import { cn } from '@/utils/classnames' -import RenameDatasetModal from '../../rename-modal' -import Operations from './operations' +import CornerLabels from './components/corner-labels' +import DatasetCardFooter from './components/dataset-card-footer' +import DatasetCardHeader from './components/dataset-card-header' +import DatasetCardModals from './components/dataset-card-modals' +import Description from './components/description' +import OperationsPopover from './components/operations-popover' +import TagArea from './components/tag-area' +import { useDatasetCardState } from './hooks/use-dataset-card-state' const EXTERNAL_PROVIDER = 'external' @@ -35,320 +24,80 @@ const DatasetCard = ({ dataset, onSuccess, }: DatasetCardProps) => { - const { t } = useTranslation() const { push } = useRouter() const isCurrentWorkspaceDatasetOperator = useAppContextWithSelector(state => state.isCurrentWorkspaceDatasetOperator) - const [tags, setTags] = useState(dataset.tags) const tagSelectorRef = useRef(null) const isHoveringTagSelector = useHover(tagSelectorRef) - const [showRenameModal, setShowRenameModal] = useState(false) - const [showConfirmDelete, setShowConfirmDelete] = useState(false) - const [confirmMessage, setConfirmMessage] = useState('') - const [exporting, setExporting] = useState(false) + const { + tags, + setTags, + modalState, + openRenameModal, + closeRenameModal, + closeConfirmDelete, + handleExportPipeline, + detectIsUsedByApp, + onConfirmDelete, + } = useDatasetCardState({ dataset, onSuccess }) - const isExternalProvider = useMemo(() => { - return dataset.provider === EXTERNAL_PROVIDER - }, [dataset.provider]) + const isExternalProvider = dataset.provider === EXTERNAL_PROVIDER const isPipelineUnpublished = useMemo(() => { return dataset.runtime_mode === 'rag_pipeline' && !dataset.is_published }, [dataset.runtime_mode, dataset.is_published]) - const isShowChunkingModeIcon = useMemo(() => { - return dataset.doc_form && (dataset.runtime_mode !== 'rag_pipeline' || dataset.is_published) - }, [dataset.doc_form, dataset.runtime_mode, dataset.is_published]) - const isShowDocModeInfo = useMemo(() => { - return dataset.doc_form && dataset.indexing_technique && dataset.retrieval_model_dict?.search_method && (dataset.runtime_mode !== 'rag_pipeline' || dataset.is_published) - }, [dataset.doc_form, dataset.indexing_technique, dataset.retrieval_model_dict?.search_method, dataset.runtime_mode, dataset.is_published]) - const chunkingModeIcon = dataset.doc_form ? DOC_FORM_ICON_WITH_BG[dataset.doc_form] : React.Fragment - const Icon = isExternalProvider ? DOC_FORM_ICON_WITH_BG.external : chunkingModeIcon - const iconInfo = dataset.icon_info || { - icon: '📙', - icon_type: 'emoji', - icon_background: '#FFF4ED', - icon_url: '', + const handleCardClick = (e: React.MouseEvent) => { + e.preventDefault() + if (isExternalProvider) + push(`/datasets/${dataset.id}/hitTesting`) + else if (isPipelineUnpublished) + push(`/datasets/${dataset.id}/pipeline`) + else + push(`/datasets/${dataset.id}/documents`) } - const { formatIndexingTechniqueAndMethod } = useKnowledge() - const documentCount = useMemo(() => { - const availableDocCount = dataset.total_available_documents ?? 0 - if (availableDocCount === dataset.document_count) - return `${dataset.document_count}` - if (availableDocCount < dataset.document_count) - return `${availableDocCount} / ${dataset.document_count}` - }, [dataset.document_count, dataset.total_available_documents]) - const documentCountTooltip = useMemo(() => { - const availableDocCount = dataset.total_available_documents ?? 0 - if (availableDocCount === dataset.document_count) - return t('docAllEnabled', { ns: 'dataset', count: availableDocCount }) - if (availableDocCount < dataset.document_count) - return t('partialEnabled', { ns: 'dataset', count: dataset.document_count, num: availableDocCount }) - }, [t, dataset.document_count, dataset.total_available_documents]) - const { formatTimeFromNow } = useFormatTimeFromNow() - const editTimeText = useMemo(() => { - return `${t('segment.editedAt', { ns: 'datasetDocuments' })} ${formatTimeFromNow(dataset.updated_at * 1000)}` - }, [t, dataset.updated_at, formatTimeFromNow]) - - const openRenameModal = useCallback(() => { - setShowRenameModal(true) - }, []) - - const { mutateAsync: exportPipelineConfig } = useExportPipelineDSL() - - const handleExportPipeline = useCallback(async (include = false) => { - const { pipeline_id, name } = dataset - if (!pipeline_id) - return - - if (exporting) - return - - try { - setExporting(true) - const { data } = await exportPipelineConfig({ - pipelineId: pipeline_id, - include, - }) - const a = document.createElement('a') - const file = new Blob([data], { type: 'application/yaml' }) - const url = URL.createObjectURL(file) - a.href = url - a.download = `${name}.pipeline` - a.click() - URL.revokeObjectURL(url) - } - catch { - Toast.notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) - } - finally { - setExporting(false) - } - }, [dataset, exportPipelineConfig, exporting, t]) - - const detectIsUsedByApp = useCallback(async () => { - try { - const { is_using: isUsedByApp } = await checkIsUsedInApp(dataset.id) - setConfirmMessage(isUsedByApp ? t('datasetUsedByApp', { ns: 'dataset' })! : t('deleteDatasetConfirmContent', { ns: 'dataset' })!) - setShowConfirmDelete(true) - } - catch (e: any) { - const res = await e.json() - Toast.notify({ type: 'error', message: res?.message || 'Unknown error' }) - } - }, [dataset.id, t]) - - const onConfirmDelete = useCallback(async () => { - try { - await deleteDataset(dataset.id) - Toast.notify({ type: 'success', message: t('datasetDeleted', { ns: 'dataset' }) }) - if (onSuccess) - onSuccess() - } - finally { - setShowConfirmDelete(false) - } - }, [dataset.id, onSuccess, t]) - - useEffect(() => { - setTags(dataset.tags) - }, [dataset]) + const handleTagAreaClick = (e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + } return ( <>
{ - e.preventDefault() - if (isExternalProvider) - push(`/datasets/${dataset.id}/hitTesting`) - else if (isPipelineUnpublished) - push(`/datasets/${dataset.id}/pipeline`) - else - push(`/datasets/${dataset.id}/documents`) - }} + onClick={handleCardClick} > - {!dataset.embedding_available && ( - - )} - {dataset.embedding_available && dataset.runtime_mode === 'rag_pipeline' && ( - - )} -
-
- - {(isShowChunkingModeIcon || isExternalProvider) && ( -
- -
- )} -
-
-
- {dataset.name} -
-
-
{dataset.author_name}
-
·
-
{editTimeText}
-
-
- {isExternalProvider && {t('externalKnowledgeBase', { ns: 'dataset' })}} - {!isExternalProvider && isShowDocModeInfo && ( - <> - {dataset.doc_form && ( - - {t(`chunkingMode.${DOC_FORM_TEXT[dataset.doc_form]}`, { ns: 'dataset' })} - - )} - {dataset.indexing_technique && ( - - {formatIndexingTechniqueAndMethod(dataset.indexing_technique, dataset.retrieval_model_dict?.search_method) as any} - - )} - {dataset.is_multimodal && ( - - {t('multimodal', { ns: 'dataset' })} - - )} - - )} -
-
-
-
- {dataset.description} -
-
{ - e.stopPropagation() - e.preventDefault() - }} - > -
0 && 'visible', - )} - > - tag.id)} - selectedTags={tags} - onCacheUpdate={setTags} - onChange={onSuccess} - /> -
- {/* Tag Mask */} -
-
-
- -
- - {documentCount} -
-
- {!isExternalProvider && ( - -
- - {dataset.app_count} -
-
- )} - / - {`${t('updated', { ns: 'dataset' })} ${formatTimeFromNow(dataset.updated_at * 1000)}`} -
-
- - )} - className="z-20 min-w-[186px]" - popupClassName="rounded-xl bg-none shadow-none ring-0 min-w-[186px]" - position="br" - trigger="click" - btnElement={( -
- -
- )} - btnClassName={open => - cn( - 'size-9 cursor-pointer justify-center rounded-[10px] border-[0.5px] border-components-actionbar-border bg-components-actionbar-bg p-0 shadow-lg shadow-shadow-shadow-5 ring-[2px] ring-inset ring-components-actionbar-bg hover:border-components-actionbar-border', - open ? 'border-components-actionbar-border bg-state-base-hover' : '', - )} - /> -
-
- {showRenameModal && ( - + + + setShowRenameModal(false)} + tags={tags} + setTags={setTags} onSuccess={onSuccess} + isHoveringTagSelector={isHoveringTagSelector} + onClick={handleTagAreaClick} /> - )} - {showConfirmDelete && ( - setShowConfirmDelete(false)} + + - )} +
+ ) } diff --git a/web/app/components/goto-anything/actions/commands/account.tsx b/web/app/components/goto-anything/actions/commands/account.tsx index 6465932a75..d1fa36b6f0 100644 --- a/web/app/components/goto-anything/actions/commands/account.tsx +++ b/web/app/components/goto-anything/actions/commands/account.tsx @@ -1,7 +1,7 @@ import type { SlashCommandHandler } from './types' import { RiUser3Line } from '@remixicon/react' import * as React from 'react' -import i18n from '@/i18n-config/i18next-config' +import { getI18n } from 'react-i18next' import { registerCommands, unregisterCommands } from './command-bus' // Account command dependency types - no external dependencies needed @@ -21,6 +21,7 @@ export const accountCommand: SlashCommandHandler = { }, async search(args: string, locale: string = 'en') { + const i18n = getI18n() return [{ id: 'account', title: i18n.t('account.account', { ns: 'common', lng: locale }), diff --git a/web/app/components/goto-anything/actions/commands/community.tsx b/web/app/components/goto-anything/actions/commands/community.tsx index fcd9a15000..685149402d 100644 --- a/web/app/components/goto-anything/actions/commands/community.tsx +++ b/web/app/components/goto-anything/actions/commands/community.tsx @@ -1,7 +1,7 @@ import type { SlashCommandHandler } from './types' import { RiDiscordLine } from '@remixicon/react' import * as React from 'react' -import i18n from '@/i18n-config/i18next-config' +import { getI18n } from 'react-i18next' import { registerCommands, unregisterCommands } from './command-bus' // Community command dependency types @@ -22,6 +22,7 @@ export const communityCommand: SlashCommandHandler = { }, async search(args: string, locale: string = 'en') { + const i18n = getI18n() return [{ id: 'community', title: i18n.t('userProfile.community', { ns: 'common', lng: locale }), diff --git a/web/app/components/goto-anything/actions/commands/docs.tsx b/web/app/components/goto-anything/actions/commands/docs.tsx index 9f09d32094..8b04e84157 100644 --- a/web/app/components/goto-anything/actions/commands/docs.tsx +++ b/web/app/components/goto-anything/actions/commands/docs.tsx @@ -1,8 +1,8 @@ import type { SlashCommandHandler } from './types' import { RiBookOpenLine } from '@remixicon/react' import * as React from 'react' +import { getI18n } from 'react-i18next' import { defaultDocBaseUrl } from '@/context/i18n' -import i18n from '@/i18n-config/i18next-config' import { getDocLanguage } from '@/i18n-config/language' import { registerCommands, unregisterCommands } from './command-bus' @@ -19,6 +19,7 @@ export const docsCommand: SlashCommandHandler = { // Direct execution function execute: () => { + const i18n = getI18n() const currentLocale = i18n.language const docLanguage = getDocLanguage(currentLocale) const url = `${defaultDocBaseUrl}/${docLanguage}` @@ -26,6 +27,7 @@ export const docsCommand: SlashCommandHandler = { }, async search(args: string, locale: string = 'en') { + const i18n = getI18n() return [{ id: 'doc', title: i18n.t('userProfile.helpCenter', { ns: 'common', lng: locale }), @@ -41,6 +43,7 @@ export const docsCommand: SlashCommandHandler = { }, register(_deps: DocDeps) { + const i18n = getI18n() registerCommands({ 'navigation.doc': async (_args) => { // Get the current language from i18n diff --git a/web/app/components/goto-anything/actions/commands/forum.tsx b/web/app/components/goto-anything/actions/commands/forum.tsx index e32632b4b5..36116ceb1f 100644 --- a/web/app/components/goto-anything/actions/commands/forum.tsx +++ b/web/app/components/goto-anything/actions/commands/forum.tsx @@ -1,7 +1,7 @@ import type { SlashCommandHandler } from './types' import { RiFeedbackLine } from '@remixicon/react' import * as React from 'react' -import i18n from '@/i18n-config/i18next-config' +import { getI18n } from 'react-i18next' import { registerCommands, unregisterCommands } from './command-bus' // Forum command dependency types @@ -22,6 +22,7 @@ export const forumCommand: SlashCommandHandler = { }, async search(args: string, locale: string = 'en') { + const i18n = getI18n() return [{ id: 'forum', title: i18n.t('userProfile.forum', { ns: 'common', lng: locale }), diff --git a/web/app/components/goto-anything/actions/commands/language.tsx b/web/app/components/goto-anything/actions/commands/language.tsx index df94fd49ce..f4bafc1d58 100644 --- a/web/app/components/goto-anything/actions/commands/language.tsx +++ b/web/app/components/goto-anything/actions/commands/language.tsx @@ -1,6 +1,6 @@ import type { CommandSearchResult } from '../types' import type { SlashCommandHandler } from './types' -import i18n from '@/i18n-config/i18next-config' +import { getI18n } from 'react-i18next' import { languages } from '@/i18n-config/language' import { registerCommands, unregisterCommands } from './command-bus' @@ -14,6 +14,7 @@ const buildLanguageCommands = (query: string): CommandSearchResult[] => { const list = languages.filter(item => item.supported && ( !q || item.name.toLowerCase().includes(q) || String(item.value).toLowerCase().includes(q) )) + const i18n = getI18n() return list.map(item => ({ id: `lang-${item.value}`, title: item.name, diff --git a/web/app/components/goto-anything/actions/commands/slash.tsx b/web/app/components/goto-anything/actions/commands/slash.tsx index ec0f333cd4..6aad67731f 100644 --- a/web/app/components/goto-anything/actions/commands/slash.tsx +++ b/web/app/components/goto-anything/actions/commands/slash.tsx @@ -2,8 +2,8 @@ import type { ActionItem } from '../types' import { useTheme } from 'next-themes' import { useEffect } from 'react' +import { getI18n } from 'react-i18next' import { setLocaleOnClient } from '@/i18n-config' -import i18n from '@/i18n-config/i18next-config' import { accountCommand } from './account' import { executeCommand } from './command-bus' import { communityCommand } from './community' @@ -14,6 +14,8 @@ import { slashCommandRegistry } from './registry' import { themeCommand } from './theme' import { zenCommand } from './zen' +const i18n = getI18n() + export const slashAction: ActionItem = { key: '/', shortcut: '/', diff --git a/web/app/components/goto-anything/actions/commands/theme.tsx b/web/app/components/goto-anything/actions/commands/theme.tsx index 335182af67..ba1416229d 100644 --- a/web/app/components/goto-anything/actions/commands/theme.tsx +++ b/web/app/components/goto-anything/actions/commands/theme.tsx @@ -2,7 +2,7 @@ import type { CommandSearchResult } from '../types' import type { SlashCommandHandler } from './types' import { RiComputerLine, RiMoonLine, RiSunLine } from '@remixicon/react' import * as React from 'react' -import i18n from '@/i18n-config/i18next-config' +import { getI18n } from 'react-i18next' import { registerCommands, unregisterCommands } from './command-bus' // Theme dependency types @@ -32,6 +32,7 @@ const THEME_ITEMS = [ ] as const const buildThemeCommands = (query: string, locale?: string): CommandSearchResult[] => { + const i18n = getI18n() const q = query.toLowerCase() const list = THEME_ITEMS.filter(item => !q diff --git a/web/app/components/goto-anything/actions/commands/zen.tsx b/web/app/components/goto-anything/actions/commands/zen.tsx index d6d9f1e5a2..1645e40fd9 100644 --- a/web/app/components/goto-anything/actions/commands/zen.tsx +++ b/web/app/components/goto-anything/actions/commands/zen.tsx @@ -1,8 +1,8 @@ import type { SlashCommandHandler } from './types' import { RiFullscreenLine } from '@remixicon/react' import * as React from 'react' +import { getI18n } from 'react-i18next' import { isInWorkflowPage } from '@/app/components/workflow/constants' -import i18n from '@/i18n-config/i18next-config' import { registerCommands, unregisterCommands } from './command-bus' // Zen command dependency types - no external dependencies needed @@ -32,6 +32,7 @@ export const zenCommand: SlashCommandHandler = { execute: toggleZenMode, async search(_args: string, locale: string = 'en') { + const i18n = getI18n() return [{ id: 'zen', title: i18n.t('gotoAnything.actions.zenTitle', { ns: 'app', lng: locale }) || 'Zen Mode', diff --git a/web/app/components/header/account-setting/data-source-page-new/install-from-marketplace.tsx b/web/app/components/header/account-setting/data-source-page-new/install-from-marketplace.tsx index 956352d6d3..f02e276f55 100644 --- a/web/app/components/header/account-setting/data-source-page-new/install-from-marketplace.tsx +++ b/web/app/components/header/account-setting/data-source-page-new/install-from-marketplace.tsx @@ -15,7 +15,6 @@ import Divider from '@/app/components/base/divider' import Loading from '@/app/components/base/loading' import List from '@/app/components/plugins/marketplace/list' import ProviderCard from '@/app/components/plugins/provider-card' -import { getLocaleOnClient } from '@/i18n-config' import { cn } from '@/utils/classnames' import { getMarketplaceUrl } from '@/utils/var' import { @@ -33,7 +32,6 @@ const InstallFromMarketplace = ({ const { t } = useTranslation() const { theme } = useTheme() const [collapse, setCollapse] = useState(false) - const locale = getLocaleOnClient() const { plugins: allPlugins, isLoading: isAllPluginsLoading, @@ -70,7 +68,6 @@ const InstallFromMarketplace = ({ marketplaceCollectionPluginsMap={{}} plugins={allPlugins} showInstallButton - locale={locale} cardContainerClassName="grid grid-cols-2 gap-2" cardRender={cardRender} emptyClassName="h-auto" diff --git a/web/app/components/header/account-setting/language-page/index.tsx b/web/app/components/header/account-setting/language-page/index.tsx index 5d888281e9..2a0604421f 100644 --- a/web/app/components/header/account-setting/language-page/index.tsx +++ b/web/app/components/header/account-setting/language-page/index.tsx @@ -2,13 +2,13 @@ import type { Item } from '@/app/components/base/select' import type { Locale } from '@/i18n-config' +import { useRouter } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { SimpleSelect } from '@/app/components/base/select' import { ToastContext } from '@/app/components/base/toast' import { useAppContext } from '@/context/app-context' - import { useLocale } from '@/context/i18n' import { setLocaleOnClient } from '@/i18n-config' import { languages } from '@/i18n-config/language' @@ -25,6 +25,7 @@ export default function LanguagePage() { const { notify } = useContext(ToastContext) const [editing, setEditing] = useState(false) const { t } = useTranslation() + const router = useRouter() const handleSelectLanguage = async (item: Item) => { const url = '/account/interface-language' @@ -35,7 +36,8 @@ export default function LanguagePage() { await updateUserProfile({ url, body: { [bodyKey]: item.value } }) notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) - setLocaleOnClient(item.value.toString() as Locale) + setLocaleOnClient(item.value.toString() as Locale, false) + router.refresh() } catch (e) { notify({ type: 'error', message: (e as Error).message }) diff --git a/web/app/components/header/account-setting/model-provider-page/index.tsx b/web/app/components/header/account-setting/model-provider-page/index.tsx index f456bcaaa6..d3daaee859 100644 --- a/web/app/components/header/account-setting/model-provider-page/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/index.tsx @@ -31,14 +31,19 @@ const FixedModelProvider = ['langgenius/openai/openai', 'langgenius/anthropic/an const ModelProviderPage = ({ searchText }: Props) => { const debouncedSearchText = useDebounce(searchText, { wait: 500 }) const { t } = useTranslation() - const { data: textGenerationDefaultModel } = useDefaultModel(ModelTypeEnum.textGeneration) - const { data: embeddingsDefaultModel } = useDefaultModel(ModelTypeEnum.textEmbedding) - const { data: rerankDefaultModel } = useDefaultModel(ModelTypeEnum.rerank) - const { data: speech2textDefaultModel } = useDefaultModel(ModelTypeEnum.speech2text) - const { data: ttsDefaultModel } = useDefaultModel(ModelTypeEnum.tts) + const { data: textGenerationDefaultModel, isLoading: isTextGenerationDefaultModelLoading } = useDefaultModel(ModelTypeEnum.textGeneration) + const { data: embeddingsDefaultModel, isLoading: isEmbeddingsDefaultModelLoading } = useDefaultModel(ModelTypeEnum.textEmbedding) + const { data: rerankDefaultModel, isLoading: isRerankDefaultModelLoading } = useDefaultModel(ModelTypeEnum.rerank) + const { data: speech2textDefaultModel, isLoading: isSpeech2textDefaultModelLoading } = useDefaultModel(ModelTypeEnum.speech2text) + const { data: ttsDefaultModel, isLoading: isTTSDefaultModelLoading } = useDefaultModel(ModelTypeEnum.tts) const { modelProviders: providers } = useProviderContext() const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures) - const defaultModelNotConfigured = !textGenerationDefaultModel && !embeddingsDefaultModel && !speech2textDefaultModel && !rerankDefaultModel && !ttsDefaultModel + const isDefaultModelLoading = isTextGenerationDefaultModelLoading + || isEmbeddingsDefaultModelLoading + || isRerankDefaultModelLoading + || isSpeech2textDefaultModelLoading + || isTTSDefaultModelLoading + const defaultModelNotConfigured = !isDefaultModelLoading && !textGenerationDefaultModel && !embeddingsDefaultModel && !speech2textDefaultModel && !rerankDefaultModel && !ttsDefaultModel const [configuredProviders, notConfiguredProviders] = useMemo(() => { const configuredProviders: ModelProvider[] = [] const notConfiguredProviders: ModelProvider[] = [] @@ -106,6 +111,7 @@ const ModelProviderPage = ({ searchText }: Props) => { rerankDefaultModel={rerankDefaultModel} speech2textDefaultModel={speech2textDefaultModel} ttsDefaultModel={ttsDefaultModel} + isLoading={isDefaultModelLoading} />
diff --git a/web/app/components/header/account-setting/model-provider-page/install-from-marketplace.tsx b/web/app/components/header/account-setting/model-provider-page/install-from-marketplace.tsx index 0e7506bf96..289146f2d2 100644 --- a/web/app/components/header/account-setting/model-provider-page/install-from-marketplace.tsx +++ b/web/app/components/header/account-setting/model-provider-page/install-from-marketplace.tsx @@ -14,7 +14,6 @@ import Divider from '@/app/components/base/divider' import Loading from '@/app/components/base/loading' import List from '@/app/components/plugins/marketplace/list' import ProviderCard from '@/app/components/plugins/provider-card' -import { getLocaleOnClient } from '@/i18n-config' import { cn } from '@/utils/classnames' import { getMarketplaceUrl } from '@/utils/var' import { @@ -32,7 +31,6 @@ const InstallFromMarketplace = ({ const { t } = useTranslation() const { theme } = useTheme() const [collapse, setCollapse] = useState(false) - const locale = getLocaleOnClient() const { plugins: allPlugins, isLoading: isAllPluginsLoading, @@ -69,7 +67,6 @@ const InstallFromMarketplace = ({ marketplaceCollectionPluginsMap={{}} plugins={allPlugins} showInstallButton - locale={locale} cardContainerClassName="grid grid-cols-2 gap-2" cardRender={cardRender} emptyClassName="h-auto" diff --git a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx index 74222ed56d..29c71e04fc 100644 --- a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx @@ -3,7 +3,7 @@ import type { DefaultModel, DefaultModelResponse, } from '../declarations' -import { RiEqualizer2Line } from '@remixicon/react' +import { RiEqualizer2Line, RiLoader2Line } from '@remixicon/react' import { useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' @@ -32,6 +32,7 @@ type SystemModelSelectorProps = { speech2textDefaultModel: DefaultModelResponse | undefined ttsDefaultModel: DefaultModelResponse | undefined notConfigured: boolean + isLoading?: boolean } const SystemModel: FC = ({ textGenerationDefaultModel, @@ -40,6 +41,7 @@ const SystemModel: FC = ({ speech2textDefaultModel, ttsDefaultModel, notConfigured, + isLoading, }) => { const { t } = useTranslation() const { notify } = useToastContext() @@ -129,13 +131,16 @@ const SystemModel: FC = ({ crossAxis: 8, }} > - setOpen(v => !v)}> + setOpen(v => !v)}> diff --git a/web/app/components/i18n-server.tsx b/web/app/components/i18n-server.tsx deleted file mode 100644 index 01dc5f0f13..0000000000 --- a/web/app/components/i18n-server.tsx +++ /dev/null @@ -1,22 +0,0 @@ -import * as React from 'react' -import { getLocaleOnServer } from '@/i18n-config/server' -import { ToastProvider } from './base/toast' -import I18N from './i18n' - -export type II18NServerProps = { - children: React.ReactNode -} - -const I18NServer = async ({ - children, -}: II18NServerProps) => { - const locale = await getLocaleOnServer() - - return ( - - {children} - - ) -} - -export default I18NServer diff --git a/web/app/components/i18n.tsx b/web/app/components/i18n.tsx deleted file mode 100644 index e9af2face9..0000000000 --- a/web/app/components/i18n.tsx +++ /dev/null @@ -1,45 +0,0 @@ -'use client' - -import type { FC } from 'react' -import type { Locale } from '@/i18n-config' -import { usePrefetchQuery } from '@tanstack/react-query' -import { useHydrateAtoms } from 'jotai/utils' -import * as React from 'react' -import { useEffect, useState } from 'react' -import { localeAtom } from '@/context/i18n' -import { setLocaleOnClient } from '@/i18n-config' -import { getSystemFeatures } from '@/service/common' -import Loading from './base/loading' - -export type II18nProps = { - locale: Locale - children: React.ReactNode -} -const I18n: FC = ({ - locale, - children, -}) => { - useHydrateAtoms([[localeAtom, locale]]) - const [loading, setLoading] = useState(true) - - usePrefetchQuery({ - queryKey: ['systemFeatures'], - queryFn: getSystemFeatures, - }) - - useEffect(() => { - setLocaleOnClient(locale, false).then(() => { - setLoading(false) - }) - }, [locale]) - - if (loading) - return
- - return ( - <> - {children} - - ) -} -export default React.memo(I18n) diff --git a/web/app/components/plugins/base/deprecation-notice.tsx b/web/app/components/plugins/base/deprecation-notice.tsx index c2ddfa6975..513b27a2cf 100644 --- a/web/app/components/plugins/base/deprecation-notice.tsx +++ b/web/app/components/plugins/base/deprecation-notice.tsx @@ -1,4 +1,5 @@ import type { FC } from 'react' +import { useTranslation } from '#i18n' import { RiAlertFill } from '@remixicon/react' import { camelCase } from 'es-toolkit/string' import Link from 'next/link' @@ -6,14 +7,12 @@ import * as React from 'react' import { useMemo } from 'react' import { Trans } from 'react-i18next' import { cn } from '@/utils/classnames' -import { useMixedTranslation } from '../marketplace/hooks' type DeprecationNoticeProps = { status: 'deleted' | 'active' deprecatedReason: string alternativePluginId: string alternativePluginURL: string - locale?: string className?: string innerWrapperClassName?: string iconWrapperClassName?: string @@ -34,13 +33,12 @@ const DeprecationNotice: FC = ({ deprecatedReason, alternativePluginId, alternativePluginURL, - locale, className, innerWrapperClassName, iconWrapperClassName, textClassName, }) => { - const { t } = useMixedTranslation(locale) + const { t } = useTranslation() const deprecatedReasonKey = useMemo(() => { if (!deprecatedReason) diff --git a/web/app/components/plugins/card/index.spec.tsx b/web/app/components/plugins/card/index.spec.tsx index 4a3e5a587b..fd97534ec4 100644 --- a/web/app/components/plugins/card/index.spec.tsx +++ b/web/app/components/plugins/card/index.spec.tsx @@ -502,31 +502,6 @@ describe('Card', () => { }) }) - // ================================ - // Locale Tests - // ================================ - describe('Locale', () => { - it('should use locale from props when provided', () => { - const plugin = createMockPlugin({ - label: { 'en-US': 'English Title', 'zh-Hans': '中文标题' }, - }) - - render() - - expect(screen.getByText('中文标题')).toBeInTheDocument() - }) - - it('should fallback to default locale when prop locale not found', () => { - const plugin = createMockPlugin({ - label: { 'en-US': 'English Title' }, - }) - - render() - - expect(screen.getByText('English Title')).toBeInTheDocument() - }) - }) - // ================================ // Memoization Tests // ================================ diff --git a/web/app/components/plugins/card/index.tsx b/web/app/components/plugins/card/index.tsx index ada26801de..8578421116 100644 --- a/web/app/components/plugins/card/index.tsx +++ b/web/app/components/plugins/card/index.tsx @@ -1,15 +1,13 @@ 'use client' import type { Plugin } from '../types' -import type { Locale } from '@/i18n-config' +import { useTranslation } from '#i18n' import { RiAlertFill } from '@remixicon/react' import * as React from 'react' -import { useMixedTranslation } from '@/app/components/plugins/marketplace/hooks' import { useGetLanguage } from '@/context/i18n' import useTheme from '@/hooks/use-theme' import { renderI18nObject, } from '@/i18n-config' -import { getLanguage } from '@/i18n-config/language' import { Theme } from '@/types/app' import { cn } from '@/utils/classnames' import Partner from '../base/badges/partner' @@ -33,7 +31,6 @@ export type Props = { footer?: React.ReactNode isLoading?: boolean loadingFileName?: string - locale?: Locale limitedInstall?: boolean } @@ -48,13 +45,11 @@ const Card = ({ footer, isLoading = false, loadingFileName, - locale: localeFromProps, limitedInstall = false, }: Props) => { - const defaultLocale = useGetLanguage() - const locale = localeFromProps ? getLanguage(localeFromProps) : defaultLocale - const { t } = useMixedTranslation(localeFromProps) - const { categoriesMap } = useCategories(t, true) + const locale = useGetLanguage() + const { t } = useTranslation() + const { categoriesMap } = useCategories(true) const { category, type, name, org, label, brief, icon, icon_dark, verified, badges = [] } = payload const { theme } = useTheme() const iconSrc = theme === Theme.dark && icon_dark ? icon_dark : icon diff --git a/web/app/components/plugins/hooks.ts b/web/app/components/plugins/hooks.ts index 262935205b..65d073cc2f 100644 --- a/web/app/components/plugins/hooks.ts +++ b/web/app/components/plugins/hooks.ts @@ -1,4 +1,3 @@ -import type { TFunction } from 'i18next' import type { CategoryKey, TagKey } from './constants' import { useMemo } from 'react' import { useTranslation } from 'react-i18next' @@ -13,9 +12,8 @@ export type Tag = { label: string } -export const useTags = (translateFromOut?: TFunction) => { - const { t: translation } = useTranslation() - const t = translateFromOut || translation +export const useTags = () => { + const { t } = useTranslation() const tags = useMemo(() => { return tagKeys.map((tag) => { @@ -53,9 +51,8 @@ type Category = { label: string } -export const useCategories = (translateFromOut?: TFunction, isSingle?: boolean) => { - const { t: translation } = useTranslation() - const t = translateFromOut || translation +export const useCategories = (isSingle?: boolean) => { + const { t } = useTranslation() const categories = useMemo(() => { return categoryKeys.map((category) => { diff --git a/web/app/components/plugins/marketplace/description/index.spec.tsx b/web/app/components/plugins/marketplace/description/index.spec.tsx index b5c8cb716b..054949ee1f 100644 --- a/web/app/components/plugins/marketplace/description/index.spec.tsx +++ b/web/app/components/plugins/marketplace/description/index.spec.tsx @@ -1,7 +1,5 @@ import { render, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' - -// Import component after mocks are set up import Description from './index' // ================================ @@ -30,20 +28,18 @@ const commonTranslations: Record = { 'operation.in': 'in', } -// Mock getLocaleOnServer and translate -vi.mock('@/i18n-config/server', () => ({ - getLocaleOnServer: vi.fn(() => Promise.resolve(mockDefaultLocale)), - getTranslation: vi.fn((locale: string, ns: string) => { - return Promise.resolve({ - t: (key: string) => { - if (ns === 'plugin') - return pluginTranslations[key] || key - if (ns === 'common') - return commonTranslations[key] || key - return key - }, - }) - }), +// Mock i18n hooks +vi.mock('#i18n', () => ({ + useLocale: vi.fn(() => mockDefaultLocale), + useTranslation: vi.fn((ns: string) => ({ + t: (key: string) => { + if (ns === 'plugin') + return pluginTranslations[key] || key + if (ns === 'common') + return commonTranslations[key] || key + return key + }, + })), })) // ================================ @@ -59,29 +55,29 @@ describe('Description', () => { // Rendering Tests // ================================ describe('Rendering', () => { - it('should render without crashing', async () => { - const { container } = render(await Description({})) + it('should render without crashing', () => { + const { container } = render() expect(container.firstChild).toBeInTheDocument() }) - it('should render h1 heading with empower text', async () => { - render(await Description({})) + it('should render h1 heading with empower text', () => { + render() const heading = screen.getByRole('heading', { level: 1 }) expect(heading).toBeInTheDocument() expect(heading).toHaveTextContent('Empower your AI development') }) - it('should render h2 subheading', async () => { - render(await Description({})) + it('should render h2 subheading', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading).toBeInTheDocument() }) - it('should apply correct CSS classes to h1', async () => { - render(await Description({})) + it('should apply correct CSS classes to h1', () => { + render() const heading = screen.getByRole('heading', { level: 1 }) expect(heading).toHaveClass('title-4xl-semi-bold') @@ -90,8 +86,8 @@ describe('Description', () => { expect(heading).toHaveClass('text-text-primary') }) - it('should apply correct CSS classes to h2', async () => { - render(await Description({})) + it('should apply correct CSS classes to h2', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading).toHaveClass('body-md-regular') @@ -104,14 +100,18 @@ describe('Description', () => { // Non-Chinese Locale Rendering Tests // ================================ describe('Non-Chinese Locale Rendering', () => { - it('should render discover text for en-US locale', async () => { - render(await Description({ locale: 'en-US' })) + beforeEach(() => { + mockDefaultLocale = 'en-US' + }) + + it('should render discover text for en-US locale', () => { + render() expect(screen.getByText(/Discover/)).toBeInTheDocument() }) - it('should render all category names', async () => { - render(await Description({ locale: 'en-US' })) + it('should render all category names', () => { + render() expect(screen.getByText('Models')).toBeInTheDocument() expect(screen.getByText('Tools')).toBeInTheDocument() @@ -122,36 +122,36 @@ describe('Description', () => { expect(screen.getByText('Bundles')).toBeInTheDocument() }) - it('should render "and" conjunction text', async () => { - render(await Description({ locale: 'en-US' })) + it('should render "and" conjunction text', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading.textContent).toContain('and') }) - it('should render "in" preposition at the end for non-Chinese locales', async () => { - render(await Description({ locale: 'en-US' })) + it('should render "in" preposition at the end for non-Chinese locales', () => { + render() expect(screen.getByText('in')).toBeInTheDocument() }) - it('should render Dify Marketplace text at the end for non-Chinese locales', async () => { - render(await Description({ locale: 'en-US' })) + it('should render Dify Marketplace text at the end for non-Chinese locales', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading.textContent).toContain('Dify Marketplace') }) - it('should render category spans with styled underline effect', async () => { - const { container } = render(await Description({ locale: 'en-US' })) + it('should render category spans with styled underline effect', () => { + const { container } = render() const styledSpans = container.querySelectorAll('.body-md-medium.relative.z-\\[1\\]') // 7 category spans (models, tools, datasources, triggers, agents, extensions, bundles) expect(styledSpans.length).toBe(7) }) - it('should apply text-text-secondary class to category spans', async () => { - const { container } = render(await Description({ locale: 'en-US' })) + it('should apply text-text-secondary class to category spans', () => { + const { container } = render() const styledSpans = container.querySelectorAll('.text-text-secondary') expect(styledSpans.length).toBeGreaterThanOrEqual(7) @@ -162,29 +162,33 @@ describe('Description', () => { // Chinese (zh-Hans) Locale Rendering Tests // ================================ describe('Chinese (zh-Hans) Locale Rendering', () => { - it('should render "in" text at the beginning for zh-Hans locale', async () => { - render(await Description({ locale: 'zh-Hans' })) + beforeEach(() => { + mockDefaultLocale = 'zh-Hans' + }) + + it('should render "in" text at the beginning for zh-Hans locale', () => { + render() // In zh-Hans mode, "in" appears at the beginning const inElements = screen.getAllByText('in') expect(inElements.length).toBeGreaterThanOrEqual(1) }) - it('should render Dify Marketplace text for zh-Hans locale', async () => { - render(await Description({ locale: 'zh-Hans' })) + it('should render Dify Marketplace text for zh-Hans locale', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading.textContent).toContain('Dify Marketplace') }) - it('should render discover text for zh-Hans locale', async () => { - render(await Description({ locale: 'zh-Hans' })) + it('should render discover text for zh-Hans locale', () => { + render() expect(screen.getByText(/Discover/)).toBeInTheDocument() }) - it('should render all categories for zh-Hans locale', async () => { - render(await Description({ locale: 'zh-Hans' })) + it('should render all categories for zh-Hans locale', () => { + render() expect(screen.getByText('Models')).toBeInTheDocument() expect(screen.getByText('Tools')).toBeInTheDocument() @@ -195,8 +199,8 @@ describe('Description', () => { expect(screen.getByText('Bundles')).toBeInTheDocument() }) - it('should render both zh-Hans specific elements and shared elements', async () => { - render(await Description({ locale: 'zh-Hans' })) + it('should render both zh-Hans specific elements and shared elements', () => { + render() // zh-Hans has specific element order: "in" -> Dify Marketplace -> Discover // then the same category list with "and" -> Bundles @@ -206,61 +210,57 @@ describe('Description', () => { }) // ================================ - // Locale Prop Variations Tests + // Locale Variations Tests // ================================ - describe('Locale Prop Variations', () => { - it('should use default locale when locale prop is undefined', async () => { + describe('Locale Variations', () => { + it('should use en-US locale by default', () => { mockDefaultLocale = 'en-US' - render(await Description({})) + render() - // Should use the default locale from getLocaleOnServer expect(screen.getByText('Empower your AI development')).toBeInTheDocument() }) - it('should use provided locale prop instead of default', async () => { + it('should handle ja-JP locale as non-Chinese', () => { mockDefaultLocale = 'ja-JP' - render(await Description({ locale: 'en-US' })) - - // The locale prop should be used, triggering non-Chinese rendering - const subheading = screen.getByRole('heading', { level: 2 }) - expect(subheading).toBeInTheDocument() - }) - - it('should handle ja-JP locale as non-Chinese', async () => { - render(await Description({ locale: 'ja-JP' })) + render() // Should render in non-Chinese format (discover first, then "in Dify Marketplace" at end) const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading.textContent).toContain('Dify Marketplace') }) - it('should handle ko-KR locale as non-Chinese', async () => { - render(await Description({ locale: 'ko-KR' })) + it('should handle ko-KR locale as non-Chinese', () => { + mockDefaultLocale = 'ko-KR' + render() // Should render in non-Chinese format expect(screen.getByText('Empower your AI development')).toBeInTheDocument() }) - it('should handle de-DE locale as non-Chinese', async () => { - render(await Description({ locale: 'de-DE' })) + it('should handle de-DE locale as non-Chinese', () => { + mockDefaultLocale = 'de-DE' + render() expect(screen.getByText('Empower your AI development')).toBeInTheDocument() }) - it('should handle fr-FR locale as non-Chinese', async () => { - render(await Description({ locale: 'fr-FR' })) + it('should handle fr-FR locale as non-Chinese', () => { + mockDefaultLocale = 'fr-FR' + render() expect(screen.getByText('Empower your AI development')).toBeInTheDocument() }) - it('should handle pt-BR locale as non-Chinese', async () => { - render(await Description({ locale: 'pt-BR' })) + it('should handle pt-BR locale as non-Chinese', () => { + mockDefaultLocale = 'pt-BR' + render() expect(screen.getByText('Empower your AI development')).toBeInTheDocument() }) - it('should handle es-ES locale as non-Chinese', async () => { - render(await Description({ locale: 'es-ES' })) + it('should handle es-ES locale as non-Chinese', () => { + mockDefaultLocale = 'es-ES' + render() expect(screen.getByText('Empower your AI development')).toBeInTheDocument() }) @@ -270,24 +270,27 @@ describe('Description', () => { // Conditional Rendering Tests // ================================ describe('Conditional Rendering', () => { - it('should render zh-Hans specific content when locale is zh-Hans', async () => { - const { container } = render(await Description({ locale: 'zh-Hans' })) + it('should render zh-Hans specific content when locale is zh-Hans', () => { + mockDefaultLocale = 'zh-Hans' + const { container } = render() // zh-Hans has additional span with mr-1 before "in" text at the start const mrSpan = container.querySelector('span.mr-1') expect(mrSpan).toBeInTheDocument() }) - it('should render non-Chinese specific content when locale is not zh-Hans', async () => { - render(await Description({ locale: 'en-US' })) + it('should render non-Chinese specific content when locale is not zh-Hans', () => { + mockDefaultLocale = 'en-US' + render() // Non-Chinese has "in" and "Dify Marketplace" at the end const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading.textContent).toContain('Dify Marketplace') }) - it('should not render zh-Hans intro content for non-Chinese locales', async () => { - render(await Description({ locale: 'en-US' })) + it('should not render zh-Hans intro content for non-Chinese locales', () => { + mockDefaultLocale = 'en-US' + render() // For en-US, the order should be Discover ... in Dify Marketplace // The "in" text should only appear once at the end @@ -303,8 +306,9 @@ describe('Description', () => { expect(inIndex).toBeLessThan(marketplaceIndex) }) - it('should render zh-Hans with proper word order', async () => { - render(await Description({ locale: 'zh-Hans' })) + it('should render zh-Hans with proper word order', () => { + mockDefaultLocale = 'zh-Hans' + render() const subheading = screen.getByRole('heading', { level: 2 }) const content = subheading.textContent || '' @@ -323,58 +327,58 @@ describe('Description', () => { // Category Styling Tests // ================================ describe('Category Styling', () => { - it('should apply underline effect with after pseudo-element styling', async () => { - const { container } = render(await Description({})) + it('should apply underline effect with after pseudo-element styling', () => { + const { container } = render() const categorySpan = container.querySelector('.after\\:absolute') expect(categorySpan).toBeInTheDocument() }) - it('should apply correct after pseudo-element classes', async () => { - const { container } = render(await Description({})) + it('should apply correct after pseudo-element classes', () => { + const { container } = render() // Check for the specific after pseudo-element classes const categorySpans = container.querySelectorAll('.after\\:bottom-\\[1\\.5px\\]') expect(categorySpans.length).toBe(7) }) - it('should apply full width to after element', async () => { - const { container } = render(await Description({})) + it('should apply full width to after element', () => { + const { container } = render() const categorySpans = container.querySelectorAll('.after\\:w-full') expect(categorySpans.length).toBe(7) }) - it('should apply correct height to after element', async () => { - const { container } = render(await Description({})) + it('should apply correct height to after element', () => { + const { container } = render() const categorySpans = container.querySelectorAll('.after\\:h-2') expect(categorySpans.length).toBe(7) }) - it('should apply bg-text-text-selected to after element', async () => { - const { container } = render(await Description({})) + it('should apply bg-text-text-selected to after element', () => { + const { container } = render() const categorySpans = container.querySelectorAll('.after\\:bg-text-text-selected') expect(categorySpans.length).toBe(7) }) - it('should have z-index 1 on category spans', async () => { - const { container } = render(await Description({})) + it('should have z-index 1 on category spans', () => { + const { container } = render() const categorySpans = container.querySelectorAll('.z-\\[1\\]') expect(categorySpans.length).toBe(7) }) - it('should apply left margin to category spans', async () => { - const { container } = render(await Description({})) + it('should apply left margin to category spans', () => { + const { container } = render() const categorySpans = container.querySelectorAll('.ml-1') expect(categorySpans.length).toBeGreaterThanOrEqual(7) }) - it('should apply both left and right margin to specific spans', async () => { - const { container } = render(await Description({})) + it('should apply both left and right margin to specific spans', () => { + const { container } = render() // Extensions and Bundles spans have both ml-1 and mr-1 const extensionsBundlesSpans = container.querySelectorAll('.ml-1.mr-1') @@ -386,28 +390,17 @@ describe('Description', () => { // Edge Cases Tests // ================================ describe('Edge Cases', () => { - it('should handle empty props object', async () => { - const { container } = render(await Description({})) - - expect(container.firstChild).toBeInTheDocument() - }) - - it('should render fragment as root element', async () => { - const { container } = render(await Description({})) + it('should render fragment as root element', () => { + const { container } = render() // Fragment renders h1 and h2 as direct children expect(container.querySelector('h1')).toBeInTheDocument() expect(container.querySelector('h2')).toBeInTheDocument() }) - it('should handle locale prop with undefined value', async () => { - render(await Description({ locale: undefined })) - - expect(screen.getByRole('heading', { level: 1 })).toBeInTheDocument() - }) - - it('should handle zh-Hant as non-Chinese simplified', async () => { - render(await Description({ locale: 'zh-Hant' })) + it('should handle zh-Hant as non-Chinese simplified', () => { + mockDefaultLocale = 'zh-Hant' + render() // zh-Hant is different from zh-Hans, should use non-Chinese format const subheading = screen.getByRole('heading', { level: 2 }) @@ -426,8 +419,8 @@ describe('Description', () => { // Content Structure Tests // ================================ describe('Content Structure', () => { - it('should have comma separators between categories', async () => { - render(await Description({})) + it('should have comma separators between categories', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) const content = subheading.textContent || '' @@ -436,8 +429,8 @@ describe('Description', () => { expect(content).toMatch(/Models[^\n\r,\u2028\u2029]*,.*Tools[^\n\r,\u2028\u2029]*,.*Data Sources[^\n\r,\u2028\u2029]*,.*Triggers[^\n\r,\u2028\u2029]*,.*Agent Strategies[^\n\r,\u2028\u2029]*,.*Extensions/) }) - it('should have "and" before last category (Bundles)', async () => { - render(await Description({})) + it('should have "and" before last category (Bundles)', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) const content = subheading.textContent || '' @@ -449,8 +442,9 @@ describe('Description', () => { expect(andIndex).toBeLessThan(bundlesIndex) }) - it('should render all text elements in correct order for en-US', async () => { - render(await Description({ locale: 'en-US' })) + it('should render all text elements in correct order for en-US', () => { + mockDefaultLocale = 'en-US' + render() const subheading = screen.getByRole('heading', { level: 2 }) const content = subheading.textContent || '' @@ -477,8 +471,9 @@ describe('Description', () => { } }) - it('should render all text elements in correct order for zh-Hans', async () => { - render(await Description({ locale: 'zh-Hans' })) + it('should render all text elements in correct order for zh-Hans', () => { + mockDefaultLocale = 'zh-Hans' + render() const subheading = screen.getByRole('heading', { level: 2 }) const content = subheading.textContent || '' @@ -499,82 +494,48 @@ describe('Description', () => { // Layout Tests // ================================ describe('Layout', () => { - it('should have shrink-0 on h1 heading', async () => { - render(await Description({})) + it('should have shrink-0 on h1 heading', () => { + render() const heading = screen.getByRole('heading', { level: 1 }) expect(heading).toHaveClass('shrink-0') }) - it('should have shrink-0 on h2 subheading', async () => { - render(await Description({})) + it('should have shrink-0 on h2 subheading', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading).toHaveClass('shrink-0') }) - it('should have flex layout on h2', async () => { - render(await Description({})) + it('should have flex layout on h2', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading).toHaveClass('flex') }) - it('should have items-center on h2', async () => { - render(await Description({})) + it('should have items-center on h2', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading).toHaveClass('items-center') }) - it('should have justify-center on h2', async () => { - render(await Description({})) + it('should have justify-center on h2', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading).toHaveClass('justify-center') }) }) - // ================================ - // Translation Function Tests - // ================================ - describe('Translation Functions', () => { - it('should call getTranslation for plugin namespace', async () => { - const { getTranslation } = await import('@/i18n-config/server') - render(await Description({ locale: 'en-US' })) - - expect(getTranslation).toHaveBeenCalledWith('en-US', 'plugin') - }) - - it('should call getTranslation for common namespace', async () => { - const { getTranslation } = await import('@/i18n-config/server') - render(await Description({ locale: 'en-US' })) - - expect(getTranslation).toHaveBeenCalledWith('en-US', 'common') - }) - - it('should call getLocaleOnServer when locale prop is undefined', async () => { - const { getLocaleOnServer } = await import('@/i18n-config/server') - render(await Description({})) - - expect(getLocaleOnServer).toHaveBeenCalled() - }) - - it('should use locale prop when provided', async () => { - const { getTranslation } = await import('@/i18n-config/server') - render(await Description({ locale: 'ja-JP' })) - - expect(getTranslation).toHaveBeenCalledWith('ja-JP', 'plugin') - expect(getTranslation).toHaveBeenCalledWith('ja-JP', 'common') - }) - }) - // ================================ // Accessibility Tests // ================================ describe('Accessibility', () => { - it('should have proper heading hierarchy', async () => { - render(await Description({})) + it('should have proper heading hierarchy', () => { + render() const h1 = screen.getByRole('heading', { level: 1 }) const h2 = screen.getByRole('heading', { level: 2 }) @@ -583,22 +544,22 @@ describe('Description', () => { expect(h2).toBeInTheDocument() }) - it('should have readable text content', async () => { - render(await Description({})) + it('should have readable text content', () => { + render() const h1 = screen.getByRole('heading', { level: 1 }) expect(h1.textContent).not.toBe('') }) - it('should have visible h1 heading', async () => { - render(await Description({})) + it('should have visible h1 heading', () => { + render() const heading = screen.getByRole('heading', { level: 1 }) expect(heading).toBeVisible() }) - it('should have visible h2 heading', async () => { - render(await Description({})) + it('should have visible h2 heading', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading).toBeVisible() @@ -615,8 +576,8 @@ describe('Description Integration', () => { mockDefaultLocale = 'en-US' }) - it('should render complete component structure', async () => { - const { container } = render(await Description({ locale: 'en-US' })) + it('should render complete component structure', () => { + const { container } = render() // Main headings expect(container.querySelector('h1')).toBeInTheDocument() @@ -627,8 +588,9 @@ describe('Description Integration', () => { expect(categorySpans.length).toBe(7) }) - it('should render complete zh-Hans structure', async () => { - const { container } = render(await Description({ locale: 'zh-Hans' })) + it('should render complete zh-Hans structure', () => { + mockDefaultLocale = 'zh-Hans' + const { container } = render() // Main headings expect(container.querySelector('h1')).toBeInTheDocument() @@ -639,14 +601,16 @@ describe('Description Integration', () => { expect(categorySpans.length).toBe(7) }) - it('should correctly switch between zh-Hans and en-US layouts', async () => { + it('should correctly differentiate between zh-Hans and en-US layouts', () => { // Render en-US - const { container: enContainer, unmount: unmountEn } = render(await Description({ locale: 'en-US' })) + mockDefaultLocale = 'en-US' + const { container: enContainer, unmount: unmountEn } = render() const enContent = enContainer.querySelector('h2')?.textContent || '' unmountEn() // Render zh-Hans - const { container: zhContainer } = render(await Description({ locale: 'zh-Hans' })) + mockDefaultLocale = 'zh-Hans' + const { container: zhContainer } = render() const zhContent = zhContainer.querySelector('h2')?.textContent || '' // Both should have all categories @@ -666,14 +630,16 @@ describe('Description Integration', () => { expect(zhMarketplaceIndex).toBeLessThan(zhDiscoverIndex) }) - it('should maintain consistent styling across locales', async () => { + it('should maintain consistent styling across locales', () => { // Render en-US - const { container: enContainer, unmount: unmountEn } = render(await Description({ locale: 'en-US' })) + mockDefaultLocale = 'en-US' + const { container: enContainer, unmount: unmountEn } = render() const enCategoryCount = enContainer.querySelectorAll('.body-md-medium').length unmountEn() // Render zh-Hans - const { container: zhContainer } = render(await Description({ locale: 'zh-Hans' })) + mockDefaultLocale = 'zh-Hans' + const { container: zhContainer } = render() const zhCategoryCount = zhContainer.querySelectorAll('.body-md-medium').length // Both should have same number of styled category spans diff --git a/web/app/components/plugins/marketplace/description/index.tsx b/web/app/components/plugins/marketplace/description/index.tsx index d3ca964538..30ccbdb76e 100644 --- a/web/app/components/plugins/marketplace/description/index.tsx +++ b/web/app/components/plugins/marketplace/description/index.tsx @@ -1,17 +1,11 @@ -/* eslint-disable dify-i18n/require-ns-option */ -import type { Locale } from '@/i18n-config' -import { getLocaleOnServer, getTranslation } from '@/i18n-config/server' +import { useLocale, useTranslation } from '#i18n' -type DescriptionProps = { - locale?: Locale -} -const Description = async ({ - locale: localeFromProps, -}: DescriptionProps) => { - const localeDefault = await getLocaleOnServer() - const { t } = await getTranslation(localeFromProps || localeDefault, 'plugin') - const { t: tCommon } = await getTranslation(localeFromProps || localeDefault, 'common') - const isZhHans = localeFromProps === 'zh-Hans' +const Description = () => { + const { t } = useTranslation('plugin') + const { t: tCommon } = useTranslation('common') + const locale = useLocale() + + const isZhHans = locale === 'zh-Hans' return ( <> diff --git a/web/app/components/plugins/marketplace/empty/index.spec.tsx b/web/app/components/plugins/marketplace/empty/index.spec.tsx index 4cbc85a309..bc8e701dfc 100644 --- a/web/app/components/plugins/marketplace/empty/index.spec.tsx +++ b/web/app/components/plugins/marketplace/empty/index.spec.tsx @@ -7,9 +7,9 @@ import Line from './line' // Mock external dependencies only // ================================ -// Mock useMixedTranslation hook -vi.mock('../hooks', () => ({ - useMixedTranslation: (_locale?: string) => ({ +// Mock i18n translation hook +vi.mock('#i18n', () => ({ + useTranslation: () => ({ t: (key: string, options?: { ns?: string }) => { // Build full key with namespace prefix if provided const fullKey = options?.ns ? `${options.ns}.${key}` : key @@ -471,36 +471,6 @@ describe('Empty', () => { }) }) - // ================================ - // Locale Prop Tests - // ================================ - describe('Locale Prop', () => { - it('should pass locale to useMixedTranslation', () => { - render() - - // Translation should still work - expect(screen.getByText('No plugin found')).toBeInTheDocument() - }) - - it('should handle undefined locale', () => { - render() - - expect(screen.getByText('No plugin found')).toBeInTheDocument() - }) - - it('should handle en-US locale', () => { - render() - - expect(screen.getByText('No plugin found')).toBeInTheDocument() - }) - - it('should handle ja-JP locale', () => { - render() - - expect(screen.getByText('No plugin found')).toBeInTheDocument() - }) - }) - // ================================ // Placeholder Cards Layout Tests // ================================ @@ -634,7 +604,6 @@ describe('Empty', () => { text="Custom message" lightCard className="custom-wrapper" - locale="en-US" />, ) @@ -695,12 +664,6 @@ describe('Empty', () => { expect(container.querySelector('.only-class')).toBeInTheDocument() }) - it('should render with only locale prop', () => { - render() - - expect(screen.getByText('No plugin found')).toBeInTheDocument() - }) - it('should handle text with unicode characters', () => { render() @@ -813,7 +776,7 @@ describe('Empty and Line Integration', () => { }) it('should render complete Empty component structure', () => { - const { container } = render() + const { container } = render() // Container expect(container.querySelector('.test')).toBeInTheDocument() diff --git a/web/app/components/plugins/marketplace/empty/index.tsx b/web/app/components/plugins/marketplace/empty/index.tsx index 3c33d9b92a..6e5adff1b4 100644 --- a/web/app/components/plugins/marketplace/empty/index.tsx +++ b/web/app/components/plugins/marketplace/empty/index.tsx @@ -1,6 +1,6 @@ 'use client' +import { useTranslation } from '#i18n' import { Group } from '@/app/components/base/icons/src/vender/other' -import { useMixedTranslation } from '@/app/components/plugins/marketplace/hooks' import { cn } from '@/utils/classnames' import Line from './line' @@ -8,16 +8,14 @@ type Props = { text?: string lightCard?: boolean className?: string - locale?: string } const Empty = ({ text, lightCard, className, - locale, }: Props) => { - const { t } = useMixedTranslation(locale) + const { t } = useTranslation() return (
{ } } -/** - * ! Support zh-Hans, pt-BR, ja-JP and en-US for Marketplace page - * ! For other languages, use en-US as fallback - */ -export const useMixedTranslation = (localeFromOuter?: string) => { - let t = useTranslation().t - - if (localeFromOuter) - t = i18n.getFixedT(localeFromOuter) - - return { - t, - } -} - export const useMarketplaceContainerScroll = ( callback: () => void, scrollContainerId = 'marketplace-container', diff --git a/web/app/components/plugins/marketplace/index.spec.tsx b/web/app/components/plugins/marketplace/index.spec.tsx index 3073897ba1..b3b1d58dd4 100644 --- a/web/app/components/plugins/marketplace/index.spec.tsx +++ b/web/app/components/plugins/marketplace/index.spec.tsx @@ -11,7 +11,6 @@ import { PluginCategoryEnum } from '@/app/components/plugins/types' // Note: Import after mocks are set up import { DEFAULT_SORT, SCROLL_BOTTOM_THRESHOLD } from './constants' import { MarketplaceContext, MarketplaceContextProvider, useMarketplaceContext } from './context' -import { useMixedTranslation } from './hooks' import PluginTypeSwitch, { PLUGIN_TYPE_SEARCH_MAP } from './plugin-type-switch' import StickySearchAndSwitchWrapper from './sticky-search-and-switch-wrapper' import { @@ -602,48 +601,6 @@ describe('utils', () => { }) }) -// ================================ -// Hooks Tests -// ================================ -describe('hooks', () => { - describe('useMixedTranslation', () => { - it('should return translation function', () => { - const { result } = renderHook(() => useMixedTranslation()) - - expect(result.current.t).toBeDefined() - expect(typeof result.current.t).toBe('function') - }) - - it('should return translation key when no translation found', () => { - const { result } = renderHook(() => useMixedTranslation()) - - // The global mock returns key with namespace prefix - expect(result.current.t('category.all', { ns: 'plugin' })).toBe('plugin.category.all') - }) - - it('should use locale from outer when provided', () => { - const { result } = renderHook(() => useMixedTranslation('zh-Hans')) - - expect(result.current.t).toBeDefined() - }) - - it('should handle different locale values', () => { - const locales = ['en-US', 'zh-Hans', 'ja-JP', 'pt-BR'] - locales.forEach((locale) => { - const { result } = renderHook(() => useMixedTranslation(locale)) - expect(result.current.t).toBeDefined() - expect(typeof result.current.t).toBe('function') - }) - }) - - it('should use getFixedT when localeFromOuter is provided', () => { - const { result } = renderHook(() => useMixedTranslation('fr-FR')) - // The global mock returns key with namespace prefix - expect(result.current.t('search', { ns: 'plugin' })).toBe('plugin.search') - }) - }) -}) - // ================================ // useMarketplaceCollectionsAndPlugins Tests // ================================ @@ -2088,17 +2045,6 @@ describe('StickySearchAndSwitchWrapper', () => { }) describe('Props', () => { - it('should accept locale prop', () => { - render( - - - , - ) - - // Component should render without errors - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() - }) - it('should accept showSearchParams prop', () => { render( diff --git a/web/app/components/plugins/marketplace/index.tsx b/web/app/components/plugins/marketplace/index.tsx index ff9a4d60bc..08d1bc833f 100644 --- a/web/app/components/plugins/marketplace/index.tsx +++ b/web/app/components/plugins/marketplace/index.tsx @@ -1,6 +1,5 @@ import type { MarketplaceCollection, SearchParams } from './types' import type { Plugin } from '@/app/components/plugins/types' -import type { Locale } from '@/i18n-config' import { TanstackQueryInitializer } from '@/context/query-client' import { MarketplaceContextProvider } from './context' import Description from './description' @@ -9,7 +8,6 @@ import StickySearchAndSwitchWrapper from './sticky-search-and-switch-wrapper' import { getMarketplaceCollectionsAndPlugins } from './utils' type MarketplaceProps = { - locale: Locale showInstallButton?: boolean shouldExclude?: boolean searchParams?: SearchParams @@ -18,7 +16,6 @@ type MarketplaceProps = { showSearchParams?: boolean } const Marketplace = async ({ - locale, showInstallButton = true, shouldExclude, searchParams, @@ -42,14 +39,12 @@ const Marketplace = async ({ scrollContainerId={scrollContainerId} showSearchParams={showSearchParams} > - + { - const { t } = useMixedTranslation(locale) + const { t } = useTranslation() const { theme } = useTheme() const [isShowInstallFromMarketplace, { setTrue: showInstallFromMarketplace, setFalse: hideInstallFromMarketplace, }] = useBoolean(false) - const localeFromLocale = useLocale() - const { getTagLabel } = useTags(t) + const locale = useLocale() + const { getTagLabel } = useTags() // Memoize marketplace link params to prevent unnecessary re-renders const marketplaceLinkParams = useMemo(() => ({ - language: localeFromLocale, + language: locale, theme, - }), [localeFromLocale, theme]) + }), [locale, theme]) // Memoize tag labels to prevent recreating array on every render const tagLabels = useMemo(() => @@ -52,7 +48,6 @@ const CardWrapperComponent = ({ ({ - useMixedTranslation: (_locale?: string) => ({ +// Mock i18n translation hook +vi.mock('#i18n', () => ({ + useTranslation: () => ({ t: (key: string, options?: { ns?: string, num?: number }) => { // Build full key with namespace prefix if provided const fullKey = options?.ns ? `${options.ns}.${key}` : key @@ -28,6 +27,7 @@ vi.mock('../hooks', () => ({ return translations[fullKey] || key }, }), + useLocale: () => 'en-US', })) // Mock useMarketplaceContext with controllable values @@ -148,15 +148,15 @@ vi.mock('@/app/components/plugins/install-plugin/install-from-marketplace', () = // Mock SortDropdown component vi.mock('../sort-dropdown', () => ({ - default: ({ locale }: { locale: Locale }) => ( -
Sort
+ default: () => ( +
Sort
), })) // Mock Empty component vi.mock('../empty', () => ({ - default: ({ className, locale }: { className?: string, locale?: string }) => ( -
+ default: ({ className }: { className?: string }) => ( +
No plugins found
), @@ -233,7 +233,6 @@ describe('List', () => { marketplaceCollectionPluginsMap: {} as Record, plugins: undefined, showInstallButton: false, - locale: 'en-US' as Locale, cardContainerClassName: '', cardRender: undefined, onMoreClick: undefined, @@ -351,18 +350,6 @@ describe('List', () => { expect(screen.getByTestId('empty-component')).toHaveClass('custom-empty-class') }) - it('should pass locale to Empty component', () => { - render( - , - ) - - expect(screen.getByTestId('empty-component')).toHaveAttribute('data-locale', 'zh-CN') - }) - it('should pass showInstallButton to CardWrapper', () => { const plugins = createMockPluginList(1) @@ -508,7 +495,6 @@ describe('ListWithCollection', () => { marketplaceCollections: [] as MarketplaceCollection[], marketplaceCollectionPluginsMap: {} as Record, showInstallButton: false, - locale: 'en-US' as Locale, cardContainerClassName: '', cardRender: undefined, onMoreClick: undefined, @@ -820,7 +806,6 @@ describe('ListWrapper', () => { marketplaceCollections: [] as MarketplaceCollection[], marketplaceCollectionPluginsMap: {} as Record, showInstallButton: false, - locale: 'en-US' as Locale, } beforeEach(() => { @@ -901,14 +886,6 @@ describe('ListWrapper', () => { expect(screen.queryByTestId('sort-dropdown')).not.toBeInTheDocument() }) - - it('should pass locale to SortDropdown', () => { - mockContextValues.plugins = createMockPluginList(1) - - render() - - expect(screen.getByTestId('sort-dropdown')).toHaveAttribute('data-locale', 'zh-CN') - }) }) // ================================ @@ -1169,7 +1146,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollections={[]} marketplaceCollectionPluginsMap={{}} plugins={[plugin]} - locale="en-US" />, ) @@ -1188,7 +1164,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollections={[]} marketplaceCollectionPluginsMap={{}} plugins={[plugin]} - locale="en-US" />, ) @@ -1209,7 +1184,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollections={[]} marketplaceCollectionPluginsMap={{}} plugins={plugins} - locale="en-US" />, ) @@ -1231,7 +1205,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollectionPluginsMap={{}} plugins={[plugin]} showInstallButton={true} - locale="en-US" />, ) @@ -1252,7 +1225,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollectionPluginsMap={{}} plugins={[plugin]} showInstallButton={true} - locale="en-US" />, ) @@ -1274,7 +1246,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollectionPluginsMap={{}} plugins={[plugin]} showInstallButton={true} - locale="en-US" />, ) @@ -1293,7 +1264,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollectionPluginsMap={{}} plugins={[plugin]} showInstallButton={true} - locale="en-US" />, ) @@ -1310,7 +1280,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollectionPluginsMap={{}} plugins={[plugin]} showInstallButton={true} - locale="en-US" />, ) @@ -1327,7 +1296,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollectionPluginsMap={{}} plugins={[plugin]} showInstallButton={true} - locale="en-US" />, ) @@ -1354,7 +1322,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollectionPluginsMap={{}} plugins={[plugin]} showInstallButton={false} - locale="en-US" />, ) @@ -1375,7 +1342,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollectionPluginsMap={{}} plugins={[plugin]} showInstallButton={false} - locale="en-US" />, ) @@ -1390,7 +1356,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollections={[]} marketplaceCollectionPluginsMap={{}} plugins={[plugin]} - locale="en-US" />, ) @@ -1414,7 +1379,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollections={[]} marketplaceCollectionPluginsMap={{}} plugins={[plugin]} - locale="en-US" />, ) @@ -1432,7 +1396,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollections={[]} marketplaceCollectionPluginsMap={{}} plugins={[plugin]} - locale="en-US" />, ) @@ -1450,7 +1413,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollections={[]} marketplaceCollectionPluginsMap={{}} plugins={[plugin]} - locale="en-US" />, ) @@ -1482,7 +1444,6 @@ describe('Combined Workflows', () => { , ) @@ -1501,7 +1462,6 @@ describe('Combined Workflows', () => { , ) @@ -1521,7 +1481,6 @@ describe('Combined Workflows', () => { , ) @@ -1535,7 +1494,6 @@ describe('Combined Workflows', () => { , ) @@ -1551,7 +1509,6 @@ describe('Combined Workflows', () => { , ) @@ -1569,7 +1526,6 @@ describe('Combined Workflows', () => { , ) @@ -1601,7 +1557,6 @@ describe('Accessibility', () => { , ) @@ -1625,7 +1580,6 @@ describe('Accessibility', () => { marketplaceCollections={collections} marketplaceCollectionPluginsMap={pluginsMap} onMoreClick={onMoreClick} - locale="en-US" />, ) @@ -1642,7 +1596,6 @@ describe('Accessibility', () => { marketplaceCollections={[]} marketplaceCollectionPluginsMap={{}} plugins={plugins} - locale="en-US" />, ) @@ -1668,7 +1621,6 @@ describe('Performance', () => { marketplaceCollections={[]} marketplaceCollectionPluginsMap={{}} plugins={plugins} - locale="en-US" />, ) const endTime = performance.now() @@ -1689,7 +1641,6 @@ describe('Performance', () => { , ) const endTime = performance.now() diff --git a/web/app/components/plugins/marketplace/list/index.tsx b/web/app/components/plugins/marketplace/list/index.tsx index 54889b232f..80b33d0ffd 100644 --- a/web/app/components/plugins/marketplace/list/index.tsx +++ b/web/app/components/plugins/marketplace/list/index.tsx @@ -1,7 +1,6 @@ 'use client' import type { Plugin } from '../../types' import type { MarketplaceCollection } from '../types' -import type { Locale } from '@/i18n-config' import { cn } from '@/utils/classnames' import Empty from '../empty' import CardWrapper from './card-wrapper' @@ -12,7 +11,6 @@ type ListProps = { marketplaceCollectionPluginsMap: Record plugins?: Plugin[] showInstallButton?: boolean - locale: Locale cardContainerClassName?: string cardRender?: (plugin: Plugin) => React.JSX.Element | null onMoreClick?: () => void @@ -23,7 +21,6 @@ const List = ({ marketplaceCollectionPluginsMap, plugins, showInstallButton, - locale, cardContainerClassName, cardRender, onMoreClick, @@ -37,7 +34,6 @@ const List = ({ marketplaceCollections={marketplaceCollections} marketplaceCollectionPluginsMap={marketplaceCollectionPluginsMap} showInstallButton={showInstallButton} - locale={locale} cardContainerClassName={cardContainerClassName} cardRender={cardRender} onMoreClick={onMoreClick} @@ -61,7 +57,6 @@ const List = ({ key={`${plugin.org}/${plugin.name}`} plugin={plugin} showInstallButton={showInstallButton} - locale={locale} /> ) }) @@ -71,7 +66,7 @@ const List = ({ } { plugins && !plugins.length && ( - + ) } diff --git a/web/app/components/plugins/marketplace/list/list-with-collection.tsx b/web/app/components/plugins/marketplace/list/list-with-collection.tsx index 8830cc5ddf..c17715e71e 100644 --- a/web/app/components/plugins/marketplace/list/list-with-collection.tsx +++ b/web/app/components/plugins/marketplace/list/list-with-collection.tsx @@ -3,9 +3,8 @@ import type { MarketplaceCollection } from '../types' import type { SearchParamsFromCollection } from '@/app/components/plugins/marketplace/types' import type { Plugin } from '@/app/components/plugins/types' -import type { Locale } from '@/i18n-config' +import { useLocale, useTranslation } from '#i18n' import { RiArrowRightSLine } from '@remixicon/react' -import { useMixedTranslation } from '@/app/components/plugins/marketplace/hooks' import { getLanguage } from '@/i18n-config/language' import { cn } from '@/utils/classnames' import CardWrapper from './card-wrapper' @@ -14,7 +13,6 @@ type ListWithCollectionProps = { marketplaceCollections: MarketplaceCollection[] marketplaceCollectionPluginsMap: Record showInstallButton?: boolean - locale: Locale cardContainerClassName?: string cardRender?: (plugin: Plugin) => React.JSX.Element | null onMoreClick?: (searchParams?: SearchParamsFromCollection) => void @@ -23,12 +21,12 @@ const ListWithCollection = ({ marketplaceCollections, marketplaceCollectionPluginsMap, showInstallButton, - locale, cardContainerClassName, cardRender, onMoreClick, }: ListWithCollectionProps) => { - const { t } = useMixedTranslation(locale) + const { t } = useTranslation() + const locale = useLocale() return ( <> @@ -72,7 +70,6 @@ const ListWithCollection = ({ key={plugin.plugin_id} plugin={plugin} showInstallButton={showInstallButton} - locale={locale} /> ) }) diff --git a/web/app/components/plugins/marketplace/list/list-wrapper.tsx b/web/app/components/plugins/marketplace/list/list-wrapper.tsx index f8126eb34b..84fcf92daf 100644 --- a/web/app/components/plugins/marketplace/list/list-wrapper.tsx +++ b/web/app/components/plugins/marketplace/list/list-wrapper.tsx @@ -1,10 +1,9 @@ 'use client' import type { Plugin } from '../../types' import type { MarketplaceCollection } from '../types' -import type { Locale } from '@/i18n-config' +import { useTranslation } from '#i18n' import { useEffect } from 'react' import Loading from '@/app/components/base/loading' -import { useMixedTranslation } from '@/app/components/plugins/marketplace/hooks' import { useMarketplaceContext } from '../context' import SortDropdown from '../sort-dropdown' import List from './index' @@ -13,15 +12,13 @@ type ListWrapperProps = { marketplaceCollections: MarketplaceCollection[] marketplaceCollectionPluginsMap: Record showInstallButton?: boolean - locale: Locale } const ListWrapper = ({ marketplaceCollections, marketplaceCollectionPluginsMap, showInstallButton, - locale, }: ListWrapperProps) => { - const { t } = useMixedTranslation(locale) + const { t } = useTranslation() const plugins = useMarketplaceContext(v => v.plugins) const pluginsTotal = useMarketplaceContext(v => v.pluginsTotal) const marketplaceCollectionsFromClient = useMarketplaceContext(v => v.marketplaceCollectionsFromClient) @@ -55,7 +52,7 @@ const ListWrapper = ({
{t('marketplace.pluginsResult', { ns: 'plugin', num: pluginsTotal })}
- +
) } @@ -73,7 +70,6 @@ const ListWrapper = ({ marketplaceCollectionPluginsMap={marketplaceCollectionPluginsMapFromClient || marketplaceCollectionPluginsMap} plugins={plugins} showInstallButton={showInstallButton} - locale={locale} onMoreClick={handleMoreClick} /> ) diff --git a/web/app/components/plugins/marketplace/plugin-type-switch.tsx b/web/app/components/plugins/marketplace/plugin-type-switch.tsx index 2a89e6847e..b9572413ed 100644 --- a/web/app/components/plugins/marketplace/plugin-type-switch.tsx +++ b/web/app/components/plugins/marketplace/plugin-type-switch.tsx @@ -1,4 +1,5 @@ 'use client' +import { useTranslation } from '#i18n' import { RiArchive2Line, RiBrain2Line, @@ -12,7 +13,6 @@ import { Trigger as TriggerIcon } from '@/app/components/base/icons/src/vender/p import { cn } from '@/utils/classnames' import { PluginCategoryEnum } from '../types' import { useMarketplaceContext } from './context' -import { useMixedTranslation } from './hooks' export const PLUGIN_TYPE_SEARCH_MAP = { all: 'all', @@ -25,16 +25,14 @@ export const PLUGIN_TYPE_SEARCH_MAP = { bundle: 'bundle', } type PluginTypeSwitchProps = { - locale?: string className?: string showSearchParams?: boolean } const PluginTypeSwitch = ({ - locale, className, showSearchParams, }: PluginTypeSwitchProps) => { - const { t } = useMixedTranslation(locale) + const { t } = useTranslation() const activePluginType = useMarketplaceContext(s => s.activePluginType) const handleActivePluginTypeChange = useMarketplaceContext(s => s.handleActivePluginTypeChange) diff --git a/web/app/components/plugins/marketplace/search-box/index.spec.tsx b/web/app/components/plugins/marketplace/search-box/index.spec.tsx index 8c3131f6d1..3e9cc40be0 100644 --- a/web/app/components/plugins/marketplace/search-box/index.spec.tsx +++ b/web/app/components/plugins/marketplace/search-box/index.spec.tsx @@ -10,9 +10,9 @@ import ToolSelectorTrigger from './trigger/tool-selector' // Mock external dependencies only // ================================ -// Mock useMixedTranslation hook -vi.mock('../hooks', () => ({ - useMixedTranslation: (_locale?: string) => ({ +// Mock i18n translation hook +vi.mock('#i18n', () => ({ + useTranslation: () => ({ t: (key: string, options?: { ns?: string }) => { // Build full key with namespace prefix if provided const fullKey = options?.ns ? `${options.ns}.${key}` : key @@ -364,13 +364,6 @@ describe('SearchBox', () => { expect(container.querySelector('.custom-input-class')).toBeInTheDocument() }) - it('should pass locale to TagsFilter', () => { - render() - - // TagsFilter should be rendered with locale - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() - }) - it('should handle empty placeholder', () => { render() @@ -449,12 +442,6 @@ describe('SearchBoxWrapper', () => { expect(screen.getByRole('textbox')).toBeInTheDocument() }) - it('should render with locale prop', () => { - render() - - expect(screen.getByRole('textbox')).toBeInTheDocument() - }) - it('should render in marketplace mode', () => { const { container } = render() @@ -500,13 +487,6 @@ describe('SearchBoxWrapper', () => { expect(screen.getByPlaceholderText('Search plugins')).toBeInTheDocument() }) - - it('should pass locale to useMixedTranslation', () => { - render() - - // Translation should still work - expect(screen.getByPlaceholderText('Search plugins')).toBeInTheDocument() - }) }) }) @@ -665,12 +645,6 @@ describe('MarketplaceTrigger', () => { }) describe('Props Variations', () => { - it('should handle locale prop', () => { - render() - - expect(screen.getByText('All Tags')).toBeInTheDocument() - }) - it('should handle empty tagsMap', () => { const { container } = render( , @@ -1251,7 +1225,6 @@ describe('Combined Workflows', () => { supportAddCustomTool onShowAddCustomCollectionModal={vi.fn()} placeholder="Search plugins" - locale="en-US" wrapperClassName="custom-wrapper" inputClassName="custom-input" autoFocus={false} diff --git a/web/app/components/plugins/marketplace/search-box/index.tsx b/web/app/components/plugins/marketplace/search-box/index.tsx index 05f98782b9..b6e1f8ee70 100644 --- a/web/app/components/plugins/marketplace/search-box/index.tsx +++ b/web/app/components/plugins/marketplace/search-box/index.tsx @@ -13,7 +13,6 @@ type SearchBoxProps = { tags: string[] onTagsChange: (tags: string[]) => void placeholder?: string - locale?: string supportAddCustomTool?: boolean usedInMarketplace?: boolean onShowAddCustomCollectionModal?: () => void @@ -28,7 +27,6 @@ const SearchBox = ({ tags, onTagsChange, placeholder = '', - locale, usedInMarketplace = false, supportAddCustomTool, onShowAddCustomCollectionModal, @@ -49,7 +47,6 @@ const SearchBox = ({ tags={tags} onTagsChange={onTagsChange} usedInMarketplace - locale={locale} />
@@ -109,7 +106,6 @@ const SearchBox = ({ ) diff --git a/web/app/components/plugins/marketplace/search-box/search-box-wrapper.tsx b/web/app/components/plugins/marketplace/search-box/search-box-wrapper.tsx index 1290c26210..d7fc004236 100644 --- a/web/app/components/plugins/marketplace/search-box/search-box-wrapper.tsx +++ b/web/app/components/plugins/marketplace/search-box/search-box-wrapper.tsx @@ -1,16 +1,11 @@ 'use client' +import { useTranslation } from '#i18n' import { useMarketplaceContext } from '../context' -import { useMixedTranslation } from '../hooks' import SearchBox from './index' -type SearchBoxWrapperProps = { - locale?: string -} -const SearchBoxWrapper = ({ - locale, -}: SearchBoxWrapperProps) => { - const { t } = useMixedTranslation(locale) +const SearchBoxWrapper = () => { + const { t } = useTranslation() const searchPluginText = useMarketplaceContext(v => v.searchPluginText) const handleSearchPluginTextChange = useMarketplaceContext(v => v.handleSearchPluginTextChange) const filterPluginTags = useMarketplaceContext(v => v.filterPluginTags) @@ -24,7 +19,6 @@ const SearchBoxWrapper = ({ onSearchChange={handleSearchPluginTextChange} tags={filterPluginTags} onTagsChange={handleFilterPluginTagsChange} - locale={locale} placeholder={t('searchPlugins', { ns: 'plugin' })} usedInMarketplace /> diff --git a/web/app/components/plugins/marketplace/search-box/tags-filter.tsx b/web/app/components/plugins/marketplace/search-box/tags-filter.tsx index df4d3eebab..9a8035e2e3 100644 --- a/web/app/components/plugins/marketplace/search-box/tags-filter.tsx +++ b/web/app/components/plugins/marketplace/search-box/tags-filter.tsx @@ -1,5 +1,6 @@ 'use client' +import { useTranslation } from '#i18n' import { useState } from 'react' import Checkbox from '@/app/components/base/checkbox' import Input from '@/app/components/base/input' @@ -9,7 +10,6 @@ import { PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' import { useTags } from '@/app/components/plugins/hooks' -import { useMixedTranslation } from '@/app/components/plugins/marketplace/hooks' import MarketplaceTrigger from './trigger/marketplace' import ToolSelectorTrigger from './trigger/tool-selector' @@ -17,18 +17,16 @@ type TagsFilterProps = { tags: string[] onTagsChange: (tags: string[]) => void usedInMarketplace?: boolean - locale?: string } const TagsFilter = ({ tags, onTagsChange, usedInMarketplace = false, - locale, }: TagsFilterProps) => { - const { t } = useMixedTranslation(locale) + const { t } = useTranslation() const [open, setOpen] = useState(false) const [searchText, setSearchText] = useState('') - const { tags: options, tagsMap } = useTags(t) + const { tags: options, tagsMap } = useTags() const filteredOptions = options.filter(option => option.label.toLowerCase().includes(searchText.toLowerCase())) const handleCheck = (id: string) => { if (tags.includes(id)) @@ -59,7 +57,6 @@ const TagsFilter = ({ open={open} tags={tags} tagsMap={tagsMap} - locale={locale} onTagsChange={onTagsChange} /> ) diff --git a/web/app/components/plugins/marketplace/search-box/trigger/marketplace.tsx b/web/app/components/plugins/marketplace/search-box/trigger/marketplace.tsx index 2ba03bd2f2..e387d52d0e 100644 --- a/web/app/components/plugins/marketplace/search-box/trigger/marketplace.tsx +++ b/web/app/components/plugins/marketplace/search-box/trigger/marketplace.tsx @@ -1,15 +1,14 @@ import type { Tag } from '../../../hooks' +import { useTranslation } from '#i18n' import { RiArrowDownSLine, RiCloseCircleFill, RiFilter3Line } from '@remixicon/react' import * as React from 'react' import { cn } from '@/utils/classnames' -import { useMixedTranslation } from '../../hooks' type MarketplaceTriggerProps = { selectedTagsLength: number open: boolean tags: string[] tagsMap: Record - locale?: string onTagsChange: (tags: string[]) => void } @@ -18,10 +17,9 @@ const MarketplaceTrigger = ({ open, tags, tagsMap, - locale, onTagsChange, }: MarketplaceTriggerProps) => { - const { t } = useMixedTranslation(locale) + const { t } = useTranslation() return (
{ // Build full key with namespace prefix if provided const fullKey = options?.ns ? `${options.ns}.${key}` : key @@ -22,8 +22,8 @@ const mockTranslation = vi.fn((key: string, options?: { ns?: string }) => { return translations[fullKey] || key }) -vi.mock('../hooks', () => ({ - useMixedTranslation: (_locale?: string) => ({ +vi.mock('#i18n', () => ({ + useTranslation: () => ({ t: mockTranslation, }), })) @@ -145,36 +145,6 @@ describe('SortDropdown', () => { }) }) - // ================================ - // Props Testing - // ================================ - describe('Props', () => { - it('should accept locale prop', () => { - render() - - expect(screen.getByTestId('portal-wrapper')).toBeInTheDocument() - }) - - it('should call useMixedTranslation with provided locale', () => { - render() - - // Translation function should be called for labels - expect(mockTranslation).toHaveBeenCalledWith('marketplace.sortBy', { ns: 'plugin' }) - }) - - it('should render without locale prop (undefined)', () => { - render() - - expect(screen.getByText('Sort by')).toBeInTheDocument() - }) - - it('should render with empty string locale', () => { - render() - - expect(screen.getByText('Sort by')).toBeInTheDocument() - }) - }) - // ================================ // State Management Tests // ================================ @@ -618,13 +588,6 @@ describe('SortDropdown', () => { expect(mockTranslation).toHaveBeenCalledWith('marketplace.sortOption.newlyReleased', { ns: 'plugin' }) expect(mockTranslation).toHaveBeenCalledWith('marketplace.sortOption.firstReleased', { ns: 'plugin' }) }) - - it('should pass locale to useMixedTranslation', () => { - render() - - // Verify component renders with locale - expect(screen.getByTestId('portal-wrapper')).toBeInTheDocument() - }) }) // ================================ diff --git a/web/app/components/plugins/marketplace/sort-dropdown/index.tsx b/web/app/components/plugins/marketplace/sort-dropdown/index.tsx index a1f6631735..984b114d03 100644 --- a/web/app/components/plugins/marketplace/sort-dropdown/index.tsx +++ b/web/app/components/plugins/marketplace/sort-dropdown/index.tsx @@ -1,4 +1,5 @@ 'use client' +import { useTranslation } from '#i18n' import { RiArrowDownSLine, RiCheckLine, @@ -9,16 +10,10 @@ import { PortalToFollowElemContent, PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' -import { useMixedTranslation } from '@/app/components/plugins/marketplace/hooks' import { useMarketplaceContext } from '../context' -type SortDropdownProps = { - locale?: string -} -const SortDropdown = ({ - locale, -}: SortDropdownProps) => { - const { t } = useMixedTranslation(locale) +const SortDropdown = () => { + const { t } = useTranslation() const options = [ { value: 'install_count', diff --git a/web/app/components/plugins/marketplace/sticky-search-and-switch-wrapper.tsx b/web/app/components/plugins/marketplace/sticky-search-and-switch-wrapper.tsx index 602a1e9af2..3d3530c83e 100644 --- a/web/app/components/plugins/marketplace/sticky-search-and-switch-wrapper.tsx +++ b/web/app/components/plugins/marketplace/sticky-search-and-switch-wrapper.tsx @@ -5,13 +5,11 @@ import PluginTypeSwitch from './plugin-type-switch' import SearchBoxWrapper from './search-box/search-box-wrapper' type StickySearchAndSwitchWrapperProps = { - locale?: string pluginTypeSwitchClassName?: string showSearchParams?: boolean } const StickySearchAndSwitchWrapper = ({ - locale, pluginTypeSwitchClassName, showSearchParams, }: StickySearchAndSwitchWrapperProps) => { @@ -25,9 +23,8 @@ const StickySearchAndSwitchWrapper = ({ pluginTypeSwitchClassName, )} > - +
diff --git a/web/app/components/plugins/plugin-item/index.tsx b/web/app/components/plugins/plugin-item/index.tsx index d287bd9e9a..3f658c63a8 100644 --- a/web/app/components/plugins/plugin-item/index.tsx +++ b/web/app/components/plugins/plugin-item/index.tsx @@ -44,7 +44,7 @@ const PluginItem: FC = ({ }) => { const { t } = useTranslation() const { theme } = useTheme() - const { categoriesMap } = useCategories(t, true) + const { categoriesMap } = useCategories(true) const currentPluginID = usePluginPageContext(v => v.currentPluginID) const setCurrentPluginID = usePluginPageContext(v => v.setCurrentPluginID) const { refreshPluginList } = useRefreshPluginList() diff --git a/web/app/components/provider/i18n-server.tsx b/web/app/components/provider/i18n-server.tsx new file mode 100644 index 0000000000..23391cf428 --- /dev/null +++ b/web/app/components/provider/i18n-server.tsx @@ -0,0 +1,21 @@ +import { getLocaleOnServer, getResources } from '@/i18n-config/server' + +import { I18nClientProvider } from './i18n' + +export async function I18nServerProvider({ + children, +}: { + children: React.ReactNode +}) { + const locale = await getLocaleOnServer() + const resource = await getResources(locale) + + return ( + + {children} + + ) +} diff --git a/web/app/components/provider/i18n.tsx b/web/app/components/provider/i18n.tsx new file mode 100644 index 0000000000..6441a09dd3 --- /dev/null +++ b/web/app/components/provider/i18n.tsx @@ -0,0 +1,24 @@ +'use client' + +import type { Resource } from 'i18next' +import type { Locale } from '@/i18n-config' +import { I18nextProvider } from 'react-i18next' +import { createI18nextInstance } from '@/i18n-config/client' + +export function I18nClientProvider({ + locale, + resource, + children, +}: { + locale: Locale + resource: Resource + children: React.ReactNode +}) { + const i18n = createI18nextInstance(locale, resource) + + return ( + + {children} + + ) +} diff --git a/web/app/components/share/text-generation/index.tsx b/web/app/components/share/text-generation/index.tsx index b9bb59664a..b793a03ce7 100644 --- a/web/app/components/share/text-generation/index.tsx +++ b/web/app/components/share/text-generation/index.tsx @@ -32,7 +32,7 @@ import { useWebAppStore } from '@/context/web-app-context' import { useAppFavicon } from '@/hooks/use-app-favicon' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useDocumentTitle from '@/hooks/use-document-title' -import { changeLanguage } from '@/i18n-config/i18next-config' +import { changeLanguage } from '@/i18n-config/client' import { AccessMode } from '@/models/access-control' import { fetchSavedMessage as doFetchSavedMessage, removeMessage, saveMessage } from '@/service/share' import { Resolution, TransferMethod } from '@/types/app' diff --git a/web/app/components/tools/marketplace/index.spec.tsx b/web/app/components/tools/marketplace/index.spec.tsx index 354c717f2d..493d960e2a 100644 --- a/web/app/components/tools/marketplace/index.spec.tsx +++ b/web/app/components/tools/marketplace/index.spec.tsx @@ -19,7 +19,6 @@ vi.mock('@/app/components/plugins/marketplace/list', () => ({ marketplaceCollectionPluginsMap: Record plugins?: unknown[] showInstallButton?: boolean - locale: string }) => { listRenderSpy(props) return
@@ -42,10 +41,6 @@ vi.mock('@/utils/var', () => ({ getMarketplaceUrl: vi.fn(() => 'https://marketplace.test/market'), })) -vi.mock('@/i18n-config', () => ({ - getLocaleOnClient: () => 'en', -})) - vi.mock('next-themes', () => ({ useTheme: () => ({ theme: 'light' }), })) @@ -148,7 +143,6 @@ describe('Marketplace', () => { expect(screen.getByTestId('marketplace-list')).toBeInTheDocument() expect(listRenderSpy).toHaveBeenCalledWith(expect.objectContaining({ showInstallButton: true, - locale: 'en', })) }) }) diff --git a/web/app/components/tools/marketplace/index.tsx b/web/app/components/tools/marketplace/index.tsx index 476a47d8f6..3900a9e505 100644 --- a/web/app/components/tools/marketplace/index.tsx +++ b/web/app/components/tools/marketplace/index.tsx @@ -1,4 +1,5 @@ import type { useMarketplace } from './hooks' +import { useLocale } from '#i18n' import { RiArrowRightUpLine, RiArrowUpDoubleLine, @@ -7,7 +8,6 @@ import { useTheme } from 'next-themes' import { useTranslation } from 'react-i18next' import Loading from '@/app/components/base/loading' import List from '@/app/components/plugins/marketplace/list' -import { getLocaleOnClient } from '@/i18n-config' import { getMarketplaceUrl } from '@/utils/var' type MarketplaceProps = { @@ -24,7 +24,7 @@ const Marketplace = ({ showMarketplacePanel, marketplaceContext, }: MarketplaceProps) => { - const locale = getLocaleOnClient() + const locale = useLocale() const { t } = useTranslation() const { theme } = useTheme() const { @@ -104,7 +104,6 @@ const Marketplace = ({ marketplaceCollectionPluginsMap={marketplaceCollectionPluginsMap || {}} plugins={plugins} showInstallButton - locale={locale} /> ) } diff --git a/web/app/forgot-password/ForgotPasswordForm.spec.tsx b/web/app/forgot-password/ForgotPasswordForm.spec.tsx new file mode 100644 index 0000000000..aa360cb6c3 --- /dev/null +++ b/web/app/forgot-password/ForgotPasswordForm.spec.tsx @@ -0,0 +1,163 @@ +import type { InitValidateStatusResponse, SetupStatusResponse } from '@/models/common' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { fetchInitValidateStatus, fetchSetupStatus, sendForgotPasswordEmail } from '@/service/common' +import ForgotPasswordForm from './ForgotPasswordForm' + +const mockPush = vi.fn() + +vi.mock('next/navigation', () => ({ + useRouter: () => ({ push: mockPush }), +})) + +vi.mock('@/service/common', () => ({ + fetchSetupStatus: vi.fn(), + fetchInitValidateStatus: vi.fn(), + sendForgotPasswordEmail: vi.fn(), +})) + +const mockFetchSetupStatus = vi.mocked(fetchSetupStatus) +const mockFetchInitValidateStatus = vi.mocked(fetchInitValidateStatus) +const mockSendForgotPasswordEmail = vi.mocked(sendForgotPasswordEmail) + +const prepareLoadedState = () => { + mockFetchSetupStatus.mockResolvedValue({ step: 'not_started' } as SetupStatusResponse) + mockFetchInitValidateStatus.mockResolvedValue({ status: 'finished' } as InitValidateStatusResponse) +} + +describe('ForgotPasswordForm', () => { + beforeEach(() => { + vi.clearAllMocks() + prepareLoadedState() + }) + + it('should render form after loading', async () => { + render() + + expect(await screen.findByLabelText('login.email')).toBeInTheDocument() + }) + + it('should show validation error when email is empty', async () => { + render() + + await screen.findByLabelText('login.email') + + fireEvent.click(screen.getByRole('button', { name: /login\.sendResetLink/ })) + + await waitFor(() => { + expect(screen.getByText('login.error.emailInValid')).toBeInTheDocument() + }) + expect(mockSendForgotPasswordEmail).not.toHaveBeenCalled() + }) + + it('should send reset email and navigate after confirmation', async () => { + mockSendForgotPasswordEmail.mockResolvedValue({ result: 'success', data: 'ok' } as any) + + render() + + const emailInput = await screen.findByLabelText('login.email') + fireEvent.change(emailInput, { target: { value: 'test@example.com' } }) + + fireEvent.click(screen.getByRole('button', { name: /login\.sendResetLink/ })) + + await waitFor(() => { + expect(mockSendForgotPasswordEmail).toHaveBeenCalledWith({ + url: '/forgot-password', + body: { email: 'test@example.com' }, + }) + }) + + await waitFor(() => { + expect(screen.getByRole('button', { name: /login\.backToSignIn/ })).toBeInTheDocument() + }) + + fireEvent.click(screen.getByRole('button', { name: /login\.backToSignIn/ })) + expect(mockPush).toHaveBeenCalledWith('/signin') + }) + + it('should submit when form is submitted', async () => { + mockSendForgotPasswordEmail.mockResolvedValue({ result: 'success', data: 'ok' } as any) + + render() + + fireEvent.change(await screen.findByLabelText('login.email'), { target: { value: 'test@example.com' } }) + + const form = screen.getByRole('button', { name: /login\.sendResetLink/ }).closest('form') + expect(form).not.toBeNull() + + fireEvent.submit(form as HTMLFormElement) + + await waitFor(() => { + expect(mockSendForgotPasswordEmail).toHaveBeenCalledWith({ + url: '/forgot-password', + body: { email: 'test@example.com' }, + }) + }) + }) + + it('should disable submit while request is in flight', async () => { + let resolveRequest: ((value: any) => void) | undefined + const requestPromise = new Promise((resolve) => { + resolveRequest = resolve + }) + mockSendForgotPasswordEmail.mockReturnValue(requestPromise as any) + + render() + + fireEvent.change(await screen.findByLabelText('login.email'), { target: { value: 'test@example.com' } }) + + const button = screen.getByRole('button', { name: /login\.sendResetLink/ }) + fireEvent.click(button) + + await waitFor(() => { + expect(button).toBeDisabled() + }) + + fireEvent.click(button) + expect(mockSendForgotPasswordEmail).toHaveBeenCalledTimes(1) + + resolveRequest?.({ result: 'success', data: 'ok' }) + + await waitFor(() => { + expect(screen.getByRole('button', { name: /login\.backToSignIn/ })).toBeInTheDocument() + }) + }) + + it('should keep form state when request fails', async () => { + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + mockSendForgotPasswordEmail.mockResolvedValue({ result: 'fail', data: 'error' } as any) + + render() + + fireEvent.change(await screen.findByLabelText('login.email'), { target: { value: 'test@example.com' } }) + fireEvent.click(screen.getByRole('button', { name: /login\.sendResetLink/ })) + + await waitFor(() => { + expect(mockSendForgotPasswordEmail).toHaveBeenCalledTimes(1) + }) + + expect(screen.getByRole('button', { name: /login\.sendResetLink/ })).toBeInTheDocument() + expect(mockPush).not.toHaveBeenCalled() + + consoleSpy.mockRestore() + }) + + it('should redirect to init when status is not started', async () => { + const originalLocation = window.location + Object.defineProperty(window, 'location', { + value: { href: '' }, + writable: true, + }) + mockFetchInitValidateStatus.mockResolvedValue({ status: 'not_started' } as InitValidateStatusResponse) + + render() + + await waitFor(() => { + expect(window.location.href).toBe('/init') + }) + + Object.defineProperty(window, 'location', { + value: originalLocation, + writable: true, + }) + }) +}) diff --git a/web/app/forgot-password/ForgotPasswordForm.tsx b/web/app/forgot-password/ForgotPasswordForm.tsx index 7299d24ebc..ff33cccc82 100644 --- a/web/app/forgot-password/ForgotPasswordForm.tsx +++ b/web/app/forgot-password/ForgotPasswordForm.tsx @@ -1,15 +1,16 @@ 'use client' import type { InitValidateStatusResponse } from '@/models/common' -import { zodResolver } from '@hookform/resolvers/zod' +import { useStore } from '@tanstack/react-form' import { useRouter } from 'next/navigation' import * as React from 'react' import { useEffect, useState } from 'react' -import { useForm } from 'react-hook-form' import { useTranslation } from 'react-i18next' import { z } from 'zod' import Button from '@/app/components/base/button' +import { formContext, useAppForm } from '@/app/components/base/form' +import { zodSubmitValidator } from '@/app/components/base/form/utils/zod-submit-validator' import { fetchInitValidateStatus, fetchSetupStatus, @@ -27,44 +28,45 @@ const accountFormSchema = z.object({ .email('error.emailInValid'), }) -type AccountFormValues = z.infer - const ForgotPasswordForm = () => { const { t } = useTranslation() const router = useRouter() const [loading, setLoading] = useState(true) const [isEmailSent, setIsEmailSent] = useState(false) - const { register, trigger, getValues, formState: { errors } } = useForm({ - resolver: zodResolver(accountFormSchema), + + const form = useAppForm({ defaultValues: { email: '' }, + validators: { + onSubmit: zodSubmitValidator(accountFormSchema), + }, + onSubmit: async ({ value }) => { + try { + const res = await sendForgotPasswordEmail({ + url: '/forgot-password', + body: { email: value.email }, + }) + if (res.result === 'success') + setIsEmailSent(true) + else console.error('Email verification failed') + } + catch (error) { + console.error('Request failed:', error) + } + }, }) - const handleSendResetPasswordEmail = async (email: string) => { - try { - const res = await sendForgotPasswordEmail({ - url: '/forgot-password', - body: { email }, - }) - if (res.result === 'success') - setIsEmailSent(true) - - else console.error('Email verification failed') - } - catch (error) { - console.error('Request failed:', error) - } - } + const isSubmitting = useStore(form.store, state => state.isSubmitting) + const emailErrors = useStore(form.store, state => state.fieldMeta.email?.errors) const handleSendResetPasswordClick = async () => { + if (isSubmitting) + return + if (isEmailSent) { router.push('/signin') } else { - const isValid = await trigger('email') - if (isValid) { - const email = getValues('email') - await handleSendResetPasswordEmail(email) - } + form.handleSubmit() } } @@ -94,30 +96,51 @@ const ForgotPasswordForm = () => {
-
- {!isEmailSent && ( -
- -
- - {errors.email && {t(`${errors.email?.message}` as 'error.emailInValid', { ns: 'login' })}} + + { + e.preventDefault() + e.stopPropagation() + form.handleSubmit() + }} + > + {!isEmailSent && ( +
+ +
+ + {field => ( + field.handleChange(e.target.value)} + onBlur={field.handleBlur} + placeholder={t('emailPlaceholder', { ns: 'login' }) || ''} + /> + )} + + {emailErrors && emailErrors.length > 0 && ( + + {t(`${emailErrors[0]}` as 'error.emailInValid', { ns: 'login' })} + + )} +
+ )} +
+
- )} -
- -
- + +
diff --git a/web/app/install/installForm.spec.tsx b/web/app/install/installForm.spec.tsx new file mode 100644 index 0000000000..74602f916a --- /dev/null +++ b/web/app/install/installForm.spec.tsx @@ -0,0 +1,158 @@ +import type { InitValidateStatusResponse, SetupStatusResponse } from '@/models/common' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { fetchInitValidateStatus, fetchSetupStatus, login, setup } from '@/service/common' +import { encryptPassword } from '@/utils/encryption' +import InstallForm from './installForm' + +const mockPush = vi.fn() +const mockReplace = vi.fn() + +vi.mock('next/navigation', () => ({ + useRouter: () => ({ push: mockPush, replace: mockReplace }), +})) + +vi.mock('@/service/common', () => ({ + fetchSetupStatus: vi.fn(), + fetchInitValidateStatus: vi.fn(), + setup: vi.fn(), + login: vi.fn(), + getSystemFeatures: vi.fn(), +})) + +const mockFetchSetupStatus = vi.mocked(fetchSetupStatus) +const mockFetchInitValidateStatus = vi.mocked(fetchInitValidateStatus) +const mockSetup = vi.mocked(setup) +const mockLogin = vi.mocked(login) + +const prepareLoadedState = () => { + mockFetchSetupStatus.mockResolvedValue({ step: 'not_started' } as SetupStatusResponse) + mockFetchInitValidateStatus.mockResolvedValue({ status: 'finished' } as InitValidateStatusResponse) +} + +describe('InstallForm', () => { + beforeEach(() => { + vi.clearAllMocks() + prepareLoadedState() + }) + + it('should render form after loading', async () => { + render() + + expect(await screen.findByLabelText('login.email')).toBeInTheDocument() + expect(screen.getByRole('button', { name: /login\.installBtn/ })).toBeInTheDocument() + }) + + it('should show validation error when required fields are empty', async () => { + render() + + await screen.findByLabelText('login.email') + + fireEvent.click(screen.getByRole('button', { name: /login\.installBtn/ })) + + await waitFor(() => { + expect(screen.getByText('login.error.emailInValid')).toBeInTheDocument() + expect(screen.getByText('login.error.nameEmpty')).toBeInTheDocument() + }) + expect(mockSetup).not.toHaveBeenCalled() + }) + + it('should submit and redirect to apps on successful login', async () => { + mockSetup.mockResolvedValue({ result: 'success' } as any) + mockLogin.mockResolvedValue({ result: 'success', data: { access_token: 'token' } } as any) + + render() + + fireEvent.change(await screen.findByLabelText('login.email'), { target: { value: 'admin@example.com' } }) + fireEvent.change(screen.getByLabelText('login.name'), { target: { value: 'Admin' } }) + fireEvent.change(screen.getByLabelText('login.password'), { target: { value: 'Password123' } }) + + const form = screen.getByRole('button', { name: /login\.installBtn/ }).closest('form') + expect(form).not.toBeNull() + + fireEvent.submit(form as HTMLFormElement) + + await waitFor(() => { + expect(mockSetup).toHaveBeenCalledWith({ + body: { + email: 'admin@example.com', + name: 'Admin', + password: 'Password123', + language: 'en', + }, + }) + }) + + await waitFor(() => { + expect(mockLogin).toHaveBeenCalledWith({ + url: '/login', + body: { + email: 'admin@example.com', + password: encryptPassword('Password123'), + }, + }) + }) + + await waitFor(() => { + expect(mockReplace).toHaveBeenCalledWith('/apps') + }) + }) + + it('should redirect to sign in when login fails', async () => { + mockSetup.mockResolvedValue({ result: 'success' } as any) + mockLogin.mockResolvedValue({ result: 'fail', data: 'error', code: 'login_failed', message: 'login failed' } as any) + + render() + + fireEvent.change(await screen.findByLabelText('login.email'), { target: { value: 'admin@example.com' } }) + fireEvent.change(screen.getByLabelText('login.name'), { target: { value: 'Admin' } }) + fireEvent.change(screen.getByLabelText('login.password'), { target: { value: 'Password123' } }) + + fireEvent.click(screen.getByRole('button', { name: /login\.installBtn/ })) + + await waitFor(() => { + expect(mockReplace).toHaveBeenCalledWith('/signin') + }) + }) + + it('should disable submit while request is in flight', async () => { + let resolveSetup: ((value: any) => void) | undefined + const setupPromise = new Promise((resolve) => { + resolveSetup = resolve + }) + mockSetup.mockReturnValue(setupPromise as any) + mockLogin.mockResolvedValue({ result: 'success', data: { access_token: 'token' } } as any) + + render() + + fireEvent.change(await screen.findByLabelText('login.email'), { target: { value: 'admin@example.com' } }) + fireEvent.change(screen.getByLabelText('login.name'), { target: { value: 'Admin' } }) + fireEvent.change(screen.getByLabelText('login.password'), { target: { value: 'Password123' } }) + + const button = screen.getByRole('button', { name: /login\.installBtn/ }) + fireEvent.click(button) + + await waitFor(() => { + expect(button).toBeDisabled() + }) + + fireEvent.click(button) + expect(mockSetup).toHaveBeenCalledTimes(1) + + resolveSetup?.({ result: 'success' }) + + await waitFor(() => { + expect(mockLogin).toHaveBeenCalledTimes(1) + }) + }) + + it('should redirect to sign in when setup is finished', async () => { + mockFetchSetupStatus.mockResolvedValue({ step: 'finished' } as SetupStatusResponse) + + render() + + await waitFor(() => { + expect(localStorage.setItem).toHaveBeenCalledWith('setup_status', 'finished') + expect(mockPush).toHaveBeenCalledWith('/signin') + }) + }) +}) diff --git a/web/app/install/installForm.tsx b/web/app/install/installForm.tsx index c43fbb4251..de32f18bc7 100644 --- a/web/app/install/installForm.tsx +++ b/web/app/install/installForm.tsx @@ -1,18 +1,17 @@ 'use client' -import type { SubmitHandler } from 'react-hook-form' import type { InitValidateStatusResponse, SetupStatusResponse } from '@/models/common' -import { zodResolver } from '@hookform/resolvers/zod' - -import { useDebounceFn } from 'ahooks' +import { useStore } from '@tanstack/react-form' import Link from 'next/link' import { useRouter } from 'next/navigation' import * as React from 'react' -import { useCallback, useEffect } from 'react' -import { useForm } from 'react-hook-form' +import { useEffect } from 'react' import { useTranslation } from 'react-i18next' import { z } from 'zod' import Button from '@/app/components/base/button' +import { formContext, useAppForm } from '@/app/components/base/form' +import { zodSubmitValidator } from '@/app/components/base/form/utils/zod-submit-validator' +import Input from '@/app/components/base/input' import { validPassword } from '@/config' import { useDocLink } from '@/context/i18n' @@ -33,8 +32,6 @@ const accountFormSchema = z.object({ }).regex(validPassword, 'error.passwordInvalid'), }) -type AccountFormValues = z.infer - const InstallForm = () => { useDocumentTitle('') const { t, i18n } = useTranslation() @@ -42,64 +39,49 @@ const InstallForm = () => { const router = useRouter() const [showPassword, setShowPassword] = React.useState(false) const [loading, setLoading] = React.useState(true) - const { - register, - handleSubmit, - formState: { errors, isSubmitting }, - } = useForm({ - resolver: zodResolver(accountFormSchema), + + const form = useAppForm({ defaultValues: { name: '', password: '', email: '', }, - }) + validators: { + onSubmit: zodSubmitValidator(accountFormSchema), + }, + onSubmit: async ({ value }) => { + // First, setup the admin account + await setup({ + body: { + ...value, + language: i18n.language, + }, + }) - const onSubmit: SubmitHandler = async (data) => { - // First, setup the admin account - await setup({ - body: { - ...data, - language: i18n.language, - }, - }) + // Then, automatically login with the same credentials + const loginRes = await login({ + url: '/login', + body: { + email: value.email, + password: encodePassword(value.password), + }, + }) - // Then, automatically login with the same credentials - const loginRes = await login({ - url: '/login', - body: { - email: data.email, - password: encodePassword(data.password), - }, - }) - - // Store tokens and redirect to apps if login successful - if (loginRes.result === 'success') { - router.replace('/apps') - } - else { - // Fallback to signin page if auto-login fails - router.replace('/signin') - } - } - - const handleSetting = async () => { - if (isSubmitting) - return - handleSubmit(onSubmit)() - } - - const { run: debouncedHandleKeyDown } = useDebounceFn( - (e: React.KeyboardEvent) => { - if (e.key === 'Enter') { - e.preventDefault() - handleSetting() + // Store tokens and redirect to apps if login successful + if (loginRes.result === 'success') { + router.replace('/apps') + } + else { + // Fallback to signin page if auto-login fails + router.replace('/signin') } }, - { wait: 200 }, - ) + }) - const handleKeyDown = useCallback(debouncedHandleKeyDown, [debouncedHandleKeyDown]) + const isSubmitting = useStore(form.store, state => state.isSubmitting) + const emailErrors = useStore(form.store, state => state.fieldMeta.email?.errors) + const nameErrors = useStore(form.store, state => state.fieldMeta.name?.errors) + const passwordErrors = useStore(form.store, state => state.fieldMeta.password?.errors) useEffect(() => { fetchSetupStatus().then((res: SetupStatusResponse) => { @@ -128,76 +110,111 @@ const InstallForm = () => {
-
-
- -
- - {errors.email && {t(`${errors.email?.message}` as 'error.emailInValid', { ns: 'login' })}} -
- -
- -
- -
- -
- {errors.name && {t(`${errors.name.message}` as 'error.nameEmpty', { ns: 'login' })}} -
- -
- -
- - -
- + + { + e.preventDefault() + e.stopPropagation() + if (isSubmitting) + return + form.handleSubmit() + }} + > +
+ +
+ + {field => ( + field.handleChange(e.target.value)} + onBlur={field.handleBlur} + placeholder={t('emailPlaceholder', { ns: 'login' }) || ''} + /> + )} + + {emailErrors && emailErrors.length > 0 && ( + + {t(`${emailErrors[0]}` as 'error.emailInValid', { ns: 'login' })} + + )}
-
- {t('error.passwordInvalid', { ns: 'login' })} +
+ +
+ + {field => ( + field.handleChange(e.target.value)} + onBlur={field.handleBlur} + placeholder={t('namePlaceholder', { ns: 'login' }) || ''} + /> + )} + +
+ {nameErrors && nameErrors.length > 0 && ( + + {t(`${nameErrors[0]}` as 'error.nameEmpty', { ns: 'login' })} + + )}
-
-
- -
- +
+ +
+ + {field => ( + field.handleChange(e.target.value)} + onBlur={field.handleBlur} + placeholder={t('passwordPlaceholder', { ns: 'login' }) || ''} + /> + )} + + +
+ +
+
+ +
0, + })} + > + {t('error.passwordInvalid', { ns: 'login' })} +
+
+ +
+ +
+ +
{t('license.tip', { ns: 'login' })} -   +   - - - {children} - - + + + + {children} + + + diff --git a/web/app/signin/components/mail-and-password-auth.tsx b/web/app/signin/components/mail-and-password-auth.tsx index 101ddf559a..01ec4c74ed 100644 --- a/web/app/signin/components/mail-and-password-auth.tsx +++ b/web/app/signin/components/mail-and-password-auth.tsx @@ -11,6 +11,7 @@ import Toast from '@/app/components/base/toast' import { emailRegex } from '@/config' import { useLocale } from '@/context/i18n' import { login } from '@/service/common' +import { setWebAppAccessToken } from '@/service/webapp-auth' import { encryptPassword } from '@/utils/encryption' import { resolvePostLoginRedirect } from '../utils/post-login-redirect' @@ -65,6 +66,7 @@ export default function MailAndPasswordAuth({ isInvite, isEmailSetup, allowRegis }) if (res.result === 'success') { // Track login success event + setWebAppAccessToken(res.data.access_token) trackEvent('user_login_success', { method: 'email_password', is_invite: isInvite, diff --git a/web/app/signin/normal-form.tsx b/web/app/signin/normal-form.tsx index bf4d80326b..be0feea6c1 100644 --- a/web/app/signin/normal-form.tsx +++ b/web/app/signin/normal-form.tsx @@ -28,7 +28,8 @@ const NormalForm = () => { const message = decodeURIComponent(searchParams.get('message') || '') const invite_token = decodeURIComponent(searchParams.get('invite_token') || '') const [isInitCheckLoading, setInitCheckLoading] = useState(true) - const isLoading = isCheckLoading || loginData?.logged_in || isInitCheckLoading + const [isRedirecting, setIsRedirecting] = useState(false) + const isLoading = isCheckLoading || isInitCheckLoading || isRedirecting const { systemFeatures } = useGlobalPublicStore() const [authType, updateAuthType] = useState<'code' | 'password'>('password') const [showORLine, setShowORLine] = useState(false) @@ -40,6 +41,7 @@ const NormalForm = () => { const init = useCallback(async () => { try { if (isLoggedIn) { + setIsRedirecting(true) const redirectUrl = resolvePostLoginRedirect(searchParams) router.replace(redirectUrl || '/apps') return diff --git a/web/context/i18n.ts b/web/context/i18n.ts index e65049b506..d57fc5b984 100644 --- a/web/context/i18n.ts +++ b/web/context/i18n.ts @@ -1,10 +1,10 @@ import type { Locale } from '@/i18n-config/language' -import { atom, useAtomValue } from 'jotai' +import { useTranslation } from '#i18n' import { getDocLanguage, getLanguage, getPricingPageLanguage } from '@/i18n-config/language' -export const localeAtom = atom('en-US') export const useLocale = () => { - return useAtomValue(localeAtom) + const { i18n } = useTranslation() + return i18n.language as Locale } export const useGetLanguage = () => { diff --git a/web/eslint.config.mjs b/web/eslint.config.mjs index 2cfe2e5e13..b8191a5eea 100644 --- a/web/eslint.config.mjs +++ b/web/eslint.config.mjs @@ -179,7 +179,7 @@ export default antfu( // 'dify-i18n/no-as-any-in-t': ['error', { mode: 'all' }], 'dify-i18n/no-as-any-in-t': 'error', // 'dify-i18n/no-legacy-namespace-prefix': 'error', - 'dify-i18n/require-ns-option': 'error', + // 'dify-i18n/require-ns-option': 'error', }, }, // i18n JSON validation rules diff --git a/web/i18n-config/client.ts b/web/i18n-config/client.ts new file mode 100644 index 0000000000..17d3dceae1 --- /dev/null +++ b/web/i18n-config/client.ts @@ -0,0 +1,35 @@ +'use client' +import type { Resource } from 'i18next' +import type { Locale } from '.' +import type { NamespaceCamelCase, NamespaceKebabCase } from './resources' +import { kebabCase } from 'es-toolkit/string' +import { createInstance } from 'i18next' +import resourcesToBackend from 'i18next-resources-to-backend' +import { getI18n, initReactI18next } from 'react-i18next' +import { getInitOptions } from './settings' + +export function createI18nextInstance(lng: Locale, resources: Resource) { + const instance = createInstance() + instance + .use(initReactI18next) + .use(resourcesToBackend(( + language: Locale, + namespace: NamespaceKebabCase | NamespaceCamelCase, + ) => { + const namespaceKebab = kebabCase(namespace) + return import(`../i18n/${language}/${namespaceKebab}.json`) + })) + .init({ + ...getInitOptions(), + lng, + resources, + }) + return instance +} + +export const changeLanguage = async (lng?: Locale) => { + if (!lng) + return + const i18n = getI18n() + await i18n.changeLanguage(lng) +} diff --git a/web/i18n-config/index.ts b/web/i18n-config/index.ts index bb73ef4b71..e24fd8533d 100644 --- a/web/i18n-config/index.ts +++ b/web/i18n-config/index.ts @@ -2,7 +2,7 @@ import type { Locale } from '@/i18n-config/language' import Cookies from 'js-cookie' import { LOCALE_COOKIE_NAME } from '@/config' -import { changeLanguage } from '@/i18n-config/i18next-config' +import { changeLanguage } from '@/i18n-config/client' import { LanguagesSupported } from '@/i18n-config/language' export const i18n = { @@ -19,10 +19,6 @@ export const setLocaleOnClient = async (locale: Locale, reloadPage = true) => { location.reload() } -export const getLocaleOnClient = (): Locale => { - return Cookies.get(LOCALE_COOKIE_NAME) as Locale || i18n.defaultLocale -} - export const renderI18nObject = (obj: Record, language: string) => { if (!obj) return '' diff --git a/web/i18n-config/lib.client.ts b/web/i18n-config/lib.client.ts new file mode 100644 index 0000000000..fffb4d95ae --- /dev/null +++ b/web/i18n-config/lib.client.ts @@ -0,0 +1,10 @@ +'use client' + +import type { NamespaceCamelCase } from './resources' +import { useTranslation as useTranslationOriginal } from 'react-i18next' + +export function useTranslation(ns?: NamespaceCamelCase) { + return useTranslationOriginal(ns) +} + +export { useLocale } from '@/context/i18n' diff --git a/web/i18n-config/lib.server.ts b/web/i18n-config/lib.server.ts new file mode 100644 index 0000000000..4727ed482f --- /dev/null +++ b/web/i18n-config/lib.server.ts @@ -0,0 +1,16 @@ +import type { NamespaceCamelCase } from './resources' +import { use } from 'react' +import { getLocaleOnServer, getTranslation } from './server' + +async function getI18nConfig(ns?: NamespaceCamelCase) { + const lang = await getLocaleOnServer() + return getTranslation(lang, ns) +} + +export function useTranslation(ns?: NamespaceCamelCase) { + return use(getI18nConfig(ns)) +} + +export function useLocale() { + return use(getLocaleOnServer()) +} diff --git a/web/i18n-config/i18next-config.ts b/web/i18n-config/resources.ts similarity index 59% rename from web/i18n-config/i18next-config.ts rename to web/i18n-config/resources.ts index 0997485967..4bcfb98e14 100644 --- a/web/i18n-config/i18next-config.ts +++ b/web/i18n-config/resources.ts @@ -1,8 +1,4 @@ -'use client' -import type { Locale } from '.' -import { camelCase, kebabCase } from 'es-toolkit/string' -import i18n from 'i18next' -import { initReactI18next } from 'react-i18next' +import { kebabCase } from 'es-toolkit/string' import appAnnotation from '../i18n/en-US/app-annotation.json' import appApi from '../i18n/en-US/app-api.json' import appDebug from '../i18n/en-US/app-debug.json' @@ -35,7 +31,7 @@ import tools from '../i18n/en-US/tools.json' import workflow from '../i18n/en-US/workflow.json' // @keep-sorted -export const resources = { +const resources = { app, appAnnotation, appApi, @@ -82,60 +78,5 @@ export type Resources = typeof resources export type NamespaceCamelCase = keyof Resources export type NamespaceKebabCase = KebabCase -const requireSilent = async (lang: Locale, namespace: NamespaceKebabCase) => { - let res - try { - res = (await import(`../i18n/${lang}/${namespace}.json`)).default - } - catch { - res = (await import(`../i18n/en-US/${namespace}.json`)).default - } - - return res -} - -const NAMESPACES = Object.keys(resources).map(kebabCase) as NamespaceKebabCase[] - -// Load a single namespace for a language -export const loadNamespace = async (lang: Locale, ns: NamespaceKebabCase) => { - const camelNs = camelCase(ns) as NamespaceCamelCase - if (i18n.hasResourceBundle(lang, camelNs)) - return - - const resource = await requireSilent(lang, ns) - i18n.addResourceBundle(lang, camelNs, resource, true, true) -} - -// Load all namespaces for a language (used when switching language) -export const loadLangResources = async (lang: Locale) => { - await Promise.all( - NAMESPACES.map(ns => loadNamespace(lang, ns)), - ) -} - -// Initial resources: load en-US namespaces for fallback/default locale -const getInitialTranslations = () => { - return { - 'en-US': resources, - } -} - -if (!i18n.isInitialized) { - i18n.use(initReactI18next).init({ - lng: undefined, - fallbackLng: 'en-US', - resources: getInitialTranslations(), - defaultNS: 'common', - ns: Object.keys(resources), - keySeparator: false, - }) -} - -export const changeLanguage = async (lng?: Locale) => { - if (!lng) - return - await loadLangResources(lng) - await i18n.changeLanguage(lng) -} - -export default i18n +export const namespacesCamelCase = Object.keys(resources) as NamespaceCamelCase[] +export const namespacesKebabCase = namespacesCamelCase.map(ns => kebabCase(ns)) as NamespaceKebabCase[] diff --git a/web/i18n-config/server.ts b/web/i18n-config/server.ts index 4912e86323..403040c134 100644 --- a/web/i18n-config/server.ts +++ b/web/i18n-config/server.ts @@ -1,15 +1,19 @@ -import type { i18n as I18nInstance } from 'i18next' +import type { i18n as I18nInstance, Resource, ResourceLanguage } from 'i18next' import type { Locale } from '.' -import type { NamespaceCamelCase, NamespaceKebabCase } from './i18next-config' +import type { NamespaceCamelCase, NamespaceKebabCase } from './resources' import { match } from '@formatjs/intl-localematcher' -import { camelCase, kebabCase } from 'es-toolkit/compat' +import { kebabCase } from 'es-toolkit/compat' +import { camelCase } from 'es-toolkit/string' import { createInstance } from 'i18next' import resourcesToBackend from 'i18next-resources-to-backend' import Negotiator from 'negotiator' import { cookies, headers } from 'next/headers' +import { cache } from 'react' import { initReactI18next } from 'react-i18next/initReactI18next' -import serverOnlyContext from '@/utils/server-only-context' +import { serverOnlyContext } from '@/utils/server-only-context' import { i18n } from '.' +import { namespacesKebabCase } from './resources' +import { getInitOptions } from './settings' const [getLocaleCache, setLocaleCache] = serverOnlyContext(null) const [getI18nInstance, setI18nInstance] = serverOnlyContext(null) @@ -27,23 +31,21 @@ const getOrCreateI18next = async (lng: Locale) => { return import(`../i18n/${language}/${fileNamespace}.json`) })) .init({ + ...getInitOptions(), lng, - fallbackLng: 'en-US', - keySeparator: false, }) setI18nInstance(instance) return instance } -export async function getTranslation(lng: Locale, ns: NamespaceKebabCase) { - const camelNs = camelCase(ns) as NamespaceCamelCase +export async function getTranslation(lng: Locale, ns?: NamespaceCamelCase) { const i18nextInstance = await getOrCreateI18next(lng) - if (!i18nextInstance.hasLoadedNamespace(camelNs)) - await i18nextInstance.loadNamespaces(camelNs) + if (ns && !i18nextInstance.hasLoadedNamespace(ns)) + await i18nextInstance.loadNamespaces(ns) return { - t: i18nextInstance.getFixedT(lng, camelNs), + t: i18nextInstance.getFixedT(lng, ns), i18n: i18nextInstance, } } @@ -77,3 +79,16 @@ export const getLocaleOnServer = async (): Promise => { setLocaleCache(matchedLocale) return matchedLocale } + +export const getResources = cache(async (lng: Locale): Promise => { + const messages = {} as ResourceLanguage + + await Promise.all( + (namespacesKebabCase).map(async (ns) => { + const mod = await import(`../i18n/${lng}/${ns}.json`) + messages[camelCase(ns)] = mod.default + }), + ) + + return { [lng]: messages } +}) diff --git a/web/i18n-config/settings.ts b/web/i18n-config/settings.ts new file mode 100644 index 0000000000..1bf37ab21d --- /dev/null +++ b/web/i18n-config/settings.ts @@ -0,0 +1,13 @@ +import type { InitOptions } from 'i18next' +import { namespacesCamelCase } from './resources' + +export function getInitOptions(): InitOptions { + return { + // We do not have en for fallback + load: 'currentOnly', + fallbackLng: 'en-US', + partialBundledLanguages: true, + keySeparator: false, + ns: namespacesCamelCase, + } +} diff --git a/web/package.json b/web/package.json index 05933c08f7..e94f2fd441 100644 --- a/web/package.json +++ b/web/package.json @@ -3,7 +3,13 @@ "type": "module", "version": "1.11.2", "private": true, - "packageManager": "pnpm@10.26.2+sha512.0e308ff2005fc7410366f154f625f6631ab2b16b1d2e70238444dd6ae9d630a8482d92a451144debc492416896ed16f7b114a86ec68b8404b2443869e68ffda6", + "packageManager": "pnpm@10.27.0+sha512.72d699da16b1179c14ba9e64dc71c9a40988cbdc65c264cb0e489db7de917f20dcf4d64d8723625f2969ba52d4b7e2a1170682d9ac2a5dcaeaab732b7e16f04a", + "imports": { + "#i18n": { + "react-server": "./i18n-config/lib.server.ts", + "default": "./i18n-config/lib.client.ts" + } + }, "engines": { "node": ">=v22.11.0" }, @@ -54,7 +60,6 @@ "@formatjs/intl-localematcher": "^0.5.10", "@headlessui/react": "2.2.1", "@heroicons/react": "^2.2.0", - "@hookform/resolvers": "^5.2.2", "@lexical/code": "^0.38.2", "@lexical/link": "^0.38.2", "@lexical/list": "^0.38.2", @@ -117,7 +122,6 @@ "react-18-input-autosize": "^3.0.0", "react-dom": "19.2.3", "react-easy-crop": "^5.5.3", - "react-hook-form": "^7.65.0", "react-hotkeys-hook": "^4.6.2", "react-i18next": "^16.5.0", "react-markdown": "^9.1.0", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index d9f90b62f3..a2e95dadd7 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -78,9 +78,6 @@ importers: '@heroicons/react': specifier: ^2.2.0 version: 2.2.0(react@19.2.3) - '@hookform/resolvers': - specifier: ^5.2.2 - version: 5.2.2(react-hook-form@7.68.0(react@19.2.3)) '@lexical/code': specifier: ^0.38.2 version: 0.38.2 @@ -267,9 +264,6 @@ importers: react-easy-crop: specifier: ^5.5.3 version: 5.5.6(react-dom@19.2.3(react@19.2.3))(react@19.2.3) - react-hook-form: - specifier: ^7.65.0 - version: 7.68.0(react@19.2.3) react-hotkeys-hook: specifier: ^4.6.2 version: 4.6.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3) @@ -1878,11 +1872,6 @@ packages: peerDependencies: react: '>= 16 || ^19.0.0-rc' - '@hookform/resolvers@5.2.2': - resolution: {integrity: sha512-A/IxlMLShx3KjV/HeTcTfaMxdwy690+L/ZADoeaTltLx+CVuzkeVIPuybK3jrRfw7YZnmdKsVVHAlEPIAEUNlA==} - peerDependencies: - react-hook-form: ^7.55.0 - '@humanfs/core@0.19.1': resolution: {integrity: sha512-5DyQ4+1JEUzejeK1JGICcideyfUbGixgS9jNgex5nqkW+cY7WZhxBigmieN5Qnw9ZosSNVC9KQKyb+GUaGyKUA==} engines: {node: '>=18.18.0'} @@ -3211,9 +3200,6 @@ packages: '@standard-schema/spec@1.1.0': resolution: {integrity: sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==} - '@standard-schema/utils@0.3.0': - resolution: {integrity: sha512-e7Mew686owMaPJVNNLs55PUvgz371nKgwsc4vxE49zsODpJEnxgxRo2y/OKrqueavXgZNMDVj3DdHFlaSAeU8g==} - '@storybook/addon-docs@9.1.13': resolution: {integrity: sha512-V1nCo7bfC3kQ5VNVq0VDcHsIhQf507m+BxMA5SIYiwdJHljH2BXpW2fL3FFn9gv9Wp57AEEzhm+wh4zANaJgkg==} peerDependencies: @@ -7436,12 +7422,6 @@ packages: react-fast-compare@3.2.2: resolution: {integrity: sha512-nsO+KSNgo1SbJqJEYRE9ERzo7YtYbou/OqjSQKxV7jcKox7+usiUVZOAC+XnDOABXggQTno0Y1CpVnuWEc1boQ==} - react-hook-form@7.68.0: - resolution: {integrity: sha512-oNN3fjrZ/Xo40SWlHf1yCjlMK417JxoSJVUXQjGdvdRCU07NTFei1i1f8ApUAts+IVh14e4EdakeLEA+BEAs/Q==} - engines: {node: '>=18.0.0'} - peerDependencies: - react: ^16.8.0 || ^17 || ^18 || ^19 - react-hotkeys-hook@4.6.2: resolution: {integrity: sha512-FmP+ZriY3EG59Ug/lxNfrObCnW9xQShgk7Nb83+CkpfkcCpfS95ydv+E9JuXA5cp8KtskU7LGlIARpkc92X22Q==} peerDependencies: @@ -10516,11 +10496,6 @@ snapshots: dependencies: react: 19.2.3 - '@hookform/resolvers@5.2.2(react-hook-form@7.68.0(react@19.2.3))': - dependencies: - '@standard-schema/utils': 0.3.0 - react-hook-form: 7.68.0(react@19.2.3) - '@humanfs/core@0.19.1': {} '@humanfs/node@0.16.7': @@ -11782,8 +11757,6 @@ snapshots: '@standard-schema/spec@1.1.0': {} - '@standard-schema/utils@0.3.0': {} - '@storybook/addon-docs@9.1.13(@types/react@19.2.7)(storybook@9.1.17(@testing-library/dom@10.4.1)(vite@7.3.0(@types/node@18.15.0)(jiti@1.21.7)(sass@1.95.0)(terser@5.44.1)(tsx@4.21.0)(yaml@2.8.2)))': dependencies: '@mdx-js/react': 3.1.1(@types/react@19.2.7)(react@19.2.3) @@ -16931,10 +16904,6 @@ snapshots: react-fast-compare@3.2.2: {} - react-hook-form@7.68.0(react@19.2.3): - dependencies: - react: 19.2.3 - react-hotkeys-hook@4.6.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3): dependencies: react: 19.2.3 diff --git a/web/service/use-apps.ts b/web/service/use-apps.ts index 0f6c4a64ac..d16d44af20 100644 --- a/web/service/use-apps.ts +++ b/web/service/use-apps.ts @@ -12,6 +12,7 @@ import type { } from '@/models/app' import type { App, AppModeEnum } from '@/types/app' import { + keepPreviousData, useInfiniteQuery, useQuery, useQueryClient, @@ -107,6 +108,7 @@ export const useInfiniteAppList = (params: AppListParams, options?: { enabled?: queryFn: ({ pageParam = normalizedParams.page }) => get('/apps', { params: { ...normalizedParams, page: pageParam } }), getNextPageParam: lastPage => lastPage.has_more ? lastPage.page + 1 : undefined, initialPageParam: normalizedParams.page, + placeholderData: keepPreviousData, ...options, }) } diff --git a/web/service/use-common.ts b/web/service/use-common.ts index a1edb041c0..ca0845d95a 100644 --- a/web/service/use-common.ts +++ b/web/service/use-common.ts @@ -221,13 +221,12 @@ export const useIsLogin = () => { await get('/account/profile', {}, { silent: true, }) - } - catch (e: any) { - if (e.status === 401) - return { logged_in: false } return { logged_in: true } } - return { logged_in: true } + catch { + // Any error (401, 500, network error, etc.) means not logged in + return { logged_in: false } + } }, }) } diff --git a/web/service/use-dataset-card.ts b/web/service/use-dataset-card.ts new file mode 100644 index 0000000000..05365479dc --- /dev/null +++ b/web/service/use-dataset-card.ts @@ -0,0 +1,18 @@ +import { useMutation } from '@tanstack/react-query' +import { checkIsUsedInApp, deleteDataset } from './datasets' + +const NAME_SPACE = 'dataset-card' + +export const useCheckDatasetUsage = () => { + return useMutation({ + mutationKey: [NAME_SPACE, 'check-usage'], + mutationFn: (datasetId: string) => checkIsUsedInApp(datasetId), + }) +} + +export const useDeleteDataset = () => { + return useMutation({ + mutationKey: [NAME_SPACE, 'delete'], + mutationFn: (datasetId: string) => deleteDataset(datasetId), + }) +} diff --git a/web/types/app.ts b/web/types/app.ts index eb1b29bb60..c8c409a22c 100644 --- a/web/types/app.ts +++ b/web/types/app.ts @@ -316,7 +316,7 @@ export type SiteConfig = { use_icon_as_answer_icon: boolean } -export type AppIconType = 'image' | 'emoji' +export type AppIconType = 'image' | 'emoji' | 'link' /** * App diff --git a/web/types/i18n.d.ts b/web/types/i18n.d.ts index 160d385730..9e20d5a55a 100644 --- a/web/types/i18n.d.ts +++ b/web/types/i18n.d.ts @@ -1,4 +1,4 @@ -import type { NamespaceCamelCase, Resources } from '../i18n-config/i18next-config' +import type { NamespaceCamelCase, Resources } from '../i18n-config/resources' import 'i18next' declare module 'i18next' { diff --git a/web/utils/server-only-context.ts b/web/utils/server-only-context.ts index e58dbfe98b..43fe743811 100644 --- a/web/utils/server-only-context.ts +++ b/web/utils/server-only-context.ts @@ -2,7 +2,7 @@ import { cache } from 'react' -export default (defaultValue: T): [() => T, (v: T) => void] => { +export function serverOnlyContext(defaultValue: T): [() => T, (v: T) => void] { const getRef = cache(() => ({ current: defaultValue })) const getValue = (): T => getRef().current