diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index b4fc44767a..1ac55b5e8d 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from typing import Any from flask_restx import Resource from pydantic import BaseModel, Field @@ -12,10 +11,12 @@ from controllers.console.app.error import ( ProviderQuotaExceededError, ) from controllers.console.wraps import account_initialization_required, setup_required +from core.app.app_config.entities import ModelConfig from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider +from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db @@ -26,28 +27,13 @@ from services.workflow_service import WorkflowService DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" -class RuleGeneratePayload(BaseModel): - instruction: str = Field(..., description="Rule generation instruction") - model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") - no_variable: bool = Field(default=False, description="Whether to exclude variables") - - -class RuleCodeGeneratePayload(RuleGeneratePayload): - code_language: str = Field(default="javascript", description="Programming language for code generation") - - -class RuleStructuredOutputPayload(BaseModel): - instruction: str = Field(..., description="Structured output generation instruction") - model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") - - class InstructionGeneratePayload(BaseModel): flow_id: str = Field(..., description="Workflow/Flow ID") node_id: str = Field(default="", description="Node ID for workflow context") current: str = Field(default="", description="Current instruction text") language: str = Field(default="javascript", description="Programming language (javascript/python)") instruction: str = Field(..., description="Instruction for generation") - model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") + model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration") ideal_output: str = Field(default="", description="Expected ideal output") @@ -64,6 +50,7 @@ reg(RuleCodeGeneratePayload) reg(RuleStructuredOutputPayload) reg(InstructionGeneratePayload) reg(InstructionTemplatePayload) +reg(ModelConfig) @console_ns.route("/rule-generate") @@ -82,12 +69,7 @@ class RuleGenerateApi(Resource): _, current_tenant_id = current_account_with_tenant() try: - rules = LLMGenerator.generate_rule_config( - tenant_id=current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - no_variable=args.no_variable, - ) + rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except QuotaExceededError: @@ -118,9 +100,7 @@ class RuleCodeGenerateApi(Resource): try: code_result = LLMGenerator.generate_code( tenant_id=current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - code_language=args.code_language, + args=args, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -152,8 +132,7 @@ class RuleStructuredOutputGenerateApi(Resource): try: structured_output = LLMGenerator.generate_structured_output( tenant_id=current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, + args=args, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -204,23 +183,29 @@ class InstructionGenerateApi(Resource): case "llm": return LLMGenerator.generate_rule_config( current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - no_variable=True, + args=RuleGeneratePayload( + instruction=args.instruction, + model_config=args.model_config_data, + no_variable=True, + ), ) case "agent": return LLMGenerator.generate_rule_config( current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - no_variable=True, + args=RuleGeneratePayload( + instruction=args.instruction, + model_config=args.model_config_data, + no_variable=True, + ), ) case "code": return LLMGenerator.generate_code( tenant_id=current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - code_language=args.language, + args=RuleCodeGeneratePayload( + instruction=args.instruction, + model_config=args.model_config_data, + code_language=args.language, + ), ) case _: return {"error": f"invalid node type: {node_type}"} diff --git a/api/core/llm_generator/entities.py b/api/core/llm_generator/entities.py new file mode 100644 index 0000000000..3bb8d2c899 --- /dev/null +++ b/api/core/llm_generator/entities.py @@ -0,0 +1,20 @@ +"""Shared payload models for LLM generator helpers and controllers.""" + +from pydantic import BaseModel, Field + +from core.app.app_config.entities import ModelConfig + + +class RuleGeneratePayload(BaseModel): + instruction: str = Field(..., description="Rule generation instruction") + model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration") + no_variable: bool = Field(default=False, description="Whether to exclude variables") + + +class RuleCodeGeneratePayload(RuleGeneratePayload): + code_language: str = Field(default="javascript", description="Programming language for code generation") + + +class RuleStructuredOutputPayload(BaseModel): + instruction: str = Field(..., description="Structured output generation instruction") + model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration") diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index be1e306d47..5b2c640265 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -6,6 +6,8 @@ from typing import Protocol, cast import json_repair +from core.app.app_config.entities import ModelConfig +from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.llm_generator.prompts import ( @@ -151,19 +153,19 @@ class LLMGenerator: return questions @classmethod - def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool): + def generate_rule_config(cls, tenant_id: str, args: RuleGeneratePayload): output_parser = RuleConfigGeneratorOutputParser() error = "" error_step = "" rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""} - model_parameters = model_config.get("completion_params", {}) - if no_variable: + model_parameters = args.model_config_data.completion_params + if args.no_variable: prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE) prompt_generate = prompt_template.format( inputs={ - "TASK_DESCRIPTION": instruction, + "TASK_DESCRIPTION": args.instruction, }, remove_template_variables=False, ) @@ -175,8 +177,8 @@ class LLMGenerator: model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + provider=args.model_config_data.provider, + model=args.model_config_data.name, ) try: @@ -190,7 +192,7 @@ class LLMGenerator: error = str(e) error_step = "generate rule config" except Exception as e: - logger.exception("Failed to generate rule config, model: %s", model_config.get("name")) + logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name) rule_config["error"] = str(e) rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" @@ -209,7 +211,7 @@ class LLMGenerator: # format the prompt_generate_prompt prompt_generate_prompt = prompt_template.format( inputs={ - "TASK_DESCRIPTION": instruction, + "TASK_DESCRIPTION": args.instruction, }, remove_template_variables=False, ) @@ -220,8 +222,8 @@ class LLMGenerator: model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + provider=args.model_config_data.provider, + model=args.model_config_data.name, ) try: @@ -250,7 +252,7 @@ class LLMGenerator: # the second step to generate the task_parameter and task_statement statement_generate_prompt = statement_template.format( inputs={ - "TASK_DESCRIPTION": instruction, + "TASK_DESCRIPTION": args.instruction, "INPUT_TEXT": prompt_content.message.get_text_content(), }, remove_template_variables=False, @@ -276,7 +278,7 @@ class LLMGenerator: error_step = "generate conversation opener" except Exception as e: - logger.exception("Failed to generate rule config, model: %s", model_config.get("name")) + logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name) rule_config["error"] = str(e) rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" @@ -284,16 +286,20 @@ class LLMGenerator: return rule_config @classmethod - def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"): - if code_language == "python": + def generate_code( + cls, + tenant_id: str, + args: RuleCodeGeneratePayload, + ): + if args.code_language == "python": prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) else: prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE) prompt = prompt_template.format( inputs={ - "INSTRUCTION": instruction, - "CODE_LANGUAGE": code_language, + "INSTRUCTION": args.instruction, + "CODE_LANGUAGE": args.code_language, }, remove_template_variables=False, ) @@ -302,28 +308,28 @@ class LLMGenerator: model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + provider=args.model_config_data.provider, + model=args.model_config_data.name, ) prompt_messages = [UserPromptMessage(content=prompt)] - model_parameters = model_config.get("completion_params", {}) + model_parameters = args.model_config_data.completion_params try: response: LLMResult = model_instance.invoke_llm( prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False ) generated_code = response.message.get_text_content() - return {"code": generated_code, "language": code_language, "error": ""} + return {"code": generated_code, "language": args.code_language, "error": ""} except InvokeError as e: error = str(e) - return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"} + return {"code": "", "language": args.code_language, "error": f"Failed to generate code. Error: {error}"} except Exception as e: logger.exception( - "Failed to invoke LLM model, model: %s, language: %s", model_config.get("name"), code_language + "Failed to invoke LLM model, model: %s, language: %s", args.model_config_data.name, args.code_language ) - return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"} + return {"code": "", "language": args.code_language, "error": f"An unexpected error occurred: {str(e)}"} @classmethod def generate_qa_document(cls, tenant_id: str, query, document_language: str): @@ -353,20 +359,20 @@ class LLMGenerator: return answer.strip() @classmethod - def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict): + def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload): model_manager = ModelManager() model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + provider=args.model_config_data.provider, + model=args.model_config_data.name, ) prompt_messages = [ SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE), - UserPromptMessage(content=instruction), + UserPromptMessage(content=args.instruction), ] - model_parameters = model_config.get("model_parameters", {}) + model_parameters = args.model_config_data.completion_params try: response: LLMResult = model_instance.invoke_llm( @@ -390,12 +396,17 @@ class LLMGenerator: error = str(e) return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"} except Exception as e: - logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name")) + logger.exception("Failed to invoke LLM model, model: %s", args.model_config_data.name) return {"output": "", "error": f"An unexpected error occurred: {str(e)}"} @staticmethod def instruction_modify_legacy( - tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None + tenant_id: str, + flow_id: str, + current: str, + instruction: str, + model_config: ModelConfig, + ideal_output: str | None, ): last_run: Message | None = ( db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() @@ -434,7 +445,7 @@ class LLMGenerator: node_id: str, current: str, instruction: str, - model_config: dict, + model_config: ModelConfig, ideal_output: str | None, workflow_service: WorkflowServiceInterface, ): @@ -505,7 +516,7 @@ class LLMGenerator: @staticmethod def __instruction_modify_common( tenant_id: str, - model_config: dict, + model_config: ModelConfig, last_run: dict | None, current: str | None, error_message: str | None, @@ -526,8 +537,8 @@ class LLMGenerator: model_instance = ModelManager().get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider", ""), - model=model_config.get("name", ""), + provider=model_config.provider, + model=model_config.name, ) match node_type: case "llm" | "agent": @@ -570,7 +581,5 @@ class LLMGenerator: error = str(e) return {"error": f"Failed to generate code. Error: {error}"} except Exception as e: - logger.exception( - "Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=True - ) + logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.name), exc_info=True) return {"error": f"An unexpected error occurred: {str(e)}"}