diff --git a/api/.ruff.toml b/api/.ruff.toml index 3301452ad9..64a461443b 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -106,10 +106,10 @@ ignore = [ "N803", # invalid-argument-name ] "tests/*" = [ - "F811", # redefined-while-unused - "T201", # allow print in tests, - "S110", # allow ignoring exceptions in tests code (currently) - + "F811", # redefined-while-unused + "T201", # allow print in tests, + "S110", # allow ignoring exceptions in tests code (currently) + "PT019", # @patch-injected params look like unused fixtures ] "controllers/console/explore/trial.py" = ["TID251"] "controllers/console/human_input_form.py" = ["TID251"] diff --git a/api/README.md b/api/README.md index 9d89b490b0..9871d2c311 100644 --- a/api/README.md +++ b/api/README.md @@ -122,7 +122,7 @@ These commands assume you start from the repository root. ```bash cd api - uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention + uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,enterprise_telemetry ``` 1. Optional: start Celery Beat (scheduled tasks, in a new terminal). diff --git a/api/app_factory.py b/api/app_factory.py index dcbc821687..11568f139f 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -81,6 +81,7 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_compress, ext_database, + ext_enterprise_telemetry, ext_fastopenapi, ext_forward_refs, ext_hosting_provider, @@ -131,6 +132,7 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_fastopenapi, ext_otel, + ext_enterprise_telemetry, ext_request_logging, ext_session_factory, ] diff --git a/api/configs/app_config.py b/api/configs/app_config.py index d3b1cf9d5b..831f0a49e0 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -8,7 +8,7 @@ from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, Settings from libs.file_utils import search_file_upwards from .deploy import DeploymentConfig -from .enterprise import EnterpriseFeatureConfig +from .enterprise import EnterpriseFeatureConfig, EnterpriseTelemetryConfig from .extra import ExtraServiceConfig from .feature import FeatureConfig from .middleware import MiddlewareConfig @@ -73,6 +73,8 @@ class DifyConfig( # Enterprise feature configs # **Before using, please contact business@dify.ai by email to inquire about licensing matters.** EnterpriseFeatureConfig, + # Enterprise telemetry configs + EnterpriseTelemetryConfig, ): model_config = SettingsConfigDict( # read from dotenv format config file diff --git a/api/configs/enterprise/__init__.py b/api/configs/enterprise/__init__.py index eda6345e14..4920eeba07 100644 --- a/api/configs/enterprise/__init__.py +++ b/api/configs/enterprise/__init__.py @@ -18,3 +18,44 @@ class EnterpriseFeatureConfig(BaseSettings): description="Allow customization of the enterprise logo.", default=False, ) + + +class EnterpriseTelemetryConfig(BaseSettings): + """ + Configuration for enterprise telemetry. + """ + + ENTERPRISE_TELEMETRY_ENABLED: bool = Field( + description="Enable enterprise telemetry collection (also requires ENTERPRISE_ENABLED=true).", + default=False, + ) + + ENTERPRISE_OTLP_ENDPOINT: str = Field( + description="Enterprise OTEL collector endpoint.", + default="", + ) + + ENTERPRISE_OTLP_HEADERS: str = Field( + description="Auth headers for OTLP export (key=value,key2=value2).", + default="", + ) + + ENTERPRISE_OTLP_PROTOCOL: str = Field( + description="OTLP protocol: 'http' or 'grpc' (default: http).", + default="http", + ) + + ENTERPRISE_INCLUDE_CONTENT: bool = Field( + description="Include input/output content in traces (privacy toggle).", + default=True, + ) + + ENTERPRISE_SERVICE_NAME: str = Field( + description="Service name for OTEL resource.", + default="dify", + ) + + ENTERPRISE_OTEL_SAMPLING_RATE: float = Field( + description="Sampling rate for enterprise traces (0.0 to 1.0, default 1.0 = 100%).", + default=1.0, + ) diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 1ac55b5e8d..280b021bb5 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from typing import Any from flask_restx import Resource from pydantic import BaseModel, Field @@ -11,12 +12,10 @@ from controllers.console.app.error import ( ProviderQuotaExceededError, ) from controllers.console.wraps import account_initialization_required, setup_required -from core.app.app_config.entities import ModelConfig from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider -from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db @@ -27,14 +26,32 @@ from services.workflow_service import WorkflowService DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" +class RuleGeneratePayload(BaseModel): + instruction: str = Field(..., description="Rule generation instruction") + model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") + no_variable: bool = Field(default=False, description="Whether to exclude variables") + app_id: str | None = Field(default=None, description="App ID for prompt generation tracing") + + +class RuleCodeGeneratePayload(RuleGeneratePayload): + code_language: str = Field(default="javascript", description="Programming language for code generation") + + +class RuleStructuredOutputPayload(BaseModel): + instruction: str = Field(..., description="Structured output generation instruction") + model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") + app_id: str | None = Field(default=None, description="App ID for prompt generation tracing") + + class InstructionGeneratePayload(BaseModel): flow_id: str = Field(..., description="Workflow/Flow ID") node_id: str = Field(default="", description="Node ID for workflow context") current: str = Field(default="", description="Current instruction text") language: str = Field(default="javascript", description="Programming language (javascript/python)") instruction: str = Field(..., description="Instruction for generation") - model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration") + model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") ideal_output: str = Field(default="", description="Expected ideal output") + app_id: str | None = Field(default=None, description="App ID for prompt generation tracing") class InstructionTemplatePayload(BaseModel): @@ -50,7 +67,6 @@ reg(RuleCodeGeneratePayload) reg(RuleStructuredOutputPayload) reg(InstructionGeneratePayload) reg(InstructionTemplatePayload) -reg(ModelConfig) @console_ns.route("/rule-generate") @@ -66,10 +82,17 @@ class RuleGenerateApi(Resource): @account_initialization_required def post(self): args = RuleGeneratePayload.model_validate(console_ns.payload) - _, current_tenant_id = current_account_with_tenant() + account, current_tenant_id = current_account_with_tenant() try: - rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args) + rules = LLMGenerator.generate_rule_config( + tenant_id=current_tenant_id, + instruction=args.instruction, + model_config=args.model_config_data, + no_variable=args.no_variable, + user_id=account.id, + app_id=args.app_id, + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except QuotaExceededError: @@ -95,12 +118,16 @@ class RuleCodeGenerateApi(Resource): @account_initialization_required def post(self): args = RuleCodeGeneratePayload.model_validate(console_ns.payload) - _, current_tenant_id = current_account_with_tenant() + account, current_tenant_id = current_account_with_tenant() try: code_result = LLMGenerator.generate_code( tenant_id=current_tenant_id, - args=args, + instruction=args.instruction, + model_config=args.model_config_data, + code_language=args.code_language, + user_id=account.id, + app_id=args.app_id, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -127,12 +154,15 @@ class RuleStructuredOutputGenerateApi(Resource): @account_initialization_required def post(self): args = RuleStructuredOutputPayload.model_validate(console_ns.payload) - _, current_tenant_id = current_account_with_tenant() + account, current_tenant_id = current_account_with_tenant() try: structured_output = LLMGenerator.generate_structured_output( tenant_id=current_tenant_id, - args=args, + instruction=args.instruction, + model_config=args.model_config_data, + user_id=account.id, + app_id=args.app_id, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -159,14 +189,14 @@ class InstructionGenerateApi(Resource): @account_initialization_required def post(self): args = InstructionGeneratePayload.model_validate(console_ns.payload) - _, current_tenant_id = current_account_with_tenant() + account, current_tenant_id = current_account_with_tenant() + app_id = args.app_id or args.flow_id providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider] code_provider: type[CodeNodeProvider] | None = next( (p for p in providers if p.is_accept_language(args.language)), None ) code_template = code_provider.get_default_code() if code_provider else "" try: - # Generate from nothing for a workflow node if (args.current in (code_template, "")) and args.node_id != "": app = db.session.query(App).where(App.id == args.flow_id).first() if not app: @@ -183,33 +213,33 @@ class InstructionGenerateApi(Resource): case "llm": return LLMGenerator.generate_rule_config( current_tenant_id, - args=RuleGeneratePayload( - instruction=args.instruction, - model_config=args.model_config_data, - no_variable=True, - ), + instruction=args.instruction, + model_config=args.model_config_data, + no_variable=True, + user_id=account.id, + app_id=app_id, ) case "agent": return LLMGenerator.generate_rule_config( current_tenant_id, - args=RuleGeneratePayload( - instruction=args.instruction, - model_config=args.model_config_data, - no_variable=True, - ), + instruction=args.instruction, + model_config=args.model_config_data, + no_variable=True, + user_id=account.id, + app_id=app_id, ) case "code": return LLMGenerator.generate_code( tenant_id=current_tenant_id, - args=RuleCodeGeneratePayload( - instruction=args.instruction, - model_config=args.model_config_data, - code_language=args.language, - ), + instruction=args.instruction, + model_config=args.model_config_data, + code_language=args.language, + user_id=account.id, + app_id=app_id, ) case _: return {"error": f"invalid node type: {node_type}"} - if args.node_id == "" and args.current != "": # For legacy app without a workflow + if args.node_id == "" and args.current != "": return LLMGenerator.instruction_modify_legacy( tenant_id=current_tenant_id, flow_id=args.flow_id, @@ -217,8 +247,10 @@ class InstructionGenerateApi(Resource): instruction=args.instruction, model_config=args.model_config_data, ideal_output=args.ideal_output, + user_id=account.id, + app_id=app_id, ) - if args.node_id != "" and args.current != "": # For workflow node + if args.node_id != "" and args.current != "": return LLMGenerator.instruction_modify_workflow( tenant_id=current_tenant_id, flow_id=args.flow_id, @@ -228,6 +260,8 @@ class InstructionGenerateApi(Resource): model_config=args.model_config_data, ideal_output=args.ideal_output, workflow_service=WorkflowService(), + user_id=account.id, + app_id=app_id, ) return {"error": "incompatible parameters"}, 400 except ProviderTokenNotInitError as ex: diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index cbcf513162..c5622c7006 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -1,6 +1,7 @@ from typing import Any from flask import request +from flask_login import current_user from flask_restx import Resource, fields from pydantic import BaseModel, Field from werkzeug.exceptions import BadRequest @@ -77,7 +78,10 @@ class TraceAppConfigApi(Resource): try: result = OpsService.create_tracing_app_config( - app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config + app_id=app_id, + tracing_provider=args.tracing_provider, + tracing_config=args.tracing_config, + account_id=current_user.id, ) if not result: raise TracingConfigIsExist() @@ -102,7 +106,10 @@ class TraceAppConfigApi(Resource): try: result = OpsService.update_tracing_app_config( - app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config + app_id=app_id, + tracing_provider=args.tracing_provider, + tracing_config=args.tracing_config, + account_id=current_user.id, ) if not result: raise TracingConfigNotExist() @@ -124,7 +131,9 @@ class TraceAppConfigApi(Resource): args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore try: - result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider) + result = OpsService.delete_tracing_app_config( + app_id=app_id, tracing_provider=args.tracing_provider, account_id=current_user.id + ) if not result: raise TracingConfigNotExist() return {"result": "success"}, 204 diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 3c6d36afe4..9765d7f41c 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -79,7 +79,7 @@ class BaseAgentRunner(AppRunner): self.model_instance = model_instance # init callback - self.agent_callback = DifyAgentCallbackHandler() + self.agent_callback = DifyAgentCallbackHandler(tenant_id=tenant_id) # init dataset tools hit_callback = DatasetIndexToolCallbackHandler( queue_manager=queue_manager, diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index da1e9f19b6..d8123593ec 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -63,6 +63,8 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit from core.workflow.enums import WorkflowExecutionStatus from core.workflow.nodes import NodeType from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory @@ -564,7 +566,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle stop events.""" - _ = trace_manager resolved_state = None if self._workflow_run_id: resolved_state = self._resolve_graph_runtime_state(graph_runtime_state) @@ -579,8 +580,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ) with self._database_session() as session: - # Save message - self._save_message(session=session, graph_runtime_state=resolved_state) + self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager) yield workflow_finish_resp elif event.stopped_by in ( @@ -589,8 +589,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ): # When hitting input-moderation or annotation-reply, the workflow will not start with self._database_session() as session: - # Save message - self._save_message(session=session) + self._save_message(session=session, trace_manager=trace_manager) yield self._message_end_to_stream_response() @@ -599,6 +598,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): event: QueueAdvancedChatMessageEndEvent, *, graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle advanced chat message end events.""" @@ -616,7 +616,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): # Save message with self._database_session() as session: - self._save_message(session=session, graph_runtime_state=resolved_state) + self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager) yield self._message_end_to_stream_response() @@ -770,7 +770,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): if self._conversation_name_generate_thread: logger.debug("Conversation name generation running as daemon thread") - def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None): + def _save_message( + self, + *, + session: Session, + graph_runtime_state: GraphRuntimeState | None = None, + trace_manager: TraceQueueManager | None = None, + ): message = self._get_message(session=session) # If there are assistant files, remove markdown image links from answer @@ -826,6 +832,22 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ] session.add_all(message_files) + if trace_manager: + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.MESSAGE_TRACE, + context=TelemetryContext( + tenant_id=self._application_generate_entity.app_config.tenant_id, + app_id=self._application_generate_entity.app_config.app_id, + ), + payload={ + "conversation_id": str(message.conversation_id), + "message_id": str(message.id), + }, + ), + trace_manager=trace_manager, + ) + def _seed_graph_runtime_state_from_queue_manager(self) -> None: """Bootstrap the cached runtime state from the queue manager when present.""" candidate = self._base_task_pipeline.queue_manager.graph_runtime_state diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index ee205ed153..5d04ae56e0 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -147,9 +147,12 @@ class WorkflowAppGenerator(BaseAppGenerator): inputs: Mapping[str, Any] = args["inputs"] - extras = { + extras: dict[str, Any] = { **extract_external_trace_id_from_args(args), } + parent_trace_context = args.get("_parent_trace_context") + if parent_trace_context: + extras["parent_trace_context"] = parent_trace_context workflow_run_id = str(uuid.uuid4()) # FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args # trigger shouldn't prepare user inputs diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 6c997753fa..2f6f5cc5db 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -52,10 +52,11 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, ) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit from events.message_event import message_was_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now @@ -409,10 +410,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): message.message_metadata = self._task_state.metadata.model_dump_json() if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id - ) + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.MESSAGE_TRACE, + context=TelemetryContext( + tenant_id=self._application_generate_entity.app_config.tenant_id, + app_id=self._application_generate_entity.app_config.app_id, + ), + payload={ + "conversation_id": self._conversation_id, + "message_id": self._message_id, + }, + ), + trace_manager=trace_manager, ) message_was_created.send( diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index 41052b4f52..aaa8b4e2dc 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -15,8 +15,7 @@ from datetime import datetime from typing import Any, Union from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution from core.workflow.enums import ( @@ -373,6 +372,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): self._workflow_node_execution_repository.save(domain_execution) self._workflow_node_execution_repository.save_execution_data(domain_execution) + self._enqueue_node_trace_task(domain_execution) def _fail_running_node_executions(self, *, error_message: str) -> None: now = naive_utc_now() @@ -390,17 +390,131 @@ class WorkflowPersistenceLayer(GraphEngineLayer): conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value) external_trace_id = None + parent_trace_context = None if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)): external_trace_id = self._application_generate_entity.extras.get("external_trace_id") + parent_trace_context = self._application_generate_entity.extras.get("parent_trace_context") - trace_task = TraceTask( - TraceTaskName.WORKFLOW_TRACE, - workflow_execution=execution, - conversation_id=conversation_id, - user_id=self._trace_manager.user_id, - external_trace_id=external_trace_id, + from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName + from core.telemetry import emit as telemetry_emit + + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.WORKFLOW_TRACE, + context=TelemetryContext( + tenant_id=self._application_generate_entity.app_config.tenant_id, + user_id=self._trace_manager.user_id, + app_id=self._application_generate_entity.app_config.app_id, + ), + payload={ + "workflow_execution": execution, + "conversation_id": conversation_id, + "user_id": self._trace_manager.user_id, + "external_trace_id": external_trace_id, + "parent_trace_context": parent_trace_context, + }, + ), + trace_manager=self._trace_manager, + ) + + def _enqueue_node_trace_task(self, domain_execution: WorkflowNodeExecution) -> None: + if not self._trace_manager: + return + + execution = self._get_workflow_execution() + meta = domain_execution.metadata or {} + + parent_trace_context = None + if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)): + parent_trace_context = self._application_generate_entity.extras.get("parent_trace_context") + + node_data: dict[str, Any] = { + "workflow_id": domain_execution.workflow_id, + "workflow_execution_id": execution.id_, + "tenant_id": self._application_generate_entity.app_config.tenant_id, + "app_id": self._application_generate_entity.app_config.app_id, + "node_execution_id": domain_execution.id, + "node_id": domain_execution.node_id, + "node_type": str(domain_execution.node_type.value), + "title": domain_execution.title, + "status": str(domain_execution.status.value), + "error": domain_execution.error, + "elapsed_time": domain_execution.elapsed_time, + "index": domain_execution.index, + "predecessor_node_id": domain_execution.predecessor_node_id, + "created_at": domain_execution.created_at, + "finished_at": domain_execution.finished_at, + "total_tokens": meta.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0), + "prompt_tokens": meta.get(WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS), + "completion_tokens": meta.get(WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS), + "total_price": meta.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0), + "currency": meta.get(WorkflowNodeExecutionMetadataKey.CURRENCY), + "tool_name": (meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name") + if isinstance(meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict) + else None, + "iteration_id": meta.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID), + "iteration_index": meta.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX), + "loop_id": meta.get(WorkflowNodeExecutionMetadataKey.LOOP_ID), + "loop_index": meta.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX), + "parallel_id": meta.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID), + "node_inputs": dict(domain_execution.inputs) if domain_execution.inputs else None, + "node_outputs": dict(domain_execution.outputs) if domain_execution.outputs else None, + "process_data": dict(domain_execution.process_data) if domain_execution.process_data else None, + } + node_data["invoke_from"] = self._application_generate_entity.invoke_from.value + node_data["user_id"] = self._system_variables().get(SystemVariableKey.USER_ID.value) + + if domain_execution.node_type.value == "knowledge-retrieval" and domain_execution.outputs: + results = domain_execution.outputs.get("result") or [] + dataset_ids: list[str] = [] + dataset_names: list[str] = [] + for doc in results: + if not isinstance(doc, dict): + continue + doc_meta = doc.get("metadata") or {} + did = doc_meta.get("dataset_id") + dname = doc_meta.get("dataset_name") + if did and did not in dataset_ids: + dataset_ids.append(did) + if dname and dname not in dataset_names: + dataset_names.append(dname) + if dataset_ids: + node_data["dataset_ids"] = dataset_ids + if dataset_names: + node_data["dataset_names"] = dataset_names + + tool_info = meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) + if isinstance(tool_info, dict): + plugin_id = tool_info.get("plugin_unique_identifier") + if plugin_id: + node_data["plugin_name"] = plugin_id + credential_id = tool_info.get("credential_id") + if credential_id: + node_data["credential_id"] = credential_id + node_data["credential_provider_type"] = tool_info.get("provider_type") + + conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value) + if conversation_id: + node_data["conversation_id"] = conversation_id + + if parent_trace_context: + node_data["parent_trace_context"] = parent_trace_context + + from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName + from core.telemetry import emit as telemetry_emit + + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id=node_data.get("tenant_id"), + user_id=node_data.get("user_id"), + app_id=node_data.get("app_id"), + ), + payload={"node_execution_data": node_data}, + ), + trace_manager=self._trace_manager, ) - self._trace_manager.add_trace_task(trace_task) def _system_variables(self) -> Mapping[str, Any]: runtime_state = self.graph_runtime_state diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 6591b08a7e..e1c5f4ac4b 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -4,8 +4,9 @@ from typing import Any, TextIO, Union from pydantic import BaseModel from configs import dify_config -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.ops.ops_trace_manager import TraceQueueManager +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit from core.tools.entities.tool_entities import ToolInvokeMessage _TEXT_COLOR_MAPPING = { @@ -36,13 +37,15 @@ class DifyAgentCallbackHandler(BaseModel): color: str | None = "" current_loop: int = 1 + tenant_id: str | None = None - def __init__(self, color: str | None = None): + def __init__(self, color: str | None = None, tenant_id: str | None = None): super().__init__() """Initialize callback handler.""" # use a specific color is not specified self.color = color or "green" self.current_loop = 1 + self.tenant_id = tenant_id def on_tool_start( self, @@ -71,15 +74,23 @@ class DifyAgentCallbackHandler(BaseModel): print_text("\n") if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.TOOL_TRACE, - message_id=message_id, - tool_name=tool_name, - tool_inputs=tool_inputs, - tool_outputs=tool_outputs, - timer=timer, - ) + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.TOOL_TRACE, + context=TelemetryContext( + tenant_id=self.tenant_id, + app_id=trace_manager.app_id, + user_id=trace_manager.user_id, + ), + payload={ + "message_id": message_id, + "tool_name": tool_name, + "tool_inputs": tool_inputs, + "tool_outputs": tool_outputs, + "timer": timer, + }, + ), + trace_manager=trace_manager, ) def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any): diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 5b2c640265..4279d44fc0 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -6,8 +6,6 @@ from typing import Protocol, cast import json_repair -from core.app.app_config.entities import ModelConfig -from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.llm_generator.prompts import ( @@ -27,10 +25,10 @@ from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from extensions.ext_storage import storage @@ -73,8 +71,8 @@ class LLMGenerator: response: LLMResult = model_instance.invoke_llm( prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False ) - answer = response.message.get_text_content() - if answer == "": + answer = cast(str, response.message.content) + if answer is None: return "" try: result_dict = json.loads(answer) @@ -96,15 +94,17 @@ class LLMGenerator: name = name[:75] + "..." # get tracing instance - trace_manager = TraceQueueManager(app_id=app_id) - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.GENERATE_NAME_TRACE, - conversation_id=conversation_id, - generate_conversation_name=name, - inputs=prompt, - timer=timer, - tenant_id=tenant_id, + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.GENERATE_NAME_TRACE, + context=TelemetryContext(tenant_id=tenant_id, app_id=app_id), + payload={ + "conversation_id": conversation_id, + "generate_conversation_name": name, + "inputs": prompt, + "timer": timer, + "tenant_id": tenant_id, + }, ) ) @@ -153,19 +153,27 @@ class LLMGenerator: return questions @classmethod - def generate_rule_config(cls, tenant_id: str, args: RuleGeneratePayload): + def generate_rule_config( + cls, + tenant_id: str, + instruction: str, + model_config: dict, + no_variable: bool, + user_id: str | None = None, + app_id: str | None = None, + ): output_parser = RuleConfigGeneratorOutputParser() error = "" error_step = "" rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""} - model_parameters = args.model_config_data.completion_params - if args.no_variable: + model_parameters = model_config.get("completion_params", {}) + if no_variable: prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE) prompt_generate = prompt_template.format( inputs={ - "TASK_DESCRIPTION": args.instruction, + "TASK_DESCRIPTION": instruction, }, remove_template_variables=False, ) @@ -177,26 +185,45 @@ class LLMGenerator: model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=args.model_config_data.provider, - model=args.model_config_data.name, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), ) - try: - response: LLMResult = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False - ) + llm_result = None + with measure_time() as timer: + try: + llm_result = model_instance.invoke_llm( + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + ) - rule_config["prompt"] = response.message.get_text_content() + rule_config["prompt"] = cast(str, llm_result.message.content) - except InvokeError as e: - error = str(e) - error_step = "generate rule config" - except Exception as e: - logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name) - rule_config["error"] = str(e) + except InvokeError as e: + error = str(e) + error_step = "generate rule config" + except Exception as e: + logger.exception("Failed to generate rule config, model: %s", model_config.get("name")) + rule_config["error"] = str(e) + error = str(e) rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" + if user_id: + prompt_value = rule_config.get("prompt", "") + generated_output = str(prompt_value) if prompt_value else "" + cls._emit_prompt_generation_trace( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + operation_type="rule_generate", + instruction=instruction, + generated_output=generated_output, + llm_result=llm_result, + model_config=model_config, + timer=timer, + error=error or None, + ) + return rule_config # get rule config prompt, parameter and statement @@ -211,7 +238,7 @@ class LLMGenerator: # format the prompt_generate_prompt prompt_generate_prompt = prompt_template.format( inputs={ - "TASK_DESCRIPTION": args.instruction, + "TASK_DESCRIPTION": instruction, }, remove_template_variables=False, ) @@ -222,84 +249,125 @@ class LLMGenerator: model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=args.model_config_data.provider, - model=args.model_config_data.name, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), ) - try: + llm_result = None + with measure_time() as timer: try: - # the first step to generate the task prompt - prompt_content: LLMResult = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + try: + # the first step to generate the task prompt + prompt_content: LLMResult = model_instance.invoke_llm( + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + ) + llm_result = prompt_content + except InvokeError as e: + error = str(e) + error_step = "generate prefix prompt" + rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" + + if user_id: + cls._emit_prompt_generation_trace( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + operation_type="rule_generate", + instruction=instruction, + generated_output="", + llm_result=llm_result, + model_config=model_config, + timer=timer, + error=error, + ) + + return rule_config + + rule_config["prompt"] = cast(str, prompt_content.message.content) + + if not isinstance(prompt_content.message.content, str): + raise NotImplementedError("prompt content is not a string") + parameter_generate_prompt = parameter_template.format( + inputs={ + "INPUT_TEXT": prompt_content.message.content, + }, + remove_template_variables=False, ) - except InvokeError as e: - error = str(e) - error_step = "generate prefix prompt" - rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" + parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)] - return rule_config - - rule_config["prompt"] = prompt_content.message.get_text_content() - - parameter_generate_prompt = parameter_template.format( - inputs={ - "INPUT_TEXT": prompt_content.message.get_text_content(), - }, - remove_template_variables=False, - ) - parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)] - - # the second step to generate the task_parameter and task_statement - statement_generate_prompt = statement_template.format( - inputs={ - "TASK_DESCRIPTION": args.instruction, - "INPUT_TEXT": prompt_content.message.get_text_content(), - }, - remove_template_variables=False, - ) - statement_messages = [UserPromptMessage(content=statement_generate_prompt)] - - try: - parameter_content: LLMResult = model_instance.invoke_llm( - prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False + # the second step to generate the task_parameter and task_statement + statement_generate_prompt = statement_template.format( + inputs={ + "TASK_DESCRIPTION": instruction, + "INPUT_TEXT": prompt_content.message.content, + }, + remove_template_variables=False, ) - rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.get_text_content()) - except InvokeError as e: - error = str(e) - error_step = "generate variables" + statement_messages = [UserPromptMessage(content=statement_generate_prompt)] - try: - statement_content: LLMResult = model_instance.invoke_llm( - prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False - ) - rule_config["opening_statement"] = statement_content.message.get_text_content() - except InvokeError as e: - error = str(e) - error_step = "generate conversation opener" + try: + parameter_content: LLMResult = model_instance.invoke_llm( + prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False + ) + rule_config["variables"] = re.findall( + r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content) + ) + except InvokeError as e: + error = str(e) + error_step = "generate variables" - except Exception as e: - logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name) - rule_config["error"] = str(e) + try: + statement_content: LLMResult = model_instance.invoke_llm( + prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False + ) + rule_config["opening_statement"] = cast(str, statement_content.message.content) + except InvokeError as e: + error = str(e) + error_step = "generate conversation opener" + + except Exception as e: + logger.exception("Failed to generate rule config, model: %s", model_config.get("name")) + rule_config["error"] = str(e) + error = str(e) rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" + if user_id: + generated_output = rule_config.get("prompt", "") + cls._emit_prompt_generation_trace( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + operation_type="rule_generate", + instruction=instruction, + generated_output=str(generated_output) if generated_output else "", + llm_result=llm_result, + model_config=model_config, + timer=timer, + error=error or None, + ) + return rule_config @classmethod def generate_code( cls, tenant_id: str, - args: RuleCodeGeneratePayload, + instruction: str, + model_config: dict, + code_language: str = "javascript", + user_id: str | None = None, + app_id: str | None = None, ): - if args.code_language == "python": + if code_language == "python": prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) else: prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE) prompt = prompt_template.format( inputs={ - "INSTRUCTION": args.instruction, - "CODE_LANGUAGE": args.code_language, + "INSTRUCTION": instruction, + "CODE_LANGUAGE": code_language, }, remove_template_variables=False, ) @@ -308,28 +376,49 @@ class LLMGenerator: model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=args.model_config_data.provider, - model=args.model_config_data.name, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), ) prompt_messages = [UserPromptMessage(content=prompt)] - model_parameters = args.model_config_data.completion_params - try: - response: LLMResult = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + model_parameters = model_config.get("completion_params", {}) + + llm_result = None + error = None + with measure_time() as timer: + try: + llm_result = model_instance.invoke_llm( + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + ) + + generated_code = cast(str, llm_result.message.content) + result = {"code": generated_code, "language": code_language, "error": ""} + + except InvokeError as e: + error = str(e) + result = {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"} + except Exception as e: + logger.exception( + "Failed to invoke LLM model, model: %s, language: %s", model_config.get("name"), code_language + ) + error = str(e) + result = {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"} + + if user_id: + cls._emit_prompt_generation_trace( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + operation_type="code_generate", + instruction=instruction, + generated_output=result.get("code", ""), + llm_result=llm_result, + model_config=model_config, + timer=timer, + error=error, ) - generated_code = response.message.get_text_content() - return {"code": generated_code, "language": args.code_language, "error": ""} - - except InvokeError as e: - error = str(e) - return {"code": "", "language": args.code_language, "error": f"Failed to generate code. Error: {error}"} - except Exception as e: - logger.exception( - "Failed to invoke LLM model, model: %s, language: %s", args.model_config_data.name, args.code_language - ) - return {"code": "", "language": args.code_language, "error": f"An unexpected error occurred: {str(e)}"} + return result @classmethod def generate_qa_document(cls, tenant_id: str, query, document_language: str): @@ -355,49 +444,76 @@ class LLMGenerator: raise TypeError("Expected LLMResult when stream=False") response = result - answer = response.message.get_text_content() + answer = cast(str, response.message.content) return answer.strip() @classmethod - def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload): + def generate_structured_output( + cls, tenant_id: str, instruction: str, model_config: dict, user_id: str | None = None, app_id: str | None = None + ): model_manager = ModelManager() model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=args.model_config_data.provider, - model=args.model_config_data.name, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), ) prompt_messages = [ SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE), - UserPromptMessage(content=args.instruction), + UserPromptMessage(content=instruction), ] - model_parameters = args.model_config_data.completion_params + model_parameters = model_config.get("model_parameters", {}) - try: - response: LLMResult = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + llm_result = None + error = None + result = {"output": "", "error": ""} + + with measure_time() as timer: + try: + llm_result = model_instance.invoke_llm( + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + ) + + raw_content = llm_result.message.content + + if not isinstance(raw_content, str): + raise ValueError(f"LLM response content must be a string, got: {type(raw_content)}") + + try: + parsed_content = json.loads(raw_content) + except json.JSONDecodeError: + parsed_content = json_repair.loads(raw_content) + + if not isinstance(parsed_content, dict | list): + raise ValueError(f"Failed to parse structured output from llm: {raw_content}") + + generated_json_schema = json.dumps(parsed_content, indent=2, ensure_ascii=False) + result = {"output": generated_json_schema, "error": ""} + + except InvokeError as e: + error = str(e) + result = {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"} + except Exception as e: + logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name")) + error = str(e) + result = {"output": "", "error": f"An unexpected error occurred: {str(e)}"} + + if user_id: + cls._emit_prompt_generation_trace( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + operation_type="structured_output", + instruction=instruction, + generated_output=result.get("output", ""), + llm_result=llm_result, + model_config=model_config, + timer=timer, + error=error, ) - raw_content = response.message.get_text_content() - - try: - parsed_content = json.loads(raw_content) - except json.JSONDecodeError: - parsed_content = json_repair.loads(raw_content) - - if not isinstance(parsed_content, dict | list): - raise ValueError(f"Failed to parse structured output from llm: {raw_content}") - - generated_json_schema = json.dumps(parsed_content, indent=2, ensure_ascii=False) - return {"output": generated_json_schema, "error": ""} - - except InvokeError as e: - error = str(e) - return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"} - except Exception as e: - logger.exception("Failed to invoke LLM model, model: %s", args.model_config_data.name) - return {"output": "", "error": f"An unexpected error occurred: {str(e)}"} + return result @staticmethod def instruction_modify_legacy( @@ -405,14 +521,16 @@ class LLMGenerator: flow_id: str, current: str, instruction: str, - model_config: ModelConfig, + model_config: dict, ideal_output: str | None, + user_id: str | None = None, + app_id: str | None = None, ): last_run: Message | None = ( db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() ) if not last_run: - return LLMGenerator.__instruction_modify_common( + result = LLMGenerator.__instruction_modify_common( tenant_id=tenant_id, model_config=model_config, last_run=None, @@ -421,22 +539,28 @@ class LLMGenerator: instruction=instruction, node_type="llm", ideal_output=ideal_output, + user_id=user_id, + app_id=app_id, ) - last_run_dict = { - "query": last_run.query, - "answer": last_run.answer, - "error": last_run.error, - } - return LLMGenerator.__instruction_modify_common( - tenant_id=tenant_id, - model_config=model_config, - last_run=last_run_dict, - current=current, - error_message=str(last_run.error), - instruction=instruction, - node_type="llm", - ideal_output=ideal_output, - ) + else: + last_run_dict = { + "query": last_run.query, + "answer": last_run.answer, + "error": last_run.error, + } + result = LLMGenerator.__instruction_modify_common( + tenant_id=tenant_id, + model_config=model_config, + last_run=last_run_dict, + current=current, + error_message=str(last_run.error), + instruction=instruction, + node_type="llm", + ideal_output=ideal_output, + user_id=user_id, + app_id=app_id, + ) + return result @staticmethod def instruction_modify_workflow( @@ -445,9 +569,11 @@ class LLMGenerator: node_id: str, current: str, instruction: str, - model_config: ModelConfig, + model_config: dict, ideal_output: str | None, workflow_service: WorkflowServiceInterface, + user_id: str | None = None, + app_id: str | None = None, ): session = db.session() @@ -478,6 +604,8 @@ class LLMGenerator: instruction=instruction, node_type=node_type, ideal_output=ideal_output, + user_id=user_id, + app_id=app_id, ) def agent_log_of(node_execution: WorkflowNodeExecutionModel) -> Sequence: @@ -511,18 +639,22 @@ class LLMGenerator: instruction=instruction, node_type=last_run.node_type, ideal_output=ideal_output, + user_id=user_id, + app_id=app_id, ) @staticmethod def __instruction_modify_common( tenant_id: str, - model_config: ModelConfig, + model_config: dict, last_run: dict | None, current: str | None, error_message: str | None, instruction: str, node_type: str, ideal_output: str | None, + user_id: str | None = None, + app_id: str | None = None, ): LAST_RUN = "{{#last_run#}}" CURRENT = "{{#current#}}" @@ -537,8 +669,8 @@ class LLMGenerator: model_instance = ModelManager().get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.provider, - model=model_config.name, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), ) match node_type: case "llm" | "agent": @@ -562,24 +694,122 @@ class LLMGenerator: ] model_parameters = {"temperature": 0.4} - try: - response: LLMResult = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + llm_result = None + error = None + result = {} + + with measure_time() as timer: + try: + llm_result = model_instance.invoke_llm( + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + ) + + generated_raw = llm_result.message.get_text_content() + first_brace = generated_raw.find("{") + last_brace = generated_raw.rfind("}") + if first_brace == -1 or last_brace == -1 or last_brace < first_brace: + raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}") + json_str = generated_raw[first_brace : last_brace + 1] + data = json_repair.loads(json_str) + if not isinstance(data, dict): + raise TypeError(f"Expected a JSON object, but got {type(data).__name__}") + result = data + except InvokeError as e: + error = str(e) + result = {"error": f"Failed to generate code. Error: {error}"} + except Exception as e: + logger.exception( + "Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=True + ) + error = str(e) + result = {"error": f"An unexpected error occurred: {str(e)}"} + + if user_id: + generated_output = "" + if isinstance(result, dict): + for key in ["prompt", "code", "output", "modified"]: + if result.get(key): + generated_output = str(result[key]) + break + + LLMGenerator._emit_prompt_generation_trace( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + operation_type="instruction_modify", + instruction=instruction, + generated_output=generated_output, + llm_result=llm_result, + model_config=model_config, + timer=timer, + error=error, ) - generated_raw = response.message.get_text_content() - first_brace = generated_raw.find("{") - last_brace = generated_raw.rfind("}") - if first_brace == -1 or last_brace == -1 or last_brace < first_brace: - raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}") - json_str = generated_raw[first_brace : last_brace + 1] - data = json_repair.loads(json_str) - if not isinstance(data, dict): - raise TypeError(f"Expected a JSON object, but got {type(data).__name__}") - return data - except InvokeError as e: - error = str(e) - return {"error": f"Failed to generate code. Error: {error}"} - except Exception as e: - logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.name), exc_info=True) - return {"error": f"An unexpected error occurred: {str(e)}"} + return result + + @classmethod + def _emit_prompt_generation_trace( + cls, + tenant_id: str, + user_id: str, + app_id: str | None, + operation_type: str, + instruction: str, + generated_output: str, + llm_result: LLMResult | None, + model_config: dict | None = None, + timer=None, + error: str | None = None, + ): + if llm_result: + prompt_tokens = llm_result.usage.prompt_tokens + completion_tokens = llm_result.usage.completion_tokens + total_tokens = llm_result.usage.total_tokens + model_name = llm_result.model + # Extract provider from model_config if available, otherwise fall back to parsing model name + if model_config and model_config.get("provider"): + model_provider = model_config.get("provider", "") + else: + model_provider = model_name.split("/")[0] if "/" in model_name else "" + latency = llm_result.usage.latency + total_price = float(llm_result.usage.total_price) if llm_result.usage.total_price else None + currency = llm_result.usage.currency + else: + prompt_tokens = 0 + completion_tokens = 0 + total_tokens = 0 + model_provider = model_config.get("provider", "") if model_config else "" + model_name = model_config.get("name", "") if model_config else "" + latency = 0.0 + if timer: + start_time = timer.get("start") + end_time = timer.get("end") + if start_time and end_time: + latency = (end_time - start_time).total_seconds() + total_price = None + currency = None + + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.PROMPT_GENERATION_TRACE, + context=TelemetryContext(tenant_id=tenant_id, user_id=user_id, app_id=app_id), + payload={ + "tenant_id": tenant_id, + "user_id": user_id, + "app_id": app_id, + "operation_type": operation_type, + "instruction": instruction, + "generated_output": generated_output, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + "model_provider": model_provider, + "model_name": model_name, + "latency": latency, + "total_price": total_price, + "currency": currency, + "timer": timer, + "error": error, + }, + ) + ) diff --git a/api/core/logging/filters.py b/api/core/logging/filters.py index 1e8aa8d566..bc816eb66b 100644 --- a/api/core/logging/filters.py +++ b/api/core/logging/filters.py @@ -15,16 +15,23 @@ class TraceContextFilter(logging.Filter): """ def filter(self, record: logging.LogRecord) -> bool: - # Get trace context from OpenTelemetry - trace_id, span_id = self._get_otel_context() + # Preserve explicit trace_id set by the caller (e.g. emit_metric_only_event) + existing_trace_id = getattr(record, "trace_id", "") + if not existing_trace_id: + # Get trace context from OpenTelemetry + trace_id, span_id = self._get_otel_context() - # Set trace_id (fallback to ContextVar if no OTEL context) - if trace_id: - record.trace_id = trace_id + # Set trace_id (fallback to ContextVar if no OTEL context) + if trace_id: + record.trace_id = trace_id + else: + record.trace_id = get_trace_id() + + record.span_id = span_id or "" else: - record.trace_id = get_trace_id() - - record.span_id = span_id or "" + # Keep existing trace_id; only fill span_id if missing + if not getattr(record, "span_id", ""): + record.span_id = "" # For backward compatibility, also set req_id record.req_id = get_request_id() @@ -55,9 +62,12 @@ class IdentityContextFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: identity = self._extract_identity() - record.tenant_id = identity.get("tenant_id", "") - record.user_id = identity.get("user_id", "") - record.user_type = identity.get("user_type", "") + if not getattr(record, "tenant_id", ""): + record.tenant_id = identity.get("tenant_id", "") + if not getattr(record, "user_id", ""): + record.user_id = identity.get("user_id", "") + if not getattr(record, "user_type", ""): + record.user_type = identity.get("user_type", "") return True def _extract_identity(self) -> dict[str, str]: diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 21dc58f16f..4afe706a62 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -5,9 +5,10 @@ from typing import Any from core.app.app_config.entities import AppConfig from core.moderation.base import ModerationAction, ModerationError from core.moderation.factory import ModerationFactory -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.ops.ops_trace_manager import TraceQueueManager from core.ops.utils import measure_time +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit logger = logging.getLogger(__name__) @@ -49,14 +50,18 @@ class InputModeration: moderation_result = moderation_factory.moderation_for_inputs(inputs, query) if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.MODERATION_TRACE, - message_id=message_id, - moderation_result=moderation_result, - inputs=inputs, - timer=timer, - ) + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.MODERATION_TRACE, + context=TelemetryContext(tenant_id=tenant_id, app_id=app_id), + payload={ + "message_id": message_id, + "moderation_result": moderation_result, + "inputs": inputs, + "timer": timer, + }, + ), + trace_manager=trace_manager, ) if not moderation_result.flagged: diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 50a2cdea63..5c878281a6 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -9,8 +9,8 @@ from pydantic import BaseModel, ConfigDict, field_serializer, field_validator class BaseTraceInfo(BaseModel): message_id: str | None = None message_data: Any | None = None - inputs: Union[str, dict[str, Any], list] | None = None - outputs: Union[str, dict[str, Any], list] | None = None + inputs: Union[str, dict[str, Any], list[Any]] | None = None + outputs: Union[str, dict[str, Any], list[Any]] | None = None start_time: datetime | None = None end_time: datetime | None = None metadata: dict[str, Any] @@ -18,7 +18,7 @@ class BaseTraceInfo(BaseModel): @field_validator("inputs", "outputs") @classmethod - def ensure_type(cls, v): + def ensure_type(cls, v: str | dict[str, Any] | list[Any] | None) -> str | dict[str, Any] | list[Any] | None: if v is None: return None if isinstance(v, str | dict | list): @@ -48,10 +48,14 @@ class WorkflowTraceInfo(BaseTraceInfo): workflow_run_version: str error: str | None = None total_tokens: int + prompt_tokens: int | None = None + completion_tokens: int | None = None file_list: list[str] query: str metadata: dict[str, Any] + invoked_by: str | None = None + class MessageTraceInfo(BaseTraceInfo): conversation_model: str @@ -59,7 +63,7 @@ class MessageTraceInfo(BaseTraceInfo): answer_tokens: int total_tokens: int error: str | None = None - file_list: Union[str, dict[str, Any], list] | None = None + file_list: Union[str, dict[str, Any], list[Any]] | None = None message_file_data: Any | None = None conversation_mode: str gen_ai_server_time_to_first_token: float | None = None @@ -106,7 +110,7 @@ class ToolTraceInfo(BaseTraceInfo): tool_config: dict[str, Any] time_cost: Union[int, float] tool_parameters: dict[str, Any] - file_url: Union[str, None, list] = None + file_url: Union[str, None, list[str]] = None class GenerateNameTraceInfo(BaseTraceInfo): @@ -114,6 +118,79 @@ class GenerateNameTraceInfo(BaseTraceInfo): tenant_id: str +class PromptGenerationTraceInfo(BaseTraceInfo): + """Trace information for prompt generation operations (rule-generate, code-generate, etc.).""" + + tenant_id: str + user_id: str + app_id: str | None = None + + operation_type: str + instruction: str + + prompt_tokens: int + completion_tokens: int + total_tokens: int + + model_provider: str + model_name: str + + latency: float + + total_price: float | None = None + currency: str | None = None + + error: str | None = None + + model_config = ConfigDict(protected_namespaces=()) + + +class WorkflowNodeTraceInfo(BaseTraceInfo): + workflow_id: str + workflow_run_id: str + tenant_id: str + node_execution_id: str + node_id: str + node_type: str + title: str + + status: str + error: str | None = None + elapsed_time: float + + index: int + predecessor_node_id: str | None = None + + total_tokens: int = 0 + total_price: float = 0.0 + currency: str | None = None + + model_provider: str | None = None + model_name: str | None = None + prompt_tokens: int | None = None + completion_tokens: int | None = None + + tool_name: str | None = None + + iteration_id: str | None = None + iteration_index: int | None = None + loop_id: str | None = None + loop_index: int | None = None + parallel_id: str | None = None + + node_inputs: Mapping[str, Any] | None = None + node_outputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + + invoked_by: str | None = None + + model_config = ConfigDict(protected_namespaces=()) + + +class DraftNodeExecutionTrace(WorkflowNodeTraceInfo): + pass + + class TaskData(BaseModel): app_id: str trace_info_type: str @@ -128,16 +205,22 @@ trace_info_info_map = { "DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo, "ToolTraceInfo": ToolTraceInfo, "GenerateNameTraceInfo": GenerateNameTraceInfo, + "PromptGenerationTraceInfo": PromptGenerationTraceInfo, + "WorkflowNodeTraceInfo": WorkflowNodeTraceInfo, + "DraftNodeExecutionTrace": DraftNodeExecutionTrace, } class TraceTaskName(StrEnum): CONVERSATION_TRACE = "conversation" WORKFLOW_TRACE = "workflow" + DRAFT_NODE_EXECUTION_TRACE = "draft_node_execution" MESSAGE_TRACE = "message" MODERATION_TRACE = "moderation" SUGGESTED_QUESTION_TRACE = "suggested_question" DATASET_RETRIEVAL_TRACE = "dataset_retrieval" TOOL_TRACE = "tool" GENERATE_NAME_TRACE = "generate_conversation_name" + PROMPT_GENERATION_TRACE = "prompt_generation" DATASOURCE_TRACE = "datasource" + NODE_EXECUTION_TRACE = "node_execution" diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 4de4f403ce..422a121311 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -3,6 +3,7 @@ import os from datetime import datetime, timedelta from langfuse import Langfuse +from sqlalchemy import select from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance @@ -30,7 +31,7 @@ from core.ops.utils import filter_none_values from core.repositories import DifyCoreRepositoryFactory from core.workflow.enums import NodeType from extensions.ext_database import db -from models import EndUser, WorkflowNodeExecutionTriggeredFrom +from models import EndUser, Message, WorkflowNodeExecutionTriggeredFrom from models.enums import MessageStatus logger = logging.getLogger(__name__) @@ -71,7 +72,50 @@ class LangFuseDataTrace(BaseTraceInstance): metadata = trace_info.metadata metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id - if trace_info.message_id: + # Check for parent_trace_context to detect nested workflow + parent_trace_context = trace_info.metadata.get("parent_trace_context") + + if parent_trace_context: + # Nested workflow: create span under outer trace + outer_trace_id = parent_trace_context.get("trace_id") + parent_node_execution_id = parent_trace_context.get("parent_node_execution_id") + parent_conversation_id = parent_trace_context.get("parent_conversation_id") + parent_workflow_run_id = parent_trace_context.get("parent_workflow_run_id") + + # Resolve outer trace_id: try message_id lookup first, fallback to workflow_run_id + if parent_conversation_id: + session_factory = sessionmaker(bind=db.engine) + with session_factory() as session: + message_data_stmt = select(Message.id).where( + Message.conversation_id == parent_conversation_id, + Message.workflow_run_id == parent_workflow_run_id, + ) + resolved_message_id = session.scalar(message_data_stmt) + if resolved_message_id: + outer_trace_id = resolved_message_id + else: + outer_trace_id = parent_workflow_run_id + else: + outer_trace_id = parent_workflow_run_id + + # Create inner workflow span under outer trace + workflow_span_data = LangfuseSpan( + id=trace_info.workflow_run_id, + name=TraceTaskName.WORKFLOW_TRACE, + input=dict(trace_info.workflow_run_inputs), + output=dict(trace_info.workflow_run_outputs), + trace_id=outer_trace_id, + parent_observation_id=parent_node_execution_id, + start_time=trace_info.start_time, + end_time=trace_info.end_time, + metadata=metadata, + level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR, + status_message=trace_info.error or "", + ) + self.add_span(langfuse_span_data=workflow_span_data) + # Use outer_trace_id for all node spans/generations + trace_id = outer_trace_id + elif trace_info.message_id: trace_id = trace_info.trace_id or trace_info.message_id name = TraceTaskName.MESSAGE_TRACE trace_data = LangfuseTrace( @@ -174,6 +218,11 @@ class LangFuseDataTrace(BaseTraceInstance): } ) + # Determine parent_observation_id for nested workflows + node_parent_observation_id = None + if parent_trace_context or trace_info.message_id: + node_parent_observation_id = trace_info.workflow_run_id + # add generation span if process_data and process_data.get("model_mode") == "chat": total_token = metadata.get("total_tokens", 0) @@ -206,7 +255,7 @@ class LangFuseDataTrace(BaseTraceInstance): metadata=metadata, level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), status_message=trace_info.error or "", - parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None, + parent_observation_id=node_parent_observation_id, usage=generation_usage, ) @@ -225,7 +274,7 @@ class LangFuseDataTrace(BaseTraceInstance): metadata=metadata, level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), status_message=trace_info.error or "", - parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None, + parent_observation_id=node_parent_observation_id, ) self.add_span(langfuse_span_data=span_data) diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 8b8117b24c..7ca51e10ef 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -6,6 +6,7 @@ from typing import cast from langsmith import Client from langsmith.schemas import RunBase +from sqlalchemy import select from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance @@ -30,7 +31,7 @@ from core.ops.utils import filter_none_values, generate_dotted_order from core.repositories import DifyCoreRepositoryFactory from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db -from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom +from models import EndUser, Message, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -64,7 +65,35 @@ class LangSmithDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id + # Check for parent_trace_context for cross-workflow linking + parent_trace_context = trace_info.metadata.get("parent_trace_context") + + if parent_trace_context: + # Inner workflow: resolve outer trace_id and link to parent node + outer_trace_id = parent_trace_context.get("parent_workflow_run_id") + + # Try to resolve message_id from conversation_id if available + if parent_trace_context.get("parent_conversation_id"): + try: + session_factory = sessionmaker(bind=db.engine) + with session_factory() as session: + message_data_stmt = select(Message.id).where( + Message.conversation_id == parent_trace_context["parent_conversation_id"], + Message.workflow_run_id == parent_trace_context["parent_workflow_run_id"], + ) + resolved_message_id = session.scalar(message_data_stmt) + if resolved_message_id: + outer_trace_id = resolved_message_id + except Exception as e: + logger.debug("Failed to resolve message_id from conversation_id: %s", str(e)) + + trace_id = outer_trace_id + parent_run_id = parent_trace_context.get("parent_node_execution_id") + else: + # Outer workflow: existing behavior + trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id + parent_run_id = trace_info.message_id or None + if trace_info.start_time is None: trace_info.start_time = datetime.now() message_dotted_order = ( @@ -78,7 +107,8 @@ class LangSmithDataTrace(BaseTraceInstance): metadata = trace_info.metadata metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id - if trace_info.message_id: + # Only create message_run for outer workflows (no parent_trace_context) + if trace_info.message_id and not parent_trace_context: message_run = LangSmithRunModel( id=trace_info.message_id, name=TraceTaskName.MESSAGE_TRACE, @@ -121,9 +151,9 @@ class LangSmithDataTrace(BaseTraceInstance): }, error=trace_info.error, tags=["workflow"], - parent_run_id=trace_info.message_id or None, + parent_run_id=parent_run_id, trace_id=trace_id, - dotted_order=workflow_dotted_order, + dotted_order=None if parent_trace_context else workflow_dotted_order, serialized=None, events=[], session_id=None, diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 84f5bf5512..3f7bc662fe 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -21,19 +21,25 @@ from core.ops.entities.config_entity import ( ) from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, + DraftNodeExecutionTrace, GenerateNameTraceInfo, MessageTraceInfo, ModerationTraceInfo, + PromptGenerationTraceInfo, SuggestedQuestionTraceInfo, TaskData, ToolTraceInfo, TraceTaskName, + WorkflowNodeTraceInfo, WorkflowTraceInfo, ) from core.ops.utils import get_message_data from extensions.ext_database import db from extensions.ext_storage import storage +from models.account import Tenant +from models.dataset import Dataset from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig +from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from models.workflow import WorkflowAppLog from tasks.ops_trace_task import process_trace_tasks @@ -43,6 +49,44 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]: + """Return (app_name, workspace_name) for the given IDs. Falls back to empty strings.""" + app_name = "" + workspace_name = "" + if not app_id and not tenant_id: + return app_name, workspace_name + with Session(db.engine) as session: + if app_id: + name = session.scalar(select(App.name).where(App.id == app_id)) + if name: + app_name = name + if tenant_id: + name = session.scalar(select(Tenant.name).where(Tenant.id == tenant_id)) + if name: + workspace_name = name + return app_name, workspace_name + + +_PROVIDER_TYPE_TO_MODEL: dict[str, type] = { + "builtin": BuiltinToolProvider, + "plugin": BuiltinToolProvider, + "api": ApiToolProvider, + "workflow": WorkflowToolProvider, + "mcp": MCPToolProvider, +} + + +def _lookup_credential_name(credential_id: str | None, provider_type: str | None) -> str: + if not credential_id: + return "" + model_cls = _PROVIDER_TYPE_TO_MODEL.get(provider_type or "") + if not model_cls: + return "" + with Session(db.engine) as session: + name = session.scalar(select(model_cls.name).where(model_cls.id == credential_id)) + return str(name) if name else "" + + class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): def __getitem__(self, provider: str) -> dict[str, Any]: match provider: @@ -317,6 +361,10 @@ class OpsTraceManager: if app_id is None: return None + # Handle storage_id format (tenant-{uuid}) - not a real app_id + if isinstance(app_id, str) and app_id.startswith("tenant-"): + return None + app: App | None = db.session.query(App).where(App.id == app_id).first() if app is None: @@ -479,6 +527,56 @@ class TraceTask: cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) return cls._workflow_run_repo + @classmethod + def _get_user_id_from_metadata(cls, metadata: dict[str, Any]) -> str: + """Extract user ID from metadata, prioritizing end_user over account. + + Returns the actual user ID (end_user or account) who invoked the workflow, + regardless of invoke_from context. + """ + # Priority 1: End user (external users via API/WebApp) + if user_id := metadata.get("from_end_user_id"): + return f"end_user:{user_id}" + + # Priority 2: Account user (internal users via console/debugger) + if user_id := metadata.get("from_account_id"): + return f"account:{user_id}" + + # Priority 3: User (internal users via console/debugger) + if user_id := metadata.get("user_id"): + return f"user:{user_id}" + + return "anonymous" + + @classmethod + def _calculate_workflow_token_split(cls, workflow_run_id: str, tenant_id: str) -> tuple[int, int]: + from core.workflow.enums import WorkflowNodeExecutionMetadataKey + from models.workflow import WorkflowNodeExecutionModel + + with Session(db.engine) as session: + node_executions = session.scalars( + select(WorkflowNodeExecutionModel).where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, + ) + ).all() + + total_prompt = 0 + total_completion = 0 + + for node_exec in node_executions: + metadata = node_exec.execution_metadata_dict + + prompt = metadata.get(WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS) + if prompt is not None: + total_prompt += prompt + + completion = metadata.get(WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS) + if completion is not None: + total_completion += completion + + return (total_prompt, total_completion) + def __init__( self, trace_type: Any, @@ -499,6 +597,8 @@ class TraceTask: self.app_id = None self.trace_id = None self.kwargs = kwargs + if user_id is not None and "user_id" not in self.kwargs: + self.kwargs["user_id"] = user_id external_trace_id = kwargs.get("external_trace_id") if external_trace_id: self.trace_id = external_trace_id @@ -512,7 +612,7 @@ class TraceTask: TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace( workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id ), - TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id), + TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id, **self.kwargs), TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace( message_id=self.message_id, timer=self.timer, **self.kwargs ), @@ -528,6 +628,9 @@ class TraceTask: TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace( conversation_id=self.conversation_id, timer=self.timer, **self.kwargs ), + TraceTaskName.PROMPT_GENERATION_TRACE: lambda: self.prompt_generation_trace(**self.kwargs), + TraceTaskName.NODE_EXECUTION_TRACE: lambda: self.node_execution_trace(**self.kwargs), + TraceTaskName.DRAFT_NODE_EXECUTION_TRACE: lambda: self.draft_node_execution_trace(**self.kwargs), } return preprocess_map.get(self.trace_type, lambda: None)() @@ -563,6 +666,10 @@ class TraceTask: total_tokens = workflow_run.total_tokens + prompt_tokens, completion_tokens = self._calculate_workflow_token_split( + workflow_run_id=workflow_run_id, tenant_id=tenant_id + ) + file_list = workflow_run_inputs.get("sys.file") or [] query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" @@ -583,7 +690,9 @@ class TraceTask: ) message_id = session.scalar(message_data_stmt) - metadata = { + app_name, workspace_name = _lookup_app_and_workspace_names(workflow_run.app_id, tenant_id) + + metadata: dict[str, Any] = { "workflow_id": workflow_id, "conversation_id": conversation_id, "workflow_run_id": workflow_run_id, @@ -596,8 +705,14 @@ class TraceTask: "triggered_from": workflow_run.triggered_from, "user_id": user_id, "app_id": workflow_run.app_id, + "app_name": app_name, + "workspace_name": workspace_name, } + parent_trace_context = self.kwargs.get("parent_trace_context") + if parent_trace_context: + metadata["parent_trace_context"] = parent_trace_context + workflow_trace_info = WorkflowTraceInfo( trace_id=self.trace_id, workflow_data=workflow_run.to_dict(), @@ -612,6 +727,8 @@ class TraceTask: workflow_run_version=workflow_run_version, error=error, total_tokens=total_tokens, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, file_list=file_list, query=query, metadata=metadata, @@ -619,10 +736,11 @@ class TraceTask: message_id=message_id, start_time=workflow_run.created_at, end_time=workflow_run.finished_at, + invoked_by=self._get_user_id_from_metadata(metadata), ) return workflow_trace_info - def message_trace(self, message_id: str | None): + def message_trace(self, message_id: str | None, **kwargs): if not message_id: return {} message_data = get_message_data(message_id) @@ -645,6 +763,14 @@ class TraceTask: streaming_metrics = self._extract_streaming_metrics(message_data) + tenant_id = "" + with Session(db.engine) as session: + tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id)) + if tid: + tenant_id = str(tid) + + app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id) + metadata = { "conversation_id": message_data.conversation_id, "ls_provider": message_data.model_provider, @@ -656,7 +782,14 @@ class TraceTask: "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, "message_id": message_id, + "tenant_id": tenant_id, + "app_id": message_data.app_id, + "user_id": message_data.from_end_user_id or message_data.from_account_id, + "app_name": app_name, + "workspace_name": workspace_name, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id message_tokens = message_data.message_tokens @@ -698,6 +831,8 @@ class TraceTask: "preset_response": moderation_result.preset_response, "query": moderation_result.query, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id # get workflow_app_log_id workflow_app_log_id = None @@ -739,6 +874,8 @@ class TraceTask: "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id # get workflow_app_log_id workflow_app_log_id = None @@ -778,6 +915,36 @@ class TraceTask: if not message_data: return {} + tenant_id = "" + with Session(db.engine) as session: + tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id)) + if tid: + tenant_id = str(tid) + + app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id) + + doc_list = [doc.model_dump() for doc in documents] if documents else [] + dataset_ids: set[str] = set() + for doc in doc_list: + doc_meta = doc.get("metadata") or {} + did = doc_meta.get("dataset_id") + if did: + dataset_ids.add(did) + + embedding_models: dict[str, dict[str, str]] = {} + if dataset_ids: + with Session(db.engine) as session: + rows = session.execute( + select(Dataset.id, Dataset.embedding_model, Dataset.embedding_model_provider).where( + Dataset.id.in_(list(dataset_ids)) + ) + ).all() + for row in rows: + embedding_models[str(row[0])] = { + "embedding_model": row[1] or "", + "embedding_model_provider": row[2] or "", + } + metadata = { "message_id": message_id, "ls_provider": message_data.model_provider, @@ -788,13 +955,21 @@ class TraceTask: "agent_based": message_data.agent_based, "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, + "tenant_id": tenant_id, + "app_id": message_data.app_id, + "user_id": message_data.from_end_user_id or message_data.from_account_id, + "app_name": app_name, + "workspace_name": workspace_name, + "embedding_models": embedding_models, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( trace_id=self.trace_id, message_id=message_id, inputs=message_data.query or message_data.inputs, - documents=[doc.model_dump() for doc in documents] if documents else [], + documents=doc_list, start_time=timer.get("start"), end_time=timer.get("end"), metadata=metadata, @@ -837,6 +1012,10 @@ class TraceTask: "error": error, "tool_parameters": tool_parameters, } + if message_data.workflow_run_id: + metadata["workflow_run_id"] = message_data.workflow_run_id + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id file_url = "" message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first() @@ -891,6 +1070,8 @@ class TraceTask: "conversation_id": conversation_id, "tenant_id": tenant_id, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id generate_name_trace_info = GenerateNameTraceInfo( trace_id=self.trace_id, @@ -905,6 +1086,158 @@ class TraceTask: return generate_name_trace_info + def prompt_generation_trace(self, **kwargs) -> PromptGenerationTraceInfo | dict: + tenant_id = kwargs.get("tenant_id", "") + user_id = kwargs.get("user_id", "") + app_id = kwargs.get("app_id") + operation_type = kwargs.get("operation_type", "") + instruction = kwargs.get("instruction", "") + generated_output = kwargs.get("generated_output", "") + + prompt_tokens = kwargs.get("prompt_tokens", 0) + completion_tokens = kwargs.get("completion_tokens", 0) + total_tokens = kwargs.get("total_tokens", 0) + + model_provider = kwargs.get("model_provider", "") + model_name = kwargs.get("model_name", "") + + latency = kwargs.get("latency", 0.0) + + timer = kwargs.get("timer") + start_time = timer.get("start") if timer else None + end_time = timer.get("end") if timer else None + + total_price = kwargs.get("total_price") + currency = kwargs.get("currency") + + error = kwargs.get("error") + + app_name = None + workspace_name = None + if app_id: + app_name, workspace_name = _lookup_app_and_workspace_names(app_id, tenant_id) + + metadata = { + "tenant_id": tenant_id, + "user_id": user_id, + "app_id": app_id or "", + "app_name": app_name, + "workspace_name": workspace_name, + "operation_type": operation_type, + "model_provider": model_provider, + "model_name": model_name, + } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id + + return PromptGenerationTraceInfo( + trace_id=self.trace_id, + inputs=instruction, + outputs=generated_output, + start_time=start_time, + end_time=end_time, + metadata=metadata, + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + operation_type=operation_type, + instruction=instruction, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + model_provider=model_provider, + model_name=model_name, + latency=latency, + total_price=total_price, + currency=currency, + error=error, + ) + + def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict: + node_data: dict = kwargs.get("node_execution_data", {}) + if not node_data: + return {} + + app_name, workspace_name = _lookup_app_and_workspace_names(node_data.get("app_id"), node_data.get("tenant_id")) + + credential_name = _lookup_credential_name( + node_data.get("credential_id"), node_data.get("credential_provider_type") + ) + + metadata: dict[str, Any] = { + "tenant_id": node_data.get("tenant_id"), + "app_id": node_data.get("app_id"), + "app_name": app_name, + "workspace_name": workspace_name, + "user_id": node_data.get("user_id"), + "dataset_ids": node_data.get("dataset_ids"), + "dataset_names": node_data.get("dataset_names"), + "plugin_name": node_data.get("plugin_name"), + "credential_name": credential_name, + } + + parent_trace_context = node_data.get("parent_trace_context") + if parent_trace_context: + metadata["parent_trace_context"] = parent_trace_context + + message_id: str | None = None + conversation_id = node_data.get("conversation_id") + workflow_execution_id = node_data.get("workflow_execution_id") + if conversation_id and workflow_execution_id and not parent_trace_context: + with Session(db.engine) as session: + msg_id = session.scalar( + select(Message.id).where( + Message.conversation_id == conversation_id, + Message.workflow_run_id == workflow_execution_id, + ) + ) + if msg_id: + message_id = str(msg_id) + metadata["message_id"] = message_id + + return WorkflowNodeTraceInfo( + trace_id=self.trace_id, + message_id=message_id, + start_time=node_data.get("created_at"), + end_time=node_data.get("finished_at"), + metadata=metadata, + workflow_id=node_data.get("workflow_id", ""), + workflow_run_id=node_data.get("workflow_execution_id", ""), + tenant_id=node_data.get("tenant_id", ""), + node_execution_id=node_data.get("node_execution_id", ""), + node_id=node_data.get("node_id", ""), + node_type=node_data.get("node_type", ""), + title=node_data.get("title", ""), + status=node_data.get("status", ""), + error=node_data.get("error"), + elapsed_time=node_data.get("elapsed_time", 0.0), + index=node_data.get("index", 0), + predecessor_node_id=node_data.get("predecessor_node_id"), + total_tokens=node_data.get("total_tokens", 0), + total_price=node_data.get("total_price", 0.0), + currency=node_data.get("currency"), + model_provider=node_data.get("model_provider"), + model_name=node_data.get("model_name"), + prompt_tokens=node_data.get("prompt_tokens"), + completion_tokens=node_data.get("completion_tokens"), + tool_name=node_data.get("tool_name"), + iteration_id=node_data.get("iteration_id"), + iteration_index=node_data.get("iteration_index"), + loop_id=node_data.get("loop_id"), + loop_index=node_data.get("loop_index"), + parallel_id=node_data.get("parallel_id"), + node_inputs=node_data.get("node_inputs"), + node_outputs=node_data.get("node_outputs"), + process_data=node_data.get("process_data"), + invoked_by=self._get_user_id_from_metadata(metadata), + ) + + def draft_node_execution_trace(self, **kwargs) -> DraftNodeExecutionTrace | dict: + node_trace = self.node_execution_trace(**kwargs) + if not node_trace or not isinstance(node_trace, WorkflowNodeTraceInfo): + return node_trace + return DraftNodeExecutionTrace(**node_trace.model_dump()) + def _extract_streaming_metrics(self, message_data) -> dict: if not message_data.message_metadata: return {} @@ -938,13 +1271,17 @@ class TraceQueueManager: self.user_id = user_id self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id) self.flask_app = current_app._get_current_object() # type: ignore + + from core.telemetry import is_enterprise_telemetry_enabled + + self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled() if trace_manager_timer is None: self.start_timer() def add_trace_task(self, trace_task: TraceTask): global trace_manager_timer, trace_manager_queue try: - if self.trace_instance: + if self._enterprise_telemetry_enabled or self.trace_instance: trace_task.app_id = self.app_id trace_manager_queue.put(trace_task) except Exception: @@ -980,20 +1317,27 @@ class TraceQueueManager: def send_to_celery(self, tasks: list[TraceTask]): with self.flask_app.app_context(): for task in tasks: - if task.app_id is None: - continue + storage_id = task.app_id + if storage_id is None: + tenant_id = task.kwargs.get("tenant_id") + if tenant_id: + storage_id = f"tenant-{tenant_id}" + else: + logger.warning("Skipping trace without app_id or tenant_id, trace_type: %s", task.trace_type) + continue + file_id = uuid4().hex trace_info = task.execute() task_data = TaskData( - app_id=task.app_id, + app_id=storage_id, trace_info_type=type(trace_info).__name__, trace_info=trace_info.model_dump() if trace_info else None, ) - file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json" + file_path = f"{OPS_FILE_PATH}{storage_id}/{file_id}.json" storage.save(file_path, task_data.model_dump_json().encode("utf-8")) file_info = { "file_id": file_id, - "app_id": task.app_id, + "app_id": storage_id, } process_trace_tasks.delay(file_info) # type: ignore diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 541c241ae5..33884378ce 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -27,8 +27,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.ops.ops_trace_manager import TraceQueueManager from core.ops.utils import measure_time from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate @@ -56,6 +55,8 @@ from core.rag.retrieval.template_prompts import ( METADATA_FILTER_USER_PROMPT_2, METADATA_FILTER_USER_PROMPT_3, ) +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit from core.tools.signature import sign_upload_file from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db @@ -728,10 +729,21 @@ class DatasetRetrieval: self.application_generate_entity.trace_manager if self.application_generate_entity else None ) if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer - ) + app_config = self.application_generate_entity.app_config if self.application_generate_entity else None + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.DATASET_RETRIEVAL_TRACE, + context=TelemetryContext( + tenant_id=app_config.tenant_id if app_config else None, + app_id=app_config.app_id if app_config else None, + ), + payload={ + "message_id": message_id, + "documents": documents, + "timer": timer, + }, + ), + trace_manager=trace_manager, ) def _on_query( diff --git a/api/core/telemetry/__init__.py b/api/core/telemetry/__init__.py new file mode 100644 index 0000000000..b1d25403a0 --- /dev/null +++ b/api/core/telemetry/__init__.py @@ -0,0 +1,60 @@ +"""Community telemetry helpers. + +Provides ``emit()`` which enqueues trace events into the CE trace pipeline +(``TraceQueueManager`` → ``ops_trace`` Celery queue → Langfuse / LangSmith / etc.). + +Enterprise-only traces (node execution, draft node execution, prompt generation) +are silently dropped when enterprise telemetry is disabled. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from core.ops.entities.trace_entity import TraceTaskName +from core.telemetry.events import TelemetryContext, TelemetryEvent + +if TYPE_CHECKING: + from core.ops.ops_trace_manager import TraceQueueManager + +_ENTERPRISE_ONLY_TRACES: frozenset[TraceTaskName] = frozenset( + { + TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + TraceTaskName.NODE_EXECUTION_TRACE, + TraceTaskName.PROMPT_GENERATION_TRACE, + } +) + + +def _is_enterprise_telemetry_enabled() -> bool: + try: + from enterprise.telemetry.exporter import is_enterprise_telemetry_enabled + + return is_enterprise_telemetry_enabled() + except Exception: + return False + + +def emit(event: TelemetryEvent, trace_manager: TraceQueueManager | None = None) -> None: + from core.ops.ops_trace_manager import TraceQueueManager as LocalTraceQueueManager + from core.ops.ops_trace_manager import TraceTask + + if event.name in _ENTERPRISE_ONLY_TRACES and not _is_enterprise_telemetry_enabled(): + return + + queue_manager = trace_manager or LocalTraceQueueManager( + app_id=event.context.app_id, + user_id=event.context.user_id, + ) + queue_manager.add_trace_task(TraceTask(event.name, **event.payload)) + + +is_enterprise_telemetry_enabled = _is_enterprise_telemetry_enabled + +__all__ = [ + "TelemetryContext", + "TelemetryEvent", + "TraceTaskName", + "emit", + "is_enterprise_telemetry_enabled", +] diff --git a/api/core/telemetry/events.py b/api/core/telemetry/events.py new file mode 100644 index 0000000000..35ace47510 --- /dev/null +++ b/api/core/telemetry/events.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from core.ops.entities.trace_entity import TraceTaskName + + +@dataclass(frozen=True) +class TelemetryContext: + tenant_id: str | None = None + user_id: str | None = None + app_id: str | None = None + + +@dataclass(frozen=True) +class TelemetryEvent: + name: TraceTaskName + context: TelemetryContext + payload: dict[str, Any] diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 9c1ceff145..0106f60c0d 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -50,6 +50,7 @@ class WorkflowTool(Tool): self.workflow_call_depth = workflow_call_depth self.label = label self._latest_usage = LLMUsage.empty_usage() + self.parent_trace_context: dict[str, str] | None = None super().__init__(entity=entity, runtime=runtime) @@ -90,11 +91,15 @@ class WorkflowTool(Tool): self._latest_usage = LLMUsage.empty_usage() + args: dict[str, Any] = {"inputs": tool_parameters, "files": files} + if self.parent_trace_context: + args["_parent_trace_context"] = self.parent_trace_context + result = generator.generate( app_model=app, workflow=workflow, user=user, - args={"inputs": tool_parameters, "files": files}, + args=args, invoke_from=self.runtime.invoke_from, streaming=False, call_depth=self.workflow_call_depth + 1, diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index bb3b13e8c6..938a2f5e21 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -232,6 +232,8 @@ class WorkflowNodeExecutionMetadataKey(StrEnum): """ TOTAL_TOKENS = "total_tokens" + PROMPT_TOKENS = "prompt_tokens" + COMPLETION_TOKENS = "completion_tokens" TOTAL_PRICE = "total_price" CURRENCY = "currency" TOOL_INFO = "tool_info" diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index beccf79344..92e9439acc 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -322,6 +322,8 @@ class LLMNode(Node[LLMNodeData]): outputs=outputs, metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS: usage.prompt_tokens, + WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS: usage.completion_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, }, diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 60d76db9b6..f498a23d13 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -61,6 +61,7 @@ class ToolNode(Node[ToolNodeData]): "provider_type": self.node_data.provider_type.value, "provider_id": self.node_data.provider_id, "plugin_unique_identifier": self.node_data.plugin_unique_identifier, + "credential_id": self.node_data.credential_id, } # get tool runtime @@ -105,6 +106,20 @@ class ToolNode(Node[ToolNodeData]): # get conversation id conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) + from core.tools.workflow_as_tool.tool import WorkflowTool + + if isinstance(tool_runtime, WorkflowTool): + workflow_run_id_var = self.graph_runtime_state.variable_pool.get( + ["sys", SystemVariableKey.WORKFLOW_EXECUTION_ID] + ) + tool_runtime.parent_trace_context = { + "trace_id": str(workflow_run_id_var.text) if workflow_run_id_var else "", + "parent_node_execution_id": self.execution_id, + "parent_workflow_run_id": str(workflow_run_id_var.text) if workflow_run_id_var else "", + "parent_app_id": self.app_id, + "parent_conversation_id": conversation_id.text if conversation_id else None, + } + try: message_stream = ToolEngine.generic_invoke( tool=tool_runtime, @@ -431,6 +446,8 @@ class ToolNode(Node[ToolNodeData]): } if isinstance(usage.total_tokens, int) and usage.total_tokens > 0: metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens + metadata[WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS] = usage.prompt_tokens + metadata[WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS] = usage.completion_tokens metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency diff --git a/api/enterprise/__init__.py b/api/enterprise/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/enterprise/telemetry/__init__.py b/api/enterprise/telemetry/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/enterprise/telemetry/contracts.py b/api/enterprise/telemetry/contracts.py new file mode 100644 index 0000000000..ac4cdeb323 --- /dev/null +++ b/api/enterprise/telemetry/contracts.py @@ -0,0 +1,83 @@ +"""Telemetry gateway contracts and data structures. + +This module defines the envelope format for telemetry events and the routing +configuration that determines how each event type is processed. +""" + +from __future__ import annotations + +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, field_validator + + +class TelemetryCase(StrEnum): + """Enumeration of all known telemetry event cases.""" + + WORKFLOW_RUN = "workflow_run" + NODE_EXECUTION = "node_execution" + DRAFT_NODE_EXECUTION = "draft_node_execution" + MESSAGE_RUN = "message_run" + TOOL_EXECUTION = "tool_execution" + MODERATION_CHECK = "moderation_check" + SUGGESTED_QUESTION = "suggested_question" + DATASET_RETRIEVAL = "dataset_retrieval" + GENERATE_NAME = "generate_name" + PROMPT_GENERATION = "prompt_generation" + APP_CREATED = "app_created" + APP_UPDATED = "app_updated" + APP_DELETED = "app_deleted" + FEEDBACK_CREATED = "feedback_created" + + +class SignalType(StrEnum): + """Signal routing type for telemetry cases.""" + + TRACE = "trace" + METRIC_LOG = "metric_log" + + +class CaseRoute(BaseModel): + """Routing configuration for a telemetry case. + + Attributes: + signal_type: The type of signal (trace or metric_log). + ce_eligible: Whether this case is eligible for community edition tracing. + """ + + signal_type: SignalType + ce_eligible: bool + + +class TelemetryEnvelope(BaseModel): + """Envelope for telemetry events. + + Attributes: + case: The telemetry case type. + tenant_id: The tenant identifier. + event_id: Unique event identifier for deduplication. + payload: The main event payload. + payload_fallback: Fallback payload (max 64KB). + metadata: Optional metadata dictionary. + """ + + case: TelemetryCase + tenant_id: str + event_id: str + payload: dict[str, Any] + payload_fallback: bytes | None = None + metadata: dict[str, Any] | None = None + + @field_validator("payload_fallback") + @classmethod + def validate_payload_fallback_size(cls, v: bytes | None) -> bytes | None: + """Validate that payload_fallback does not exceed 64KB.""" + if v is not None and len(v) > 65536: # 64 * 1024 + raise ValueError("payload_fallback must not exceed 64KB") + return v + + class Config: + """Pydantic configuration.""" + + use_enum_values = False diff --git a/api/enterprise/telemetry/draft_trace.py b/api/enterprise/telemetry/draft_trace.py new file mode 100644 index 0000000000..ea8088695e --- /dev/null +++ b/api/enterprise/telemetry/draft_trace.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit +from core.workflow.enums import WorkflowNodeExecutionMetadataKey +from models.workflow import WorkflowNodeExecutionModel + + +def enqueue_draft_node_execution_trace( + *, + execution: WorkflowNodeExecutionModel, + outputs: Mapping[str, Any] | None, + workflow_execution_id: str | None, + user_id: str, +) -> None: + node_data = _build_node_execution_data( + execution=execution, + outputs=outputs, + workflow_execution_id=workflow_execution_id, + ) + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id=execution.tenant_id, + user_id=user_id, + app_id=execution.app_id, + ), + payload={"node_execution_data": node_data}, + ) + ) + + +def _build_node_execution_data( + *, + execution: WorkflowNodeExecutionModel, + outputs: Mapping[str, Any] | None, + workflow_execution_id: str | None, +) -> dict[str, Any]: + metadata = execution.execution_metadata_dict + node_outputs = outputs if outputs is not None else execution.outputs_dict + execution_id = workflow_execution_id or execution.workflow_run_id or execution.id + + return { + "workflow_id": execution.workflow_id, + "workflow_execution_id": execution_id, + "tenant_id": execution.tenant_id, + "app_id": execution.app_id, + "node_execution_id": execution.id, + "node_id": execution.node_id, + "node_type": execution.node_type, + "title": execution.title, + "status": execution.status, + "error": execution.error, + "elapsed_time": execution.elapsed_time, + "index": execution.index, + "predecessor_node_id": execution.predecessor_node_id, + "created_at": execution.created_at, + "finished_at": execution.finished_at, + "total_tokens": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0), + "total_price": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0), + "currency": metadata.get(WorkflowNodeExecutionMetadataKey.CURRENCY), + "tool_name": (metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name") + if isinstance(metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict) + else None, + "iteration_id": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID), + "iteration_index": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX), + "loop_id": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID), + "loop_index": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX), + "parallel_id": metadata.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID), + "node_inputs": execution.inputs_dict, + "node_outputs": node_outputs, + "process_data": execution.process_data_dict, + } diff --git a/api/enterprise/telemetry/enterprise_trace.py b/api/enterprise/telemetry/enterprise_trace.py new file mode 100644 index 0000000000..2b0c9ade7c --- /dev/null +++ b/api/enterprise/telemetry/enterprise_trace.py @@ -0,0 +1,844 @@ +"""Enterprise trace handler — duck-typed, NOT a BaseTraceInstance subclass. + +Invoked directly in the Celery task, not through OpsTraceManager dispatch. +Only requires a matching ``trace(trace_info)`` method signature. + +Signal strategy: +- **Traces (spans)**: workflow run, node execution, draft node execution only. +- **Metrics + structured logs**: all other event types. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, cast + +from opentelemetry.util.types import AttributeValue + +from core.ops.entities.trace_entity import ( + BaseTraceInfo, + DatasetRetrievalTraceInfo, + DraftNodeExecutionTrace, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + PromptGenerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowNodeTraceInfo, + WorkflowTraceInfo, +) +from enterprise.telemetry.entities import ( + EnterpriseTelemetryCounter, + EnterpriseTelemetryHistogram, + EnterpriseTelemetrySpan, +) +from enterprise.telemetry.telemetry_log import emit_metric_only_event, emit_telemetry_log + +logger = logging.getLogger(__name__) + + +class EnterpriseOtelTrace: + """Duck-typed enterprise trace handler. + + ``*_trace`` methods emit spans (workflow/node only) or structured logs + (all other events), plus metrics at 100 % accuracy. + """ + + def __init__(self) -> None: + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if exporter is None: + raise RuntimeError("EnterpriseOtelTrace instantiated but exporter is not initialized") + self._exporter = exporter + + def trace(self, trace_info: BaseTraceInfo) -> None: + if isinstance(trace_info, WorkflowTraceInfo): + self._workflow_trace(trace_info) + elif isinstance(trace_info, MessageTraceInfo): + self._message_trace(trace_info) + elif isinstance(trace_info, ToolTraceInfo): + self._tool_trace(trace_info) + elif isinstance(trace_info, DraftNodeExecutionTrace): + self._draft_node_execution_trace(trace_info) + elif isinstance(trace_info, WorkflowNodeTraceInfo): + self._node_execution_trace(trace_info) + elif isinstance(trace_info, ModerationTraceInfo): + self._moderation_trace(trace_info) + elif isinstance(trace_info, SuggestedQuestionTraceInfo): + self._suggested_question_trace(trace_info) + elif isinstance(trace_info, DatasetRetrievalTraceInfo): + self._dataset_retrieval_trace(trace_info) + elif isinstance(trace_info, GenerateNameTraceInfo): + self._generate_name_trace(trace_info) + elif isinstance(trace_info, PromptGenerationTraceInfo): + self._prompt_generation_trace(trace_info) + + def _common_attrs(self, trace_info: BaseTraceInfo) -> dict[str, Any]: + metadata = self._metadata(trace_info) + tenant_id, app_id, user_id = self._context_ids(trace_info, metadata) + return { + "dify.trace_id": trace_info.trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "gen_ai.user.id": user_id, + "dify.message.id": trace_info.message_id, + } + + def _metadata(self, trace_info: BaseTraceInfo) -> dict[str, Any]: + return trace_info.metadata + + def _context_ids( + self, + trace_info: BaseTraceInfo, + metadata: dict[str, Any], + ) -> tuple[str | None, str | None, str | None]: + tenant_id = getattr(trace_info, "tenant_id", None) or metadata.get("tenant_id") + app_id = getattr(trace_info, "app_id", None) or metadata.get("app_id") + user_id = getattr(trace_info, "user_id", None) or metadata.get("user_id") + return tenant_id, app_id, user_id + + def _labels(self, **values: AttributeValue) -> dict[str, AttributeValue]: + return dict(values) + + def _safe_payload_value(self, value: Any) -> str | dict[str, Any] | list[object] | None: + if isinstance(value, str): + return value + if isinstance(value, dict): + return cast(dict[str, Any], value) + if isinstance(value, list): + items: list[object] = [] + for item in cast(list[object], value): + items.append(item) + return items + return None + + def _content_or_ref(self, value: Any, ref: str) -> Any: + if self._exporter.include_content: + return self._maybe_json(value) + return ref + + def _maybe_json(self, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + try: + return json.dumps(value, default=str) + except (TypeError, ValueError): + return str(value) + + # ------------------------------------------------------------------ + # SPAN-emitting handlers (workflow, node execution, draft node) + # ------------------------------------------------------------------ + + def _workflow_trace(self, info: WorkflowTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + # -- Slim span attrs: identity + structure + status + timing only -- + span_attrs: dict[str, Any] = { + "dify.trace_id": info.trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.workflow.id": info.workflow_id, + "dify.workflow.run_id": info.workflow_run_id, + "dify.workflow.status": info.workflow_run_status, + "dify.workflow.error": info.error, + "dify.workflow.elapsed_time": info.workflow_run_elapsed_time, + "dify.invoke_from": metadata.get("triggered_from"), + "dify.conversation.id": info.conversation_id, + "dify.message.id": info.message_id, + "dify.invoked_by": info.invoked_by, + } + + trace_correlation_override: str | None = None + parent_span_id_source: str | None = None + + parent_ctx = metadata.get("parent_trace_context") + if isinstance(parent_ctx, dict): + parent_ctx_dict = cast(dict[str, Any], parent_ctx) + span_attrs["dify.parent.trace_id"] = parent_ctx_dict.get("trace_id") + span_attrs["dify.parent.node.execution_id"] = parent_ctx_dict.get("parent_node_execution_id") + span_attrs["dify.parent.workflow.run_id"] = parent_ctx_dict.get("parent_workflow_run_id") + span_attrs["dify.parent.app.id"] = parent_ctx_dict.get("parent_app_id") + + trace_override_value = parent_ctx_dict.get("parent_workflow_run_id") + if isinstance(trace_override_value, str): + trace_correlation_override = trace_override_value + parent_span_value = parent_ctx_dict.get("parent_node_execution_id") + if isinstance(parent_span_value, str): + parent_span_id_source = parent_span_value + + self._exporter.export_span( + EnterpriseTelemetrySpan.WORKFLOW_RUN, + span_attrs, + correlation_id=info.workflow_run_id, + span_id_source=info.workflow_run_id, + start_time=info.start_time, + end_time=info.end_time, + trace_correlation_override=trace_correlation_override, + parent_span_id_source=parent_span_id_source, + ) + + # -- Companion log: ALL attrs (span + detail) for full picture -- + log_attrs: dict[str, Any] = {**span_attrs} + log_attrs.update( + { + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "gen_ai.user.id": user_id, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.workflow.version": info.workflow_run_version, + } + ) + + ref = f"ref:workflow_run_id={info.workflow_run_id}" + log_attrs["dify.workflow.inputs"] = self._content_or_ref(info.workflow_run_inputs, ref) + log_attrs["dify.workflow.outputs"] = self._content_or_ref(info.workflow_run_outputs, ref) + log_attrs["dify.workflow.query"] = self._content_or_ref(info.query, ref) + + emit_telemetry_log( + event_name="dify.workflow.run", + attributes=log_attrs, + signal="span_detail", + trace_id_source=info.workflow_run_id, + span_id_source=info.workflow_run_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + # -- Metrics -- + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, labels) + if info.prompt_tokens is not None and info.prompt_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, labels) + if info.completion_tokens is not None and info.completion_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, labels) + invoke_from = metadata.get("triggered_from", "") + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="workflow", + status=info.workflow_run_status, + invoke_from=invoke_from, + ), + ) + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.WORKFLOW_DURATION, + float(info.workflow_run_elapsed_time), + self._labels( + **labels, + status=info.workflow_run_status, + ), + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="workflow", + ), + ) + + def _node_execution_trace(self, info: WorkflowNodeTraceInfo) -> None: + self._emit_node_execution_trace(info, EnterpriseTelemetrySpan.NODE_EXECUTION, "node") + + def _draft_node_execution_trace(self, info: DraftNodeExecutionTrace) -> None: + self._emit_node_execution_trace( + info, + EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION, + "draft_node", + correlation_id_override=info.node_execution_id, + trace_correlation_override_param=info.workflow_run_id, + ) + + def _emit_node_execution_trace( + self, + info: WorkflowNodeTraceInfo, + span_name: EnterpriseTelemetrySpan, + request_type: str, + correlation_id_override: str | None = None, + trace_correlation_override_param: str | None = None, + ) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + # -- Slim span attrs: identity + structure + status + timing -- + span_attrs: dict[str, Any] = { + "dify.trace_id": info.trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.workflow.id": info.workflow_id, + "dify.workflow.run_id": info.workflow_run_id, + "dify.message.id": info.message_id, + "dify.conversation.id": metadata.get("conversation_id"), + "dify.node.execution_id": info.node_execution_id, + "dify.node.id": info.node_id, + "dify.node.type": info.node_type, + "dify.node.title": info.title, + "dify.node.status": info.status, + "dify.node.error": info.error, + "dify.node.elapsed_time": info.elapsed_time, + "dify.node.index": info.index, + "dify.node.predecessor_node_id": info.predecessor_node_id, + "dify.node.iteration_id": info.iteration_id, + "dify.node.loop_id": info.loop_id, + "dify.node.parallel_id": info.parallel_id, + "dify.node.invoked_by": info.invoked_by, + } + + trace_correlation_override = trace_correlation_override_param + parent_ctx = metadata.get("parent_trace_context") + if isinstance(parent_ctx, dict): + parent_ctx_dict = cast(dict[str, Any], parent_ctx) + override_value = parent_ctx_dict.get("parent_workflow_run_id") + if isinstance(override_value, str): + trace_correlation_override = override_value + + effective_correlation_id = correlation_id_override or info.workflow_run_id + self._exporter.export_span( + span_name, + span_attrs, + correlation_id=effective_correlation_id, + span_id_source=info.node_execution_id, + start_time=info.start_time, + end_time=info.end_time, + trace_correlation_override=trace_correlation_override, + ) + + # -- Companion log: ALL attrs (span + detail) -- + log_attrs: dict[str, Any] = {**span_attrs} + log_attrs.update( + { + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "dify.invoke_from": metadata.get("invoke_from"), + "gen_ai.user.id": user_id, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.node.total_price": info.total_price, + "dify.node.currency": info.currency, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_name, + "gen_ai.tool.name": info.tool_name, + "dify.node.iteration_index": info.iteration_index, + "dify.node.loop_index": info.loop_index, + "dify.plugin.name": metadata.get("plugin_name"), + "dify.credential.name": metadata.get("credential_name"), + "dify.dataset.ids": self._maybe_json(metadata.get("dataset_ids")), + "dify.dataset.names": self._maybe_json(metadata.get("dataset_names")), + } + ) + + ref = f"ref:node_execution_id={info.node_execution_id}" + log_attrs["dify.node.inputs"] = self._content_or_ref(info.node_inputs, ref) + log_attrs["dify.node.outputs"] = self._content_or_ref(info.node_outputs, ref) + log_attrs["dify.node.process_data"] = self._content_or_ref(info.process_data, ref) + + emit_telemetry_log( + event_name=span_name.value, + attributes=log_attrs, + signal="span_detail", + trace_id_source=info.workflow_run_id, + span_id_source=info.node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + # -- Metrics -- + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + node_type=info.node_type, + model_provider=info.model_provider or "", + ) + if info.total_tokens: + token_labels = self._labels( + **labels, + model_name=info.model_name or "", + ) + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.prompt_tokens is not None and info.prompt_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels + ) + if info.completion_tokens is not None and info.completion_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type=request_type, + status=info.status, + ), + ) + duration_labels = dict(labels) + plugin_name = metadata.get("plugin_name") + if plugin_name and info.node_type in {"tool", "knowledge-retrieval"}: + duration_labels["plugin_name"] = plugin_name + self._exporter.record_histogram(EnterpriseTelemetryHistogram.NODE_DURATION, info.elapsed_time, duration_labels) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type=request_type, + ), + ) + + # ------------------------------------------------------------------ + # METRIC-ONLY handlers (structured log + counters/histograms) + # ------------------------------------------------------------------ + + def _message_trace(self, info: MessageTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "dify.invoke_from": metadata.get("from_source"), + "dify.conversation.id": metadata.get("conversation_id"), + "dify.conversation.mode": info.conversation_mode, + "gen_ai.provider.name": metadata.get("ls_provider"), + "gen_ai.request.model": metadata.get("ls_model_name"), + "gen_ai.usage.input_tokens": info.message_tokens, + "gen_ai.usage.output_tokens": info.answer_tokens, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.message.status": metadata.get("status"), + "dify.message.error": info.error, + "dify.message.from_source": metadata.get("from_source"), + "dify.message.from_end_user_id": metadata.get("from_end_user_id"), + "dify.message.from_account_id": metadata.get("from_account_id"), + "dify.streaming": info.is_streaming_request, + "dify.message.time_to_first_token": info.gen_ai_server_time_to_first_token, + "dify.message.streaming_duration": info.llm_streaming_time_to_generate, + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + ref = f"ref:message_id={info.message_id}" + inputs = self._safe_payload_value(info.inputs) + outputs = self._safe_payload_value(info.outputs) + attrs["dify.message.inputs"] = self._content_or_ref(inputs, ref) + attrs["dify.message.outputs"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name="dify.message.run", + attributes=attrs, + trace_id_source=metadata.get("workflow_run_id") or str(info.message_id) if info.message_id else None, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + model_provider=metadata.get("ls_provider", ""), + model_name=metadata.get("ls_model_name", ""), + ) + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, labels) + invoke_from = metadata.get("from_source", "") + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="message", + status=metadata.get("status", ""), + invoke_from=invoke_from, + ), + ) + + if info.start_time and info.end_time: + duration = (info.end_time - info.start_time).total_seconds() + self._exporter.record_histogram(EnterpriseTelemetryHistogram.MESSAGE_DURATION, duration, labels) + + if info.gen_ai_server_time_to_first_token is not None: + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.MESSAGE_TTFT, info.gen_ai_server_time_to_first_token, labels + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="message", + ), + ) + + def _tool_trace(self, info: ToolTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "gen_ai.tool.name": info.tool_name, + "dify.tool.time_cost": info.time_cost, + "dify.tool.error": info.error, + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + ref = f"ref:message_id={info.message_id}" + attrs["dify.tool.inputs"] = self._content_or_ref(info.tool_inputs, ref) + attrs["dify.tool.outputs"] = self._content_or_ref(info.tool_outputs, ref) + attrs["dify.tool.parameters"] = self._content_or_ref(info.tool_parameters, ref) + attrs["dify.tool.config"] = self._content_or_ref(info.tool_config, ref) + + emit_metric_only_event( + event_name="dify.tool.execution", + attributes=attrs, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + tool_name=info.tool_name, + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="tool", + ), + ) + self._exporter.record_histogram(EnterpriseTelemetryHistogram.TOOL_DURATION, float(info.time_cost), labels) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="tool", + ), + ) + + def _moderation_trace(self, info: ModerationTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "dify.moderation.flagged": info.flagged, + "dify.moderation.action": info.action, + "dify.moderation.preset_response": info.preset_response, + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + attrs["dify.moderation.query"] = self._content_or_ref( + info.query, + f"ref:message_id={info.message_id}", + ) + + emit_metric_only_event( + event_name="dify.moderation.check", + attributes=attrs, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="moderation", + ), + ) + + def _suggested_question_trace(self, info: SuggestedQuestionTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.suggested_question.status": info.status, + "dify.suggested_question.error": info.error, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_id, + "dify.suggested_question.count": len(info.suggested_question), + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + attrs["dify.suggested_question.questions"] = self._content_or_ref( + info.suggested_question, + f"ref:message_id={info.message_id}", + ) + + emit_metric_only_event( + event_name="dify.suggested_question.generation", + attributes=attrs, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="suggested_question", + ), + ) + + def _dataset_retrieval_trace(self, info: DatasetRetrievalTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs["dify.dataset.error"] = info.error + attrs["dify.workflow.run_id"] = metadata.get("workflow_run_id") + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + docs: list[dict[str, Any]] = [] + documents_any: Any = info.documents + documents_list: list[Any] = cast(list[Any], documents_any) if isinstance(documents_any, list) else [] + for entry in documents_list: + if isinstance(entry, dict): + entry_dict: dict[str, Any] = cast(dict[str, Any], entry) + docs.append(entry_dict) + dataset_ids: list[str] = [] + dataset_names: list[str] = [] + structured_docs: list[dict[str, Any]] = [] + for doc in docs: + meta_raw = doc.get("metadata") + meta: dict[str, Any] = cast(dict[str, Any], meta_raw) if isinstance(meta_raw, dict) else {} + did = meta.get("dataset_id") + dname = meta.get("dataset_name") + if did and did not in dataset_ids: + dataset_ids.append(did) + if dname and dname not in dataset_names: + dataset_names.append(dname) + structured_docs.append( + { + "dataset_id": did, + "document_id": meta.get("document_id"), + "segment_id": meta.get("segment_id"), + "score": meta.get("score"), + } + ) + + attrs["dify.dataset.ids"] = self._maybe_json(dataset_ids) + attrs["dify.dataset.names"] = self._maybe_json(dataset_names) + attrs["dify.retrieval.document_count"] = len(docs) + + embedding_models_raw: Any = metadata.get("embedding_models") + embedding_models: dict[str, Any] = ( + cast(dict[str, Any], embedding_models_raw) if isinstance(embedding_models_raw, dict) else {} + ) + if embedding_models: + providers: list[str] = [] + models: list[str] = [] + for ds_info in embedding_models.values(): + if isinstance(ds_info, dict): + ds_info_dict: dict[str, Any] = cast(dict[str, Any], ds_info) + p = ds_info_dict.get("embedding_model_provider", "") + m = ds_info_dict.get("embedding_model", "") + if p and p not in providers: + providers.append(p) + if m and m not in models: + models.append(m) + attrs["dify.dataset.embedding_providers"] = self._maybe_json(providers) + attrs["dify.dataset.embedding_models"] = self._maybe_json(models) + + ref = f"ref:message_id={info.message_id}" + retrieval_inputs = self._safe_payload_value(info.inputs) + attrs["dify.retrieval.query"] = self._content_or_ref(retrieval_inputs, ref) + attrs["dify.dataset.documents"] = self._content_or_ref(structured_docs, ref) + + emit_metric_only_event( + event_name="dify.dataset.retrieval", + attributes=attrs, + trace_id_source=metadata.get("workflow_run_id") or str(info.message_id) if info.message_id else None, + span_id_source=node_execution_id or (str(info.message_id) if info.message_id else None), + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="dataset_retrieval", + ), + ) + + for did in dataset_ids: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.DATASET_RETRIEVALS, + 1, + self._labels( + **labels, + dataset_id=did, + ), + ) + + def _generate_name_trace(self, info: GenerateNameTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs["dify.conversation.id"] = info.conversation_id + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + ref = f"ref:conversation_id={info.conversation_id}" + inputs = self._safe_payload_value(info.inputs) + outputs = self._safe_payload_value(info.outputs) + attrs["dify.generate_name.inputs"] = self._content_or_ref(inputs, ref) + attrs["dify.generate_name.outputs"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name="dify.generate_name.execution", + attributes=attrs, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="generate_name", + ), + ) + + def _prompt_generation_trace(self, info: PromptGenerationTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = { + "dify.trace_id": info.trace_id, + "dify.tenant_id": tenant_id, + "dify.user.id": user_id, + "dify.app.id": app_id or "", + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "dify.operation.type": info.operation_type, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_name, + "gen_ai.usage.input_tokens": info.prompt_tokens, + "gen_ai.usage.output_tokens": info.completion_tokens, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.prompt_generation.latency": info.latency, + "dify.prompt_generation.error": info.error, + } + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + if info.total_price is not None: + attrs["dify.prompt_generation.total_price"] = info.total_price + attrs["dify.prompt_generation.currency"] = info.currency + + ref = f"ref:trace_id={info.trace_id}" + outputs = self._safe_payload_value(info.outputs) + attrs["dify.prompt_generation.instruction"] = self._content_or_ref(info.instruction, ref) + attrs["dify.prompt_generation.output"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name="dify.prompt_generation.execution", + attributes=attrs, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=info.operation_type, + model_provider=info.model_provider, + model_name=info.model_name, + ) + + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, labels) + if info.prompt_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, labels) + if info.completion_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, labels) + + status = "failed" if info.error else "success" + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="prompt_generation", + status=status, + ), + ) + + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION, + info.latency, + labels, + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="prompt_generation", + ), + ) diff --git a/api/enterprise/telemetry/entities/__init__.py b/api/enterprise/telemetry/entities/__init__.py new file mode 100644 index 0000000000..c0a5499c6c --- /dev/null +++ b/api/enterprise/telemetry/entities/__init__.py @@ -0,0 +1,33 @@ +from enum import StrEnum + + +class EnterpriseTelemetrySpan(StrEnum): + WORKFLOW_RUN = "dify.workflow.run" + NODE_EXECUTION = "dify.node.execution" + DRAFT_NODE_EXECUTION = "dify.node.execution.draft" + + +class EnterpriseTelemetryCounter(StrEnum): + TOKENS = "tokens" + INPUT_TOKENS = "input_tokens" + OUTPUT_TOKENS = "output_tokens" + REQUESTS = "requests" + ERRORS = "errors" + FEEDBACK = "feedback" + DATASET_RETRIEVALS = "dataset_retrievals" + + +class EnterpriseTelemetryHistogram(StrEnum): + WORKFLOW_DURATION = "workflow_duration" + NODE_DURATION = "node_duration" + MESSAGE_DURATION = "message_duration" + MESSAGE_TTFT = "message_ttft" + TOOL_DURATION = "tool_duration" + PROMPT_GENERATION_DURATION = "prompt_generation_duration" + + +__all__ = [ + "EnterpriseTelemetryCounter", + "EnterpriseTelemetryHistogram", + "EnterpriseTelemetrySpan", +] diff --git a/api/enterprise/telemetry/event_handlers.py b/api/enterprise/telemetry/event_handlers.py new file mode 100644 index 0000000000..38276c7f0f --- /dev/null +++ b/api/enterprise/telemetry/event_handlers.py @@ -0,0 +1,130 @@ +"""Blinker signal handlers for enterprise telemetry. + +Registered at import time via ``@signal.connect`` decorators. +Import must happen during ``ext_enterprise_telemetry.init_app()`` to ensure handlers fire. +""" + +from __future__ import annotations + +import logging +import uuid + +from events.app_event import app_was_created, app_was_deleted, app_was_updated +from events.feedback_event import feedback_was_created + +logger = logging.getLogger(__name__) + +__all__ = [ + "_handle_app_created", + "_handle_app_deleted", + "_handle_app_updated", + "_handle_feedback_created", +] + + +@app_was_created.connect +def _handle_app_created(sender: object, **kwargs: object) -> None: + from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + from tasks.enterprise_telemetry_task import process_enterprise_telemetry + + exporter = get_enterprise_exporter() + if not exporter: + return + + tenant_id = str(getattr(sender, "tenant_id", "") or "") + payload = { + "app_id": getattr(sender, "id", None), + "mode": getattr(sender, "mode", None), + } + + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id=tenant_id, + event_id=str(uuid.uuid4()), + payload=payload, + ) + + process_enterprise_telemetry.delay(envelope.model_dump_json()) + + +@app_was_deleted.connect +def _handle_app_deleted(sender: object, **kwargs: object) -> None: + from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + from tasks.enterprise_telemetry_task import process_enterprise_telemetry + + exporter = get_enterprise_exporter() + if not exporter: + return + + tenant_id = str(getattr(sender, "tenant_id", "") or "") + payload = { + "app_id": getattr(sender, "id", None), + } + + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_DELETED, + tenant_id=tenant_id, + event_id=str(uuid.uuid4()), + payload=payload, + ) + + process_enterprise_telemetry.delay(envelope.model_dump_json()) + + +@app_was_updated.connect +def _handle_app_updated(sender: object, **kwargs: object) -> None: + from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + from tasks.enterprise_telemetry_task import process_enterprise_telemetry + + exporter = get_enterprise_exporter() + if not exporter: + return + + tenant_id = str(getattr(sender, "tenant_id", "") or "") + payload = { + "app_id": getattr(sender, "id", None), + } + + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_UPDATED, + tenant_id=tenant_id, + event_id=str(uuid.uuid4()), + payload=payload, + ) + + process_enterprise_telemetry.delay(envelope.model_dump_json()) + + +@feedback_was_created.connect +def _handle_feedback_created(sender: object, **kwargs: object) -> None: + from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + from tasks.enterprise_telemetry_task import process_enterprise_telemetry + + exporter = get_enterprise_exporter() + if not exporter: + return + + tenant_id = str(kwargs.get("tenant_id", "") or "") + payload = { + "message_id": getattr(sender, "message_id", None), + "app_id": getattr(sender, "app_id", None), + "conversation_id": getattr(sender, "conversation_id", None), + "from_end_user_id": getattr(sender, "from_end_user_id", None), + "from_account_id": getattr(sender, "from_account_id", None), + "rating": getattr(sender, "rating", None), + "from_source": getattr(sender, "from_source", None), + "content": getattr(sender, "content", None), + } + + envelope = TelemetryEnvelope( + case=TelemetryCase.FEEDBACK_CREATED, + tenant_id=tenant_id, + event_id=str(uuid.uuid4()), + payload=payload, + ) + + process_enterprise_telemetry.delay(envelope.model_dump_json()) diff --git a/api/enterprise/telemetry/exporter.py b/api/enterprise/telemetry/exporter.py new file mode 100644 index 0000000000..529c38741a --- /dev/null +++ b/api/enterprise/telemetry/exporter.py @@ -0,0 +1,252 @@ +"""Enterprise OTEL exporter — shared by EnterpriseOtelTrace, event handlers, and direct instrumentation. + +Uses dedicated TracerProvider and MeterProvider instances (configurable sampling, +independent from ext_otel.py infrastructure). + +Initialized once during Flask extension init (single-threaded via ext_enterprise_telemetry.py). +Accessed via ``ext_enterprise_telemetry.get_enterprise_exporter()`` from any thread/process. +""" + +import logging +import socket +import uuid +from datetime import datetime +from typing import Any, cast + +from opentelemetry import trace +from opentelemetry.context import Context +from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio +from opentelemetry.semconv.resource import ResourceAttributes +from opentelemetry.trace import SpanContext, TraceFlags +from opentelemetry.util.types import Attributes, AttributeValue + +from configs import dify_config +from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram +from enterprise.telemetry.id_generator import ( + CorrelationIdGenerator, + compute_deterministic_span_id, + set_correlation_id, + set_span_id_source, +) + +logger = logging.getLogger(__name__) + + +def is_enterprise_telemetry_enabled() -> bool: + return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED) + + +def _parse_otlp_headers(raw: str) -> dict[str, str]: + """Parse ``key=value,key2=value2`` into a dict.""" + if not raw: + return {} + headers: dict[str, str] = {} + for pair in raw.split(","): + if "=" not in pair: + continue + k, v = pair.split("=", 1) + headers[k.strip()] = v.strip() + return headers + + +def _datetime_to_ns(dt: datetime) -> int: + """Convert a datetime to nanoseconds since epoch (OTEL convention).""" + return int(dt.timestamp() * 1_000_000_000) + + +class _ExporterFactory: + def __init__(self, protocol: str, endpoint: str, headers: dict[str, str]): + self._protocol = protocol + self._endpoint = endpoint + self._headers = headers + self._grpc_headers = tuple(headers.items()) if headers else None + self._http_headers = headers or None + + def create_trace_exporter(self) -> HTTPSpanExporter | GRPCSpanExporter: + if self._protocol == "grpc": + return GRPCSpanExporter( + endpoint=self._endpoint or None, + headers=self._grpc_headers, + insecure=True, + ) + trace_endpoint = f"{self._endpoint}/v1/traces" if self._endpoint else "" + return HTTPSpanExporter(endpoint=trace_endpoint or None, headers=self._http_headers) + + def create_metric_exporter(self) -> HTTPMetricExporter | GRPCMetricExporter: + if self._protocol == "grpc": + return GRPCMetricExporter( + endpoint=self._endpoint or None, + headers=self._grpc_headers, + insecure=True, + ) + metric_endpoint = f"{self._endpoint}/v1/metrics" if self._endpoint else "" + return HTTPMetricExporter(endpoint=metric_endpoint or None, headers=self._http_headers) + + +class EnterpriseExporter: + """Shared OTEL exporter for all enterprise telemetry. + + ``export_span`` creates spans with optional real timestamps, deterministic + span/trace IDs, and cross-workflow parent linking. + ``increment_counter`` / ``record_histogram`` emit OTEL metrics at 100% accuracy. + """ + + def __init__(self, config: object) -> None: + endpoint: str = getattr(config, "ENTERPRISE_OTLP_ENDPOINT", "") + headers_raw: str = getattr(config, "ENTERPRISE_OTLP_HEADERS", "") + protocol: str = (getattr(config, "ENTERPRISE_OTLP_PROTOCOL", "http") or "http").lower() + service_name: str = getattr(config, "ENTERPRISE_SERVICE_NAME", "dify") + sampling_rate: float = getattr(config, "ENTERPRISE_OTEL_SAMPLING_RATE", 1.0) + self.include_content: bool = getattr(config, "ENTERPRISE_INCLUDE_CONTENT", True) + + resource = Resource( + attributes={ + ResourceAttributes.SERVICE_NAME: service_name, + ResourceAttributes.HOST_NAME: socket.gethostname(), + } + ) + sampler = ParentBasedTraceIdRatio(sampling_rate) + id_generator = CorrelationIdGenerator() + self._tracer_provider = TracerProvider(resource=resource, sampler=sampler, id_generator=id_generator) + + headers = _parse_otlp_headers(headers_raw) + factory = _ExporterFactory(protocol, endpoint, headers) + + trace_exporter = factory.create_trace_exporter() + self._tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter)) + self._tracer = self._tracer_provider.get_tracer("dify.enterprise") + + metric_exporter = factory.create_metric_exporter() + self._meter_provider = MeterProvider( + resource=resource, + metric_readers=[PeriodicExportingMetricReader(metric_exporter)], + ) + meter = self._meter_provider.get_meter("dify.enterprise") + self._counters = { + EnterpriseTelemetryCounter.TOKENS: meter.create_counter("dify.tokens.total", unit="{token}"), + EnterpriseTelemetryCounter.INPUT_TOKENS: meter.create_counter("dify.tokens.input", unit="{token}"), + EnterpriseTelemetryCounter.OUTPUT_TOKENS: meter.create_counter("dify.tokens.output", unit="{token}"), + EnterpriseTelemetryCounter.REQUESTS: meter.create_counter("dify.requests.total", unit="{request}"), + EnterpriseTelemetryCounter.ERRORS: meter.create_counter("dify.errors.total", unit="{error}"), + EnterpriseTelemetryCounter.FEEDBACK: meter.create_counter("dify.feedback.total", unit="{feedback}"), + EnterpriseTelemetryCounter.DATASET_RETRIEVALS: meter.create_counter( + "dify.dataset.retrievals.total", unit="{retrieval}" + ), + } + self._histograms = { + EnterpriseTelemetryHistogram.WORKFLOW_DURATION: meter.create_histogram("dify.workflow.duration", unit="s"), + EnterpriseTelemetryHistogram.NODE_DURATION: meter.create_histogram("dify.node.duration", unit="s"), + EnterpriseTelemetryHistogram.MESSAGE_DURATION: meter.create_histogram("dify.message.duration", unit="s"), + EnterpriseTelemetryHistogram.MESSAGE_TTFT: meter.create_histogram( + "dify.message.time_to_first_token", unit="s" + ), + EnterpriseTelemetryHistogram.TOOL_DURATION: meter.create_histogram("dify.tool.duration", unit="s"), + EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION: meter.create_histogram( + "dify.prompt_generation.duration", unit="s" + ), + } + + def export_span( + self, + name: str, + attributes: dict[str, Any], + correlation_id: str | None = None, + span_id_source: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + trace_correlation_override: str | None = None, + parent_span_id_source: str | None = None, + ) -> None: + """Export an OTEL span with optional deterministic IDs and real timestamps. + + Args: + name: Span operation name. + attributes: Span attributes dict. + correlation_id: Source for trace_id derivation (groups spans in one trace). + span_id_source: Source for deterministic span_id (e.g. workflow_run_id or node_execution_id). + start_time: Real span start time. When None, uses current time. + end_time: Real span end time. When None, span ends immediately. + trace_correlation_override: Override trace_id source (for cross-workflow linking). + When set, trace_id is derived from this instead of ``correlation_id``. + parent_span_id_source: Override parent span_id source (for cross-workflow linking). + When set, parent span_id is derived from this value. When None and + ``correlation_id`` is set, parent is the workflow root span. + """ + effective_trace_correlation = trace_correlation_override or correlation_id + set_correlation_id(effective_trace_correlation) + set_span_id_source(span_id_source) + + try: + parent_context: Context | None = None + # A span is the "root" of its correlation group when span_id_source == correlation_id + # (i.e. a workflow root span). All other spans are children. + if parent_span_id_source: + # Cross-workflow linking: parent is an explicit span (e.g. tool node in outer workflow) + parent_span_id = compute_deterministic_span_id(parent_span_id_source) + parent_trace_id = int(uuid.UUID(effective_trace_correlation)) if effective_trace_correlation else 0 + if parent_trace_id: + parent_span_context = SpanContext( + trace_id=parent_trace_id, + span_id=parent_span_id, + is_remote=True, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context)) + elif correlation_id and correlation_id != span_id_source: + # Child span: parent is the correlation-group root (workflow root span) + parent_span_id = compute_deterministic_span_id(correlation_id) + parent_trace_id = int(uuid.UUID(effective_trace_correlation or correlation_id)) + parent_span_context = SpanContext( + trace_id=parent_trace_id, + span_id=parent_span_id, + is_remote=True, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context)) + + span_start_time = _datetime_to_ns(start_time) if start_time is not None else None + span_end_on_exit = end_time is None + + with self._tracer.start_as_current_span( + name, + context=parent_context, + start_time=span_start_time, + end_on_exit=span_end_on_exit, + ) as span: + for key, value in attributes.items(): + if value is not None: + span.set_attribute(key, value) + if end_time is not None: + span.end(end_time=_datetime_to_ns(end_time)) + except Exception: + logger.exception("Failed to export span %s", name) + finally: + set_correlation_id(None) + set_span_id_source(None) + + def increment_counter( + self, name: EnterpriseTelemetryCounter, value: int, labels: dict[str, AttributeValue] + ) -> None: + counter = self._counters.get(name) + if counter: + counter.add(value, cast(Attributes, labels)) + + def record_histogram( + self, name: EnterpriseTelemetryHistogram, value: float, labels: dict[str, AttributeValue] + ) -> None: + histogram = self._histograms.get(name) + if histogram: + histogram.record(value, cast(Attributes, labels)) + + def shutdown(self) -> None: + self._tracer_provider.shutdown() + self._meter_provider.shutdown() diff --git a/api/enterprise/telemetry/gateway.py b/api/enterprise/telemetry/gateway.py new file mode 100644 index 0000000000..73886e327e --- /dev/null +++ b/api/enterprise/telemetry/gateway.py @@ -0,0 +1,199 @@ +"""Telemetry gateway routing and dispatch. + +Maps ``TelemetryCase`` → ``CaseRoute`` (signal type + CE eligibility) +and dispatches events to either the trace pipeline or the metric/log +Celery queue. + +Singleton lifecycle is managed by ``ext_enterprise_telemetry.init_app()`` +which creates the instance during single-threaded Flask app startup. +Access via ``ext_enterprise_telemetry.get_gateway()``. +""" + +from __future__ import annotations + +import json +import logging +import uuid +from typing import TYPE_CHECKING, Any + +from core.ops.entities.trace_entity import TraceTaskName +from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase, TelemetryEnvelope +from extensions.ext_storage import storage + +if TYPE_CHECKING: + from core.ops.ops_trace_manager import TraceQueueManager + +logger = logging.getLogger(__name__) + +PAYLOAD_SIZE_THRESHOLD_BYTES = 1 * 1024 * 1024 + +CASE_TO_TRACE_TASK: dict[TelemetryCase, TraceTaskName] = { + TelemetryCase.WORKFLOW_RUN: TraceTaskName.WORKFLOW_TRACE, + TelemetryCase.MESSAGE_RUN: TraceTaskName.MESSAGE_TRACE, + TelemetryCase.NODE_EXECUTION: TraceTaskName.NODE_EXECUTION_TRACE, + TelemetryCase.DRAFT_NODE_EXECUTION: TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + TelemetryCase.PROMPT_GENERATION: TraceTaskName.PROMPT_GENERATION_TRACE, +} + +CASE_ROUTING: dict[TelemetryCase, CaseRoute] = { + TelemetryCase.WORKFLOW_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.MESSAGE_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + TelemetryCase.DRAFT_NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + TelemetryCase.PROMPT_GENERATION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + TelemetryCase.APP_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.APP_UPDATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.APP_DELETED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.FEEDBACK_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.TOOL_EXECUTION: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.MODERATION_CHECK: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.SUGGESTED_QUESTION: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.DATASET_RETRIEVAL: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.GENERATE_NAME: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), +} + + +def _is_enterprise_telemetry_enabled() -> bool: + try: + from enterprise.telemetry.exporter import is_enterprise_telemetry_enabled + + return is_enterprise_telemetry_enabled() + except Exception: + return False + + +def _should_drop_ee_only_event(route: CaseRoute) -> bool: + """Return True when the event is enterprise-only and EE telemetry is disabled.""" + return not route.ce_eligible and not _is_enterprise_telemetry_enabled() + + +class TelemetryGateway: + """Routes telemetry events to the trace pipeline or the metric/log Celery queue. + + Stateless — instantiated once during ``ext_enterprise_telemetry.init_app()`` + and shared for the lifetime of the process. + """ + + def emit( + self, + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + trace_manager: TraceQueueManager | None = None, + ) -> None: + route = CASE_ROUTING.get(case) + if route is None: + logger.warning("Unknown telemetry case: %s, dropping event", case) + return + + if _should_drop_ee_only_event(route): + logger.debug("Dropping EE-only event: case=%s (EE disabled)", case) + return + + logger.debug( + "Gateway routing: case=%s, signal_type=%s, ce_eligible=%s", + case, + route.signal_type, + route.ce_eligible, + ) + + if route.signal_type is SignalType.TRACE: + self._emit_trace(case, context, payload, route, trace_manager) + else: + self._emit_metric_log(case, context, payload) + + def _emit_trace( + self, + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + route: CaseRoute, + trace_manager: TraceQueueManager | None, + ) -> None: + from core.ops.ops_trace_manager import TraceQueueManager as LocalTraceQueueManager + from core.ops.ops_trace_manager import TraceTask + + trace_task_name = CASE_TO_TRACE_TASK.get(case) + if trace_task_name is None: + logger.warning("No TraceTaskName mapping for case: %s", case) + return + + queue_manager = trace_manager or LocalTraceQueueManager( + app_id=context.get("app_id"), + user_id=context.get("user_id"), + ) + + queue_manager.add_trace_task(TraceTask(trace_task_name, **payload)) + logger.debug("Enqueued trace task: case=%s, app_id=%s", case, context.get("app_id")) + + def _emit_metric_log( + self, + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + ) -> None: + from tasks.enterprise_telemetry_task import process_enterprise_telemetry + + tenant_id = context.get("tenant_id", "") + event_id = str(uuid.uuid4()) + + payload_for_envelope, payload_ref = self._handle_payload_sizing(payload, tenant_id, event_id) + + envelope = TelemetryEnvelope( + case=case, + tenant_id=tenant_id, + event_id=event_id, + payload=payload_for_envelope, + metadata={"payload_ref": payload_ref} if payload_ref else None, + ) + + process_enterprise_telemetry.delay(envelope.model_dump_json()) + logger.debug( + "Enqueued metric/log event: case=%s, tenant_id=%s, event_id=%s", + case, + tenant_id, + event_id, + ) + + def _handle_payload_sizing( + self, + payload: dict[str, Any], + tenant_id: str, + event_id: str, + ) -> tuple[dict[str, Any], str | None]: + try: + payload_json = json.dumps(payload) + payload_size = len(payload_json.encode("utf-8")) + except (TypeError, ValueError): + logger.warning("Failed to serialize payload for sizing: event_id=%s", event_id) + return payload, None + + if payload_size <= PAYLOAD_SIZE_THRESHOLD_BYTES: + return payload, None + + storage_key = f"telemetry/{tenant_id}/{event_id}.json" + try: + storage.save(storage_key, payload_json.encode("utf-8")) + logger.debug("Stored large payload to storage: key=%s, size=%d", storage_key, payload_size) + return {}, storage_key + except Exception: + logger.warning("Failed to store large payload, inlining instead: event_id=%s", event_id, exc_info=True) + return payload, None + + +def emit( + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + trace_manager: TraceQueueManager | None = None, +) -> None: + """Module-level convenience wrapper. + + Fetches the gateway singleton from the extension; no-ops when + enterprise telemetry is disabled (gateway is ``None``). + """ + from extensions.ext_enterprise_telemetry import get_gateway + + gateway = get_gateway() + if gateway is not None: + gateway.emit(case, context, payload, trace_manager) diff --git a/api/enterprise/telemetry/id_generator.py b/api/enterprise/telemetry/id_generator.py new file mode 100644 index 0000000000..8f4760cac2 --- /dev/null +++ b/api/enterprise/telemetry/id_generator.py @@ -0,0 +1,76 @@ +"""Custom OTEL ID Generator for correlation-based trace/span ID derivation. + +Uses contextvars for thread-safe correlation_id -> trace_id mapping. +When a span_id_source is set, the span_id is derived deterministically +from that value, enabling any span to reference another as parent +without depending on span creation order. +""" + +import random +import uuid +from contextvars import ContextVar +from typing import cast + +from opentelemetry.sdk.trace.id_generator import IdGenerator + +_correlation_id_context: ContextVar[str | None] = ContextVar("correlation_id", default=None) +_span_id_source_context: ContextVar[str | None] = ContextVar("span_id_source", default=None) + + +def set_correlation_id(correlation_id: str | None) -> None: + _correlation_id_context.set(correlation_id) + + +def get_correlation_id() -> str | None: + return _correlation_id_context.get() + + +def set_span_id_source(source_id: str | None) -> None: + """Set the source for deterministic span_id generation. + + When set, ``generate_span_id()`` derives the span_id from this value + (lower 64 bits of the UUID). Pass the ``workflow_run_id`` for workflow + root spans or ``node_execution_id`` for node spans. + """ + _span_id_source_context.set(source_id) + + +def compute_deterministic_span_id(source_id: str) -> int: + """Derive a deterministic span_id from any UUID string. + + Uses the lower 64 bits of the UUID, guaranteeing non-zero output + (OTEL requires span_id != 0). + """ + span_id = cast(int, uuid.UUID(source_id).int) & ((1 << 64) - 1) + return span_id if span_id != 0 else 1 + + +class CorrelationIdGenerator(IdGenerator): + """ID generator that derives trace_id and optionally span_id from context. + + - trace_id: always derived from correlation_id (groups all spans in one trace) + - span_id: derived from span_id_source when set (enables deterministic + parent-child linking), otherwise random + """ + + def generate_trace_id(self) -> int: + correlation_id = _correlation_id_context.get() + if correlation_id: + try: + return cast(int, uuid.UUID(correlation_id).int) + except (ValueError, AttributeError): + pass + return random.getrandbits(128) + + def generate_span_id(self) -> int: + source = _span_id_source_context.get() + if source: + try: + return compute_deterministic_span_id(source) + except (ValueError, AttributeError): + pass + + span_id = random.getrandbits(64) + while span_id == 0: + span_id = random.getrandbits(64) + return span_id diff --git a/api/enterprise/telemetry/metric_handler.py b/api/enterprise/telemetry/metric_handler.py new file mode 100644 index 0000000000..cfe1768a10 --- /dev/null +++ b/api/enterprise/telemetry/metric_handler.py @@ -0,0 +1,371 @@ +"""Enterprise metric/log event handler. + +This module processes metric and log telemetry events after they've been +dequeued from the enterprise_telemetry Celery queue. It handles case routing, +idempotency checking, and payload rehydration. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope +from extensions.ext_redis import redis_client + +logger = logging.getLogger(__name__) + + +class EnterpriseMetricHandler: + """Handler for enterprise metric and log telemetry events. + + Processes envelopes from the enterprise_telemetry queue, routing each + case to the appropriate handler method. Implements idempotency checking + and payload rehydration with fallback. + """ + + def _increment_diagnostic_counter(self, counter_name: str, labels: dict[str, str] | None = None) -> None: + """Increment a diagnostic counter for operational monitoring. + + Args: + counter_name: Name of the counter (e.g., 'processed_total', 'deduped_total'). + labels: Optional labels for the counter. + """ + try: + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + return + + full_counter_name = f"enterprise_telemetry.handler.{counter_name}" + logger.debug( + "Diagnostic counter: %s, labels=%s", + full_counter_name, + labels or {}, + ) + except Exception: + logger.debug("Failed to increment diagnostic counter: %s", counter_name, exc_info=True) + + def handle(self, envelope: TelemetryEnvelope) -> None: + """Main entry point for processing telemetry envelopes. + + Args: + envelope: The telemetry envelope to process. + """ + # Check for duplicate events + if self._is_duplicate(envelope): + logger.debug( + "Skipping duplicate event: tenant_id=%s, event_id=%s", + envelope.tenant_id, + envelope.event_id, + ) + self._increment_diagnostic_counter("deduped_total") + return + + # Route to appropriate handler based on case + case = envelope.case + if case == TelemetryCase.APP_CREATED: + self._on_app_created(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_created"}) + elif case == TelemetryCase.APP_UPDATED: + self._on_app_updated(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_updated"}) + elif case == TelemetryCase.APP_DELETED: + self._on_app_deleted(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"}) + elif case == TelemetryCase.FEEDBACK_CREATED: + self._on_feedback_created(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"}) + elif case == TelemetryCase.MESSAGE_RUN: + self._on_message_run(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "message_run"}) + elif case == TelemetryCase.TOOL_EXECUTION: + self._on_tool_execution(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"}) + elif case == TelemetryCase.MODERATION_CHECK: + self._on_moderation_check(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"}) + elif case == TelemetryCase.SUGGESTED_QUESTION: + self._on_suggested_question(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"}) + elif case == TelemetryCase.DATASET_RETRIEVAL: + self._on_dataset_retrieval(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"}) + elif case == TelemetryCase.GENERATE_NAME: + self._on_generate_name(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "generate_name"}) + elif case == TelemetryCase.PROMPT_GENERATION: + self._on_prompt_generation(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"}) + else: + logger.warning( + "Unknown telemetry case: %s (tenant_id=%s, event_id=%s)", + case, + envelope.tenant_id, + envelope.event_id, + ) + + def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool: + """Check if this event has already been processed. + + Uses Redis with TTL for deduplication. Returns True if duplicate, + False if first time seeing this event. + + Args: + envelope: The telemetry envelope to check. + + Returns: + True if this event_id has been seen before, False otherwise. + """ + dedup_key = f"telemetry:dedup:{envelope.tenant_id}:{envelope.event_id}" + + try: + # Atomic set-if-not-exists with 1h TTL + # Returns True if key was set (first time), None if already exists (duplicate) + was_set = redis_client.set(dedup_key, b"1", nx=True, ex=3600) + return was_set is None + except Exception: + # Fail open: if Redis is unavailable, process the event + # (prefer occasional duplicate over lost data) + logger.warning( + "Redis unavailable for deduplication check, processing event anyway: %s", + envelope.event_id, + exc_info=True, + ) + return False + + def _rehydrate(self, envelope: TelemetryEnvelope) -> dict[str, Any]: + """Rehydrate payload from reference or fallback. + + Attempts to resolve payload_ref to full data. If that fails, + falls back to payload_fallback. If both fail, emits a degraded + event marker. + + Args: + envelope: The telemetry envelope containing payload data. + + Returns: + The rehydrated payload dictionary. + """ + # For now, payload is directly in the envelope + # Future: implement payload_ref resolution from storage + payload = envelope.payload + + if not payload and envelope.payload_fallback: + import pickle + + try: + payload = pickle.loads(envelope.payload_fallback) # noqa: S301 + logger.debug("Used payload_fallback for event_id=%s", envelope.event_id) + except Exception: + logger.warning( + "Failed to deserialize payload_fallback for event_id=%s", + envelope.event_id, + exc_info=True, + ) + + if not payload: + # Both ref and fallback failed - emit degraded event + logger.error( + "Payload rehydration failed for event_id=%s, tenant_id=%s, case=%s", + envelope.event_id, + envelope.tenant_id, + envelope.case, + ) + # Emit degraded event marker + from enterprise.telemetry.telemetry_log import emit_metric_only_event + + emit_metric_only_event( + event_name="dify.telemetry.rehydration_failed", + attributes={ + "dify.tenant_id": envelope.tenant_id, + "dify.event_id": envelope.event_id, + "dify.case": envelope.case, + "rehydration_failed": True, + }, + tenant_id=envelope.tenant_id, + ) + self._increment_diagnostic_counter("rehydration_failed_total") + return {} + + return payload + + # Stub methods for each metric/log case + # These will be implemented in later tasks with actual emission logic + + def _on_app_created(self, envelope: TelemetryEnvelope) -> None: + """Handle app created event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_CREATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app.id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.app.mode": payload.get("mode"), + } + + emit_metric_only_event( + event_name="dify.app.created", + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + { + "type": "app.created", + "tenant_id": envelope.tenant_id, + }, + ) + + def _on_app_updated(self, envelope: TelemetryEnvelope) -> None: + """Handle app updated event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_UPDATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app.id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + } + + emit_metric_only_event( + event_name="dify.app.updated", + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + { + "type": "app.updated", + "tenant_id": envelope.tenant_id, + }, + ) + + def _on_app_deleted(self, envelope: TelemetryEnvelope) -> None: + """Handle app deleted event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_DELETED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app.id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + } + + emit_metric_only_event( + event_name="dify.app.deleted", + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + { + "type": "app.deleted", + "tenant_id": envelope.tenant_id, + }, + ) + + def _on_feedback_created(self, envelope: TelemetryEnvelope) -> None: + """Handle feedback created event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for FEEDBACK_CREATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + include_content = exporter.include_content + attrs: dict = { + "dify.message.id": payload.get("message_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.app_id": payload.get("app_id"), + "dify.conversation.id": payload.get("conversation_id"), + "gen_ai.user.id": payload.get("from_end_user_id") or payload.get("from_account_id"), + "dify.feedback.rating": payload.get("rating"), + "dify.feedback.from_source": payload.get("from_source"), + } + if include_content: + attrs["dify.feedback.content"] = payload.get("content") + + user_id = payload.get("from_end_user_id") or payload.get("from_account_id") + emit_metric_only_event( + event_name="dify.feedback.created", + attributes=attrs, + tenant_id=envelope.tenant_id, + user_id=str(user_id or ""), + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.FEEDBACK, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + "rating": str(payload.get("rating", "")), + }, + ) + + def _on_message_run(self, envelope: TelemetryEnvelope) -> None: + """Handle message run event (stub).""" + logger.debug("Processing MESSAGE_RUN: event_id=%s", envelope.event_id) + + def _on_tool_execution(self, envelope: TelemetryEnvelope) -> None: + """Handle tool execution event (stub).""" + logger.debug("Processing TOOL_EXECUTION: event_id=%s", envelope.event_id) + + def _on_moderation_check(self, envelope: TelemetryEnvelope) -> None: + """Handle moderation check event (stub).""" + logger.debug("Processing MODERATION_CHECK: event_id=%s", envelope.event_id) + + def _on_suggested_question(self, envelope: TelemetryEnvelope) -> None: + """Handle suggested question event (stub).""" + logger.debug("Processing SUGGESTED_QUESTION: event_id=%s", envelope.event_id) + + def _on_dataset_retrieval(self, envelope: TelemetryEnvelope) -> None: + """Handle dataset retrieval event (stub).""" + logger.debug("Processing DATASET_RETRIEVAL: event_id=%s", envelope.event_id) + + def _on_generate_name(self, envelope: TelemetryEnvelope) -> None: + """Handle generate name event (stub).""" + logger.debug("Processing GENERATE_NAME: event_id=%s", envelope.event_id) + + def _on_prompt_generation(self, envelope: TelemetryEnvelope) -> None: + """Handle prompt generation event (stub).""" + logger.debug("Processing PROMPT_GENERATION: event_id=%s", envelope.event_id) diff --git a/api/enterprise/telemetry/telemetry_log.py b/api/enterprise/telemetry/telemetry_log.py new file mode 100644 index 0000000000..63d79e8dc4 --- /dev/null +++ b/api/enterprise/telemetry/telemetry_log.py @@ -0,0 +1,119 @@ +"""Structured-log emitter for enterprise telemetry events. + +Emits structured JSON log lines correlated with OTEL traces via trace_id. +Picked up by ``StructuredJSONFormatter`` → stdout/Loki/Elastic. +""" + +from __future__ import annotations + +import logging +import uuid +from functools import lru_cache +from typing import Any + +logger = logging.getLogger("dify.telemetry") + + +@lru_cache(maxsize=4096) +def compute_trace_id_hex(uuid_str: str | None) -> str: + """Convert a business UUID string to a 32-hex OTEL-compatible trace_id. + + Returns empty string when *uuid_str* is ``None`` or invalid. + """ + if not uuid_str: + return "" + normalized = uuid_str.strip().lower() + if len(normalized) == 32 and all(ch in "0123456789abcdef" for ch in normalized): + return normalized + try: + return f"{uuid.UUID(normalized).int:032x}" + except (ValueError, AttributeError): + return "" + + +@lru_cache(maxsize=4096) +def compute_span_id_hex(uuid_str: str | None) -> str: + if not uuid_str: + return "" + normalized = uuid_str.strip().lower() + if len(normalized) == 16 and all(ch in "0123456789abcdef" for ch in normalized): + return normalized + try: + from enterprise.telemetry.id_generator import compute_deterministic_span_id + + return f"{compute_deterministic_span_id(normalized):016x}" + except (ValueError, AttributeError): + return "" + + +def emit_telemetry_log( + *, + event_name: str, + attributes: dict[str, Any], + signal: str = "metric_only", + trace_id_source: str | None = None, + span_id_source: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, +) -> None: + """Emit a structured log line for a telemetry event. + + Parameters + ---------- + event_name: + Canonical event name, e.g. ``"dify.workflow.run"``. + attributes: + All event-specific attributes (already built by the caller). + signal: + ``"metric_only"`` for events with no span, ``"span_detail"`` + for detail logs accompanying a slim span. + trace_id_source: + A UUID string (e.g. ``workflow_run_id``) used to derive a 32-hex + trace_id for cross-signal correlation. + tenant_id: + Tenant identifier (for the ``IdentityContextFilter``). + user_id: + User identifier (for the ``IdentityContextFilter``). + """ + if not logger.isEnabledFor(logging.INFO): + return + attrs = { + "dify.event.name": event_name, + "dify.event.signal": signal, + **attributes, + } + + extra: dict[str, Any] = {"attributes": attrs} + + trace_id_hex = compute_trace_id_hex(trace_id_source) + if trace_id_hex: + extra["trace_id"] = trace_id_hex + span_id_hex = compute_span_id_hex(span_id_source) + if span_id_hex: + extra["span_id"] = span_id_hex + if tenant_id: + extra["tenant_id"] = tenant_id + if user_id: + extra["user_id"] = user_id + + logger.info("telemetry.%s", signal, extra=extra) + + +def emit_metric_only_event( + *, + event_name: str, + attributes: dict[str, Any], + trace_id_source: str | None = None, + span_id_source: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, +) -> None: + emit_telemetry_log( + event_name=event_name, + attributes=attributes, + signal="metric_only", + trace_id_source=trace_id_source, + span_id_source=span_id_source, + tenant_id=tenant_id, + user_id=user_id, + ) diff --git a/api/events/app_event.py b/api/events/app_event.py index f2ce71bbbb..3a0094b77c 100644 --- a/api/events/app_event.py +++ b/api/events/app_event.py @@ -3,6 +3,12 @@ from blinker import signal # sender: app app_was_created = signal("app-was-created") +# sender: app +app_was_deleted = signal("app-was-deleted") + +# sender: app +app_was_updated = signal("app-was-updated") + # sender: app, kwargs: app_model_config app_model_config_was_updated = signal("app-model-config-was-updated") diff --git a/api/events/feedback_event.py b/api/events/feedback_event.py new file mode 100644 index 0000000000..8d91d5c5e5 --- /dev/null +++ b/api/events/feedback_event.py @@ -0,0 +1,4 @@ +from blinker import signal + +# sender: MessageFeedback, kwargs: tenant_id +feedback_was_created = signal("feedback-was-created") diff --git a/api/extensions/ext_enterprise_telemetry.py b/api/extensions/ext_enterprise_telemetry.py new file mode 100644 index 0000000000..a24e14efa7 --- /dev/null +++ b/api/extensions/ext_enterprise_telemetry.py @@ -0,0 +1,58 @@ +"""Flask extension for enterprise telemetry lifecycle management. + +Initializes the EnterpriseExporter and TelemetryGateway singletons during +``create_app()`` (single-threaded), registers blinker event handlers, +and hooks atexit for graceful shutdown. + +Skipped entirely when ``ENTERPRISE_ENABLED`` and ``ENTERPRISE_TELEMETRY_ENABLED`` +are false (``is_enabled()`` gate). +""" + +from __future__ import annotations + +import atexit +import logging +from typing import TYPE_CHECKING + +from configs import dify_config + +if TYPE_CHECKING: + from dify_app import DifyApp + from enterprise.telemetry.exporter import EnterpriseExporter + from enterprise.telemetry.gateway import TelemetryGateway + +logger = logging.getLogger(__name__) + +_exporter: EnterpriseExporter | None = None +_gateway: TelemetryGateway | None = None + + +def is_enabled() -> bool: + return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED) + + +def init_app(app: DifyApp) -> None: + global _exporter, _gateway + + if not is_enabled(): + return + + from enterprise.telemetry.exporter import EnterpriseExporter + from enterprise.telemetry.gateway import TelemetryGateway + + _exporter = EnterpriseExporter(dify_config) + _gateway = TelemetryGateway() + atexit.register(_exporter.shutdown) + + # Import to trigger @signal.connect decorator registration + import enterprise.telemetry.event_handlers # noqa: F401 # type: ignore[reportUnusedImport] + + logger.info("Enterprise telemetry initialized") + + +def get_enterprise_exporter() -> EnterpriseExporter | None: + return _exporter + + +def get_gateway() -> TelemetryGateway | None: + return _gateway diff --git a/api/extensions/otel/semconv/dify.py b/api/extensions/otel/semconv/dify.py index a20b9b358d..301ddd11aa 100644 --- a/api/extensions/otel/semconv/dify.py +++ b/api/extensions/otel/semconv/dify.py @@ -21,3 +21,15 @@ class DifySpanAttributes: INVOKE_FROM = "dify.invoke_from" """Invocation source, e.g. SERVICE_API, WEB_APP, DEBUGGER.""" + + INVOKED_BY = "dify.invoked_by" + """Invoked by, e.g. end_user, account, user.""" + + USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens" + """Number of input tokens (prompt tokens) used.""" + + USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens" + """Number of output tokens (completion tokens) generated.""" + + USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens" + """Total number of tokens used.""" diff --git a/api/services/app_service.py b/api/services/app_service.py index af458ff618..0422b4bab9 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -14,7 +14,7 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelTy from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager -from events.app_event import app_was_created +from events.app_event import app_was_created, app_was_deleted from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from libs.login import current_user @@ -340,6 +340,8 @@ class AppService: db.session.delete(app) db.session.commit() + app_was_deleted.send(app) + # clean up web app settings if FeatureService.get_system_features().webapp_auth.enabled: EnterpriseService.WebAppAuth.cleanup_webapp(app.id) diff --git a/api/services/message_service.py b/api/services/message_service.py index a53ca8b22d..26b220edfa 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -7,9 +7,10 @@ from core.llm_generator.llm_generator import LLMGenerator from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit +from events.feedback_event import feedback_was_created from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account @@ -179,6 +180,9 @@ class MessageService: db.session.commit() + if feedback and rating: + feedback_was_created.send(feedback, tenant_id=app_model.tenant_id) + return feedback @classmethod @@ -294,10 +298,15 @@ class MessageService: questions: list[str] = list(questions_sequence) # get tracing instance - trace_manager = TraceQueueManager(app_id=app_model.id) - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id=message_id, suggested_question=questions, timer=timer + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.SUGGESTED_QUESTION_TRACE, + context=TelemetryContext(tenant_id=app_model.tenant_id, app_id=app_model.id), + payload={ + "message_id": message_id, + "suggested_question": questions, + "timer": timer, + }, ) ) diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 50ea832085..c1c92b2de8 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -1,3 +1,4 @@ +import logging from typing import Any from core.ops.entities.config_entity import BaseTracingConfig @@ -5,6 +6,8 @@ from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map from extensions.ext_database import db from models.model import App, TraceAppConfig +logger = logging.getLogger(__name__) + class OpsService: @classmethod @@ -135,12 +138,13 @@ class OpsService: return trace_config_data.to_dict() @classmethod - def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict): + def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict, account_id: str): """ Create tracing app config :param app_id: app id :param tracing_provider: tracing provider :param tracing_config: tracing config + :param account_id: account id of the user creating the config :return: """ try: @@ -207,15 +211,19 @@ class OpsService: db.session.add(trace_config_data) db.session.commit() + # Log the creation with modifier information + logger.info("Trace config created: app_id=%s, provider=%s, created_by=%s", app_id, tracing_provider, account_id) + return {"result": "success"} @classmethod - def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict): + def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict, account_id: str): """ Update tracing app config :param app_id: app id :param tracing_provider: tracing provider :param tracing_config: tracing config + :param account_id: account id of the user updating the config :return: """ try: @@ -251,14 +259,18 @@ class OpsService: current_trace_config.tracing_config = tracing_config db.session.commit() + # Log the update with modifier information + logger.info("Trace config updated: app_id=%s, provider=%s, updated_by=%s", app_id, tracing_provider, account_id) + return current_trace_config.to_dict() @classmethod - def delete_tracing_app_config(cls, app_id: str, tracing_provider: str): + def delete_tracing_app_config(cls, app_id: str, tracing_provider: str, account_id: str): """ Delete tracing app config :param app_id: app id :param tracing_provider: tracing provider + :param account_id: account id of the user deleting the config :return: """ trace_config = ( @@ -270,6 +282,9 @@ class OpsService: if not trace_config: return None + # Log the deletion with modifier information + logger.info("Trace config deleted: app_id=%s, provider=%s, deleted_by=%s", app_id, tracing_provider, account_id) + db.session.delete(trace_config) db.session.commit() diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 6404136994..0d6e2eb4b7 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -27,6 +27,7 @@ from core.workflow.nodes.start.entities import StartNodeData from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry +from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace from enums.cloud_plan import CloudPlan from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db @@ -647,6 +648,7 @@ class WorkflowService: node_config = draft_workflow.get_node_config_by_id(node_id) node_type = Workflow.get_node_type_from_node_config(node_config) node_data = node_config.get("data", {}) + workflow_execution_id: str | None = None if node_type.is_start_node: with Session(bind=db.engine) as session, session.begin(): draft_var_srv = WorkflowDraftVariableService(session) @@ -672,10 +674,13 @@ class WorkflowService: node_type=node_type, conversation_id=conversation_id, ) + workflow_execution_id = variable_pool.system_variables.workflow_execution_id else: + workflow_execution_id = str(uuid.uuid4()) + system_variable = SystemVariable(workflow_execution_id=workflow_execution_id) variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=system_variable, user_inputs=user_inputs, environment_variables=draft_workflow.environment_variables, conversation_variables=[], @@ -729,6 +734,13 @@ class WorkflowService: with Session(db.engine) as session: outputs = workflow_node_execution.load_full_outputs(session, storage) + enqueue_draft_node_execution_trace( + execution=workflow_node_execution, + outputs=outputs, + workflow_execution_id=workflow_execution_id, + user_id=account.id, + ) + with Session(bind=db.engine) as session, session.begin(): draft_var_saver = DraftVariableSaver( session=session, @@ -784,19 +796,20 @@ class WorkflowService: Returns: WorkflowNodeExecution: The execution result """ + created_at = naive_utc_now() node, node_run_result, run_succeeded, error = self._execute_node_safely(invoke_node_fn) + finished_at = naive_utc_now() - # Create base node execution node_execution = WorkflowNodeExecution( - id=str(uuid.uuid4()), + id=node.execution_id or str(uuid.uuid4()), workflow_id="", # Single-step execution has no workflow ID index=1, node_id=node_id, node_type=node.node_type, title=node.title, elapsed_time=time.perf_counter() - start_at, - created_at=naive_utc_now(), - finished_at=naive_utc_now(), + created_at=created_at, + finished_at=finished_at, ) # Populate execution result data diff --git a/api/tasks/enterprise_telemetry_task.py b/api/tasks/enterprise_telemetry_task.py new file mode 100644 index 0000000000..7d5ea7c0a5 --- /dev/null +++ b/api/tasks/enterprise_telemetry_task.py @@ -0,0 +1,52 @@ +"""Celery worker for enterprise metric/log telemetry events. + +This module defines the Celery task that processes telemetry envelopes +from the enterprise_telemetry queue. It deserializes envelopes and +dispatches them to the EnterpriseMetricHandler. +""" + +import json +import logging + +from celery import shared_task + +from enterprise.telemetry.contracts import TelemetryEnvelope +from enterprise.telemetry.metric_handler import EnterpriseMetricHandler + +logger = logging.getLogger(__name__) + + +@shared_task(queue="enterprise_telemetry") +def process_enterprise_telemetry(envelope_json: str) -> None: + """Process enterprise metric/log telemetry envelope. + + This task is enqueued by the TelemetryGateway for metric/log-only + events. It deserializes the envelope and dispatches to the handler. + + Best-effort processing: logs errors but never raises, to avoid + failing user requests due to telemetry issues. + + Args: + envelope_json: JSON-serialized TelemetryEnvelope. + """ + try: + # Deserialize envelope + envelope_dict = json.loads(envelope_json) + envelope = TelemetryEnvelope.model_validate(envelope_dict) + + # Process through handler + handler = EnterpriseMetricHandler() + handler.handle(envelope) + + logger.debug( + "Successfully processed telemetry envelope: tenant_id=%s, event_id=%s, case=%s", + envelope.tenant_id, + envelope.event_id, + envelope.case, + ) + except Exception: + # Best-effort: log and drop on error, never fail user request + logger.warning( + "Failed to process enterprise telemetry envelope, dropping event", + exc_info=True, + ) diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index 72e3b42ca7..5b61e9e7a1 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -39,12 +39,24 @@ def process_trace_tasks(file_info): trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]] try: + trace_type = trace_info_info_map.get(trace_info_type) + if trace_type: + trace_info = trace_type(**trace_info) + + from extensions.ext_enterprise_telemetry import is_enabled as is_ee_telemetry_enabled + + if is_ee_telemetry_enabled(): + from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace + + try: + EnterpriseOtelTrace().trace(trace_info) + except Exception: + logger.warning("Enterprise trace failed for app_id: %s", app_id, exc_info=True) + if trace_instance: with current_app.app_context(): - trace_type = trace_info_info_map.get(trace_info_type) - if trace_type: - trace_info = trace_type(**trace_info) trace_instance.trace(trace_info) + logger.info("Processing trace tasks success, app_id: %s", app_id) except Exception as e: logger.info("error:\n\n\n%s\n\n\n\n", e) @@ -52,4 +64,12 @@ def process_trace_tasks(file_info): redis_client.incr(failed_key) logger.info("Processing trace tasks failed, app_id: %s", app_id) finally: - storage.delete(file_path) + try: + storage.delete(file_path) + except Exception as e: + logger.warning( + "Failed to delete trace file %s for app_id %s: %s", + file_path, + app_id, + e, + ) diff --git a/api/tests/unit_tests/core/ops/test_trace_queue_manager.py b/api/tests/unit_tests/core/ops/test_trace_queue_manager.py new file mode 100644 index 0000000000..25adda21ec --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_trace_queue_manager.py @@ -0,0 +1,200 @@ +"""Unit tests for TraceQueueManager telemetry guard. + +This test suite verifies that TraceQueueManager correctly drops trace tasks +when telemetry is disabled, proving Bug 1 from code review is a false positive. + +The guard logic moved from persistence.py to TraceQueueManager.add_trace_task() +at line 1282 of ops_trace_manager.py: + if self._enterprise_telemetry_enabled or self.trace_instance: + trace_task.app_id = self.app_id + trace_manager_queue.put(trace_task) + +Tasks are only enqueued if EITHER: +- Enterprise telemetry is enabled (_enterprise_telemetry_enabled=True), OR +- A third-party trace instance (Langfuse, etc.) is configured + +When BOTH are false, tasks are silently dropped (correct behavior). +""" + +import queue +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture +def trace_queue_manager_and_task(monkeypatch): + """Fixture to provide TraceQueueManager and TraceTask with delayed imports.""" + module_name = "core.ops.ops_trace_manager" + if module_name not in sys.modules: + ops_stub = types.ModuleType(module_name) + + class StubTraceTask: + def __init__(self, trace_type): + self.trace_type = trace_type + self.app_id = None + + class StubTraceQueueManager: + def __init__(self, app_id=None): + self.app_id = app_id + from core.telemetry import is_enterprise_telemetry_enabled + + self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled() + self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id) + + def add_trace_task(self, trace_task): + if self._enterprise_telemetry_enabled or self.trace_instance: + trace_task.app_id = self.app_id + from core.ops.ops_trace_manager import trace_manager_queue + + trace_manager_queue.put(trace_task) + + class StubOpsTraceManager: + @staticmethod + def get_ops_trace_instance(app_id): + return None + + ops_stub.TraceQueueManager = StubTraceQueueManager + ops_stub.TraceTask = StubTraceTask + ops_stub.OpsTraceManager = StubOpsTraceManager + ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue) + monkeypatch.setitem(sys.modules, module_name, ops_stub) + + from core.ops.entities.trace_entity import TraceTaskName + + ops_module = __import__(module_name, fromlist=["TraceQueueManager", "TraceTask"]) + TraceQueueManager = ops_module.TraceQueueManager + TraceTask = ops_module.TraceTask + + return TraceQueueManager, TraceTask, TraceTaskName + + +class TestTraceQueueManagerTelemetryGuard: + """Test TraceQueueManager's telemetry guard in add_trace_task().""" + + def test_task_not_enqueued_when_telemetry_disabled_and_no_trace_instance(self, trace_queue_manager_and_task): + """Verify task is NOT enqueued when telemetry disabled and no trace instance. + + This is the core guard: when _enterprise_telemetry_enabled=False AND + trace_instance=None, the task should be silently dropped. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.is_enterprise_telemetry_enabled", return_value=False), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_not_called() + + def test_task_enqueued_when_telemetry_enabled(self, trace_queue_manager_and_task): + """Verify task IS enqueued when enterprise telemetry is enabled. + + When _enterprise_telemetry_enabled=True, the task should be enqueued + regardless of trace_instance state. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.is_enterprise_telemetry_enabled", return_value=True), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_task_enqueued_when_trace_instance_configured(self, trace_queue_manager_and_task): + """Verify task IS enqueued when third-party trace instance is configured. + + When trace_instance is not None (e.g., Langfuse configured), the task + should be enqueued even if enterprise telemetry is disabled. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + mock_trace_instance = MagicMock() + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.is_enterprise_telemetry_enabled", return_value=False), + patch( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance + ), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_task_enqueued_when_both_telemetry_and_trace_instance_enabled(self, trace_queue_manager_and_task): + """Verify task IS enqueued when both telemetry and trace instance are enabled. + + When both _enterprise_telemetry_enabled=True AND trace_instance is set, + the task should definitely be enqueued. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + mock_trace_instance = MagicMock() + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.is_enterprise_telemetry_enabled", return_value=True), + patch( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance + ), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_app_id_set_before_enqueue(self, trace_queue_manager_and_task): + """Verify app_id is set on the task before enqueuing. + + The guard logic sets trace_task.app_id = self.app_id before calling + trace_manager_queue.put(trace_task). This test verifies that behavior. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.is_enterprise_telemetry_enabled", return_value=True), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="expected-app-id") + manager.add_trace_task(trace_task) + + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "expected-app-id" diff --git a/api/tests/unit_tests/core/telemetry/test_facade.py b/api/tests/unit_tests/core/telemetry/test_facade.py new file mode 100644 index 0000000000..ae7b2ce818 --- /dev/null +++ b/api/tests/unit_tests/core/telemetry/test_facade.py @@ -0,0 +1,181 @@ +"""Unit tests for core.telemetry.emit() routing and enterprise-only filtering.""" + +from __future__ import annotations + +import queue +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.trace_entity import TraceTaskName +from core.telemetry.events import TelemetryContext, TelemetryEvent + + +@pytest.fixture +def telemetry_test_setup(monkeypatch): + module_name = "core.ops.ops_trace_manager" + ops_stub = types.ModuleType(module_name) + + class StubTraceTask: + def __init__(self, trace_type, **kwargs): + self.trace_type = trace_type + self.app_id = None + self.kwargs = kwargs + + class StubTraceQueueManager: + def __init__(self, app_id=None, user_id=None): + self.app_id = app_id + self.user_id = user_id + self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id) + + def add_trace_task(self, trace_task): + trace_task.app_id = self.app_id + from core.ops.ops_trace_manager import trace_manager_queue + + trace_manager_queue.put(trace_task) + + class StubOpsTraceManager: + @staticmethod + def get_ops_trace_instance(app_id): + return None + + ops_stub.TraceQueueManager = StubTraceQueueManager + ops_stub.TraceTask = StubTraceTask + ops_stub.OpsTraceManager = StubOpsTraceManager + ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue) + monkeypatch.setitem(sys.modules, module_name, ops_stub) + + from core.telemetry import emit + + return emit, ops_stub.trace_manager_queue + + +class TestTelemetryEmit: + @patch("core.telemetry._is_enterprise_telemetry_enabled", return_value=True) + def test_emit_enterprise_trace_creates_trace_task(self, _mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={"key": "value"}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE + + def test_emit_community_trace_enqueued(self, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.WORKFLOW_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + + def test_emit_enterprise_only_trace_dropped_when_ee_disabled(self, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_not_called() + + @patch("core.telemetry._is_enterprise_telemetry_enabled", return_value=True) + def test_emit_all_enterprise_only_traces_allowed_when_ee_enabled(self, _mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + enterprise_only_traces = [ + TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + TraceTaskName.NODE_EXECUTION_TRACE, + TraceTaskName.PROMPT_GENERATION_TRACE, + ] + + for trace_name in enterprise_only_traces: + mock_queue.reset_mock() + + event = TelemetryEvent( + name=trace_name, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == trace_name + + @patch("core.telemetry._is_enterprise_telemetry_enabled", return_value=True) + def test_emit_passes_name_directly_to_trace_task(self, _mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={"extra": "data"}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE + assert isinstance(called_task.trace_type, TraceTaskName) + + @patch("core.telemetry._is_enterprise_telemetry_enabled", return_value=True) + def test_emit_with_provided_trace_manager(self, _mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + mock_trace_manager = MagicMock() + mock_trace_manager.add_trace_task = MagicMock() + + event = TelemetryEvent( + name=TraceTaskName.NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event, trace_manager=mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + called_task = mock_trace_manager.add_trace_task.call_args[0][0] + assert called_task.trace_type == TraceTaskName.NODE_EXECUTION_TRACE diff --git a/api/tests/unit_tests/core/telemetry/test_gateway_integration.py b/api/tests/unit_tests/core/telemetry/test_gateway_integration.py new file mode 100644 index 0000000000..076cd00879 --- /dev/null +++ b/api/tests/unit_tests/core/telemetry/test_gateway_integration.py @@ -0,0 +1,252 @@ +from __future__ import annotations + +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from core.telemetry import is_enterprise_telemetry_enabled +from enterprise.telemetry.contracts import TelemetryCase +from enterprise.telemetry.gateway import TelemetryGateway + + +class TestTelemetryCoreExports: + def test_is_enterprise_telemetry_enabled_exported(self) -> None: + from core.telemetry import is_enterprise_telemetry_enabled as exported_func + + assert callable(exported_func) + + +@pytest.fixture +def mock_ops_trace_manager(): + mock_module = MagicMock() + mock_trace_task_class = MagicMock() + mock_trace_task_class.return_value = MagicMock() + mock_module.TraceTask = mock_trace_task_class + mock_module.TraceQueueManager = MagicMock() + + mock_trace_entity = MagicMock() + mock_trace_task_name = MagicMock() + mock_trace_task_name.return_value = "workflow" + mock_trace_entity.TraceTaskName = mock_trace_task_name + + with ( + patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}), + patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}), + ): + yield mock_module, mock_trace_entity + + +class TestGatewayIntegrationTraceRouting: + @pytest.fixture + def gateway(self) -> TelemetryGateway: + return TelemetryGateway() + + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_ce_eligible_trace_routed_to_trace_manager( + self, + gateway: TelemetryGateway, + mock_trace_manager: MagicMock, + ) -> None: + with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True): + context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"} + payload = {"workflow_run_id": "run-abc"} + + gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_ce_eligible_trace_routed_when_ee_disabled( + self, + gateway: TelemetryGateway, + mock_trace_manager: MagicMock, + ) -> None: + with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_enterprise_only_trace_dropped_when_ee_disabled( + self, + gateway: TelemetryGateway, + mock_trace_manager: MagicMock, + ) -> None: + with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_enterprise_only_trace_routed_when_ee_enabled( + self, + gateway: TelemetryGateway, + mock_trace_manager: MagicMock, + ) -> None: + with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + +class TestGatewayIntegrationMetricRouting: + @pytest.fixture + def gateway(self) -> TelemetryGateway: + return TelemetryGateway() + + def test_metric_case_routes_to_celery_task( + self, + gateway: TelemetryGateway, + ) -> None: + from enterprise.telemetry.contracts import TelemetryEnvelope + + with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") as mock_delay: + context = {"tenant_id": "tenant-123"} + payload = {"app_id": "app-abc", "name": "My App"} + + gateway.emit(TelemetryCase.APP_CREATED, context, payload) + + mock_delay.assert_called_once() + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.case == TelemetryCase.APP_CREATED + assert envelope.tenant_id == "tenant-123" + assert envelope.payload["app_id"] == "app-abc" + + def test_tool_execution_metric_routed( + self, + gateway: TelemetryGateway, + ) -> None: + from enterprise.telemetry.contracts import TelemetryEnvelope + + with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") as mock_delay: + context = {"tenant_id": "tenant-123", "app_id": "app-123"} + payload = {"tool_name": "test_tool", "tool_inputs": {}, "tool_outputs": "result"} + + gateway.emit(TelemetryCase.TOOL_EXECUTION, context, payload) + + mock_delay.assert_called_once() + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.case == TelemetryCase.TOOL_EXECUTION + + def test_moderation_check_metric_routed( + self, + gateway: TelemetryGateway, + ) -> None: + from enterprise.telemetry.contracts import TelemetryEnvelope + + with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") as mock_delay: + context = {"tenant_id": "tenant-123", "app_id": "app-123"} + payload = {"message_id": "msg-123", "moderation_result": {"flagged": False}} + + gateway.emit(TelemetryCase.MODERATION_CHECK, context, payload) + + mock_delay.assert_called_once() + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.case == TelemetryCase.MODERATION_CHECK + + +class TestGatewayIntegrationCEEligibility: + @pytest.fixture + def gateway(self) -> TelemetryGateway: + return TelemetryGateway() + + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_workflow_run_is_ce_eligible( + self, + gateway: TelemetryGateway, + mock_trace_manager: MagicMock, + ) -> None: + with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_message_run_is_ce_eligible( + self, + gateway: TelemetryGateway, + mock_trace_manager: MagicMock, + ) -> None: + with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"message_id": "msg-abc", "conversation_id": "conv-123"} + + gateway.emit(TelemetryCase.MESSAGE_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_node_execution_not_ce_eligible( + self, + gateway: TelemetryGateway, + mock_trace_manager: MagicMock, + ) -> None: + with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_draft_node_execution_not_ce_eligible( + self, + gateway: TelemetryGateway, + mock_trace_manager: MagicMock, + ) -> None: + with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_execution_data": {}} + + gateway.emit(TelemetryCase.DRAFT_NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_prompt_generation_not_ce_eligible( + self, + gateway: TelemetryGateway, + mock_trace_manager: MagicMock, + ) -> None: + with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"} + payload = {"operation_type": "generate", "instruction": "test"} + + gateway.emit(TelemetryCase.PROMPT_GENERATION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + +class TestIsEnterpriseTelemetryEnabled: + def test_returns_false_when_exporter_import_fails(self) -> None: + with patch.dict(sys.modules, {"enterprise.telemetry.exporter": None}): + result = is_enterprise_telemetry_enabled() + assert result is False + + def test_function_is_callable(self) -> None: + assert callable(is_enterprise_telemetry_enabled) diff --git a/api/tests/unit_tests/enterprise/telemetry/test_contracts.py b/api/tests/unit_tests/enterprise/telemetry/test_contracts.py new file mode 100644 index 0000000000..ce2162c5f4 --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_contracts.py @@ -0,0 +1,264 @@ +"""Unit tests for telemetry gateway contracts.""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase, TelemetryEnvelope +from enterprise.telemetry.gateway import CASE_ROUTING + + +class TestTelemetryCase: + """Tests for TelemetryCase enum.""" + + def test_all_cases_defined(self) -> None: + """Verify all 14 telemetry cases are defined.""" + expected_cases = { + "WORKFLOW_RUN", + "NODE_EXECUTION", + "DRAFT_NODE_EXECUTION", + "MESSAGE_RUN", + "TOOL_EXECUTION", + "MODERATION_CHECK", + "SUGGESTED_QUESTION", + "DATASET_RETRIEVAL", + "GENERATE_NAME", + "PROMPT_GENERATION", + "APP_CREATED", + "APP_UPDATED", + "APP_DELETED", + "FEEDBACK_CREATED", + } + actual_cases = {case.name for case in TelemetryCase} + assert actual_cases == expected_cases + + def test_case_values(self) -> None: + """Verify case enum values are correct.""" + assert TelemetryCase.WORKFLOW_RUN.value == "workflow_run" + assert TelemetryCase.NODE_EXECUTION.value == "node_execution" + assert TelemetryCase.DRAFT_NODE_EXECUTION.value == "draft_node_execution" + assert TelemetryCase.MESSAGE_RUN.value == "message_run" + assert TelemetryCase.TOOL_EXECUTION.value == "tool_execution" + assert TelemetryCase.MODERATION_CHECK.value == "moderation_check" + assert TelemetryCase.SUGGESTED_QUESTION.value == "suggested_question" + assert TelemetryCase.DATASET_RETRIEVAL.value == "dataset_retrieval" + assert TelemetryCase.GENERATE_NAME.value == "generate_name" + assert TelemetryCase.PROMPT_GENERATION.value == "prompt_generation" + assert TelemetryCase.APP_CREATED.value == "app_created" + assert TelemetryCase.APP_UPDATED.value == "app_updated" + assert TelemetryCase.APP_DELETED.value == "app_deleted" + assert TelemetryCase.FEEDBACK_CREATED.value == "feedback_created" + + +class TestCaseRoute: + """Tests for CaseRoute model.""" + + def test_valid_trace_route(self) -> None: + """Verify valid trace route creation.""" + route = CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True) + assert route.signal_type == SignalType.TRACE + assert route.ce_eligible is True + + def test_valid_metric_log_route(self) -> None: + """Verify valid metric_log route creation.""" + route = CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False) + assert route.signal_type == SignalType.METRIC_LOG + assert route.ce_eligible is False + + def test_invalid_signal_type(self) -> None: + """Verify invalid signal_type is rejected.""" + with pytest.raises(ValidationError): + CaseRoute(signal_type="invalid", ce_eligible=True) + + +class TestTelemetryEnvelope: + """Tests for TelemetryEnvelope model.""" + + def test_valid_envelope_minimal(self) -> None: + """Verify valid minimal envelope creation.""" + envelope = TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + tenant_id="tenant-123", + event_id="event-456", + payload={"key": "value"}, + ) + assert envelope.case == TelemetryCase.WORKFLOW_RUN + assert envelope.tenant_id == "tenant-123" + assert envelope.event_id == "event-456" + assert envelope.payload == {"key": "value"} + assert envelope.payload_fallback is None + assert envelope.metadata is None + + def test_valid_envelope_full(self) -> None: + """Verify valid envelope with all fields.""" + metadata = {"source": "api"} + fallback = b"fallback data" + envelope = TelemetryEnvelope( + case=TelemetryCase.MESSAGE_RUN, + tenant_id="tenant-789", + event_id="event-012", + payload={"message": "hello"}, + payload_fallback=fallback, + metadata=metadata, + ) + assert envelope.case == TelemetryCase.MESSAGE_RUN + assert envelope.tenant_id == "tenant-789" + assert envelope.event_id == "event-012" + assert envelope.payload == {"message": "hello"} + assert envelope.payload_fallback == fallback + assert envelope.metadata == metadata + + def test_missing_required_case(self) -> None: + """Verify missing case field is rejected.""" + with pytest.raises(ValidationError): + TelemetryEnvelope( + tenant_id="tenant-123", + event_id="event-456", + payload={"key": "value"}, + ) + + def test_missing_required_tenant_id(self) -> None: + """Verify missing tenant_id field is rejected.""" + with pytest.raises(ValidationError): + TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + event_id="event-456", + payload={"key": "value"}, + ) + + def test_missing_required_event_id(self) -> None: + """Verify missing event_id field is rejected.""" + with pytest.raises(ValidationError): + TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + tenant_id="tenant-123", + payload={"key": "value"}, + ) + + def test_missing_required_payload(self) -> None: + """Verify missing payload field is rejected.""" + with pytest.raises(ValidationError): + TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + tenant_id="tenant-123", + event_id="event-456", + ) + + def test_payload_fallback_within_limit(self) -> None: + """Verify payload_fallback within 64KB limit is accepted.""" + fallback = b"x" * 65536 + envelope = TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + tenant_id="tenant-123", + event_id="event-456", + payload={"key": "value"}, + payload_fallback=fallback, + ) + assert envelope.payload_fallback == fallback + + def test_payload_fallback_exceeds_limit(self) -> None: + """Verify payload_fallback exceeding 64KB is rejected.""" + fallback = b"x" * 65537 + with pytest.raises(ValidationError) as exc_info: + TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + tenant_id="tenant-123", + event_id="event-456", + payload={"key": "value"}, + payload_fallback=fallback, + ) + assert "64KB" in str(exc_info.value) + + def test_payload_fallback_none(self) -> None: + """Verify payload_fallback can be None.""" + envelope = TelemetryEnvelope( + case=TelemetryCase.WORKFLOW_RUN, + tenant_id="tenant-123", + event_id="event-456", + payload={"key": "value"}, + payload_fallback=None, + ) + assert envelope.payload_fallback is None + + +class TestCaseRouting: + """Tests for CASE_ROUTING table.""" + + def test_all_cases_routed(self) -> None: + """Verify all 14 cases have routing entries.""" + assert len(CASE_ROUTING) == 14 + for case in TelemetryCase: + assert case in CASE_ROUTING + + def test_trace_ce_eligible_cases(self) -> None: + """Verify trace cases with CE eligibility.""" + ce_eligible_trace_cases = { + TelemetryCase.WORKFLOW_RUN, + TelemetryCase.MESSAGE_RUN, + } + for case in ce_eligible_trace_cases: + route = CASE_ROUTING[case] + assert route.signal_type == SignalType.TRACE + assert route.ce_eligible is True + + def test_trace_enterprise_only_cases(self) -> None: + """Verify trace cases that are enterprise-only.""" + enterprise_only_trace_cases = { + TelemetryCase.NODE_EXECUTION, + TelemetryCase.DRAFT_NODE_EXECUTION, + TelemetryCase.PROMPT_GENERATION, + } + for case in enterprise_only_trace_cases: + route = CASE_ROUTING[case] + assert route.signal_type == SignalType.TRACE + assert route.ce_eligible is False + + def test_metric_log_cases(self) -> None: + """Verify metric/log-only cases.""" + metric_log_cases = { + TelemetryCase.APP_CREATED, + TelemetryCase.APP_UPDATED, + TelemetryCase.APP_DELETED, + TelemetryCase.FEEDBACK_CREATED, + TelemetryCase.TOOL_EXECUTION, + TelemetryCase.MODERATION_CHECK, + TelemetryCase.SUGGESTED_QUESTION, + TelemetryCase.DATASET_RETRIEVAL, + TelemetryCase.GENERATE_NAME, + } + for case in metric_log_cases: + route = CASE_ROUTING[case] + assert route.signal_type == SignalType.METRIC_LOG + assert route.ce_eligible is False + + def test_routing_table_completeness(self) -> None: + """Verify routing table covers all cases with correct types.""" + trace_cases = { + TelemetryCase.WORKFLOW_RUN, + TelemetryCase.MESSAGE_RUN, + TelemetryCase.NODE_EXECUTION, + TelemetryCase.DRAFT_NODE_EXECUTION, + TelemetryCase.PROMPT_GENERATION, + } + metric_log_cases = { + TelemetryCase.APP_CREATED, + TelemetryCase.APP_UPDATED, + TelemetryCase.APP_DELETED, + TelemetryCase.FEEDBACK_CREATED, + TelemetryCase.TOOL_EXECUTION, + TelemetryCase.MODERATION_CHECK, + TelemetryCase.SUGGESTED_QUESTION, + TelemetryCase.DATASET_RETRIEVAL, + TelemetryCase.GENERATE_NAME, + } + + all_cases = trace_cases | metric_log_cases + assert len(all_cases) == 14 + assert all_cases == set(TelemetryCase) + + for case in trace_cases: + assert CASE_ROUTING[case].signal_type == SignalType.TRACE + + for case in metric_log_cases: + assert CASE_ROUTING[case].signal_type == SignalType.METRIC_LOG diff --git a/api/tests/unit_tests/enterprise/telemetry/test_event_handlers.py b/api/tests/unit_tests/enterprise/telemetry/test_event_handlers.py new file mode 100644 index 0000000000..13902e8340 --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_event_handlers.py @@ -0,0 +1,134 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from enterprise.telemetry import event_handlers +from enterprise.telemetry.contracts import TelemetryCase + + +@pytest.fixture +def mock_exporter(): + with patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock: + exporter = MagicMock() + mock.return_value = exporter + yield exporter + + +@pytest.fixture +def mock_task(): + with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry") as mock: + yield mock + + +def test_handle_app_created_calls_task(mock_exporter, mock_task): + sender = MagicMock() + sender.id = "app-123" + sender.tenant_id = "tenant-456" + sender.mode = "chat" + + event_handlers._handle_app_created(sender) + + mock_task.delay.assert_called_once() + call_args = mock_task.delay.call_args[0][0] + assert "app_created" in call_args + assert "tenant-456" in call_args + assert "app-123" in call_args + assert "chat" in call_args + + +def test_handle_app_created_no_exporter(mock_task): + with patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter", return_value=None): + sender = MagicMock() + sender.id = "app-123" + sender.tenant_id = "tenant-456" + + event_handlers._handle_app_created(sender) + + mock_task.delay.assert_not_called() + + +def test_handle_app_updated_calls_task(mock_exporter, mock_task): + sender = MagicMock() + sender.id = "app-123" + sender.tenant_id = "tenant-456" + + event_handlers._handle_app_updated(sender) + + mock_task.delay.assert_called_once() + call_args = mock_task.delay.call_args[0][0] + assert "app_updated" in call_args + assert "tenant-456" in call_args + assert "app-123" in call_args + + +def test_handle_app_deleted_calls_task(mock_exporter, mock_task): + sender = MagicMock() + sender.id = "app-123" + sender.tenant_id = "tenant-456" + + event_handlers._handle_app_deleted(sender) + + mock_task.delay.assert_called_once() + call_args = mock_task.delay.call_args[0][0] + assert "app_deleted" in call_args + assert "tenant-456" in call_args + assert "app-123" in call_args + + +def test_handle_feedback_created_calls_task(mock_exporter, mock_task): + sender = MagicMock() + sender.message_id = "msg-123" + sender.app_id = "app-456" + sender.conversation_id = "conv-789" + sender.from_end_user_id = "user-001" + sender.from_account_id = None + sender.rating = "like" + sender.from_source = "api" + sender.content = "Great response!" + + event_handlers._handle_feedback_created(sender, tenant_id="tenant-456") + + mock_task.delay.assert_called_once() + call_args = mock_task.delay.call_args[0][0] + assert "feedback_created" in call_args + assert "tenant-456" in call_args + assert "msg-123" in call_args + assert "app-456" in call_args + assert "conv-789" in call_args + assert "user-001" in call_args + assert "like" in call_args + assert "api" in call_args + assert "Great response!" in call_args + + +def test_handle_feedback_created_no_exporter(mock_task): + with patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter", return_value=None): + sender = MagicMock() + sender.message_id = "msg-123" + + event_handlers._handle_feedback_created(sender, tenant_id="tenant-456") + + mock_task.delay.assert_not_called() + + +def test_handlers_create_valid_envelopes(mock_exporter, mock_task): + import json + + from enterprise.telemetry.contracts import TelemetryEnvelope + + sender = MagicMock() + sender.id = "app-123" + sender.tenant_id = "tenant-456" + sender.mode = "chat" + + event_handlers._handle_app_created(sender) + + call_args = mock_task.delay.call_args[0][0] + envelope_dict = json.loads(call_args) + envelope = TelemetryEnvelope(**envelope_dict) + + assert envelope.case == TelemetryCase.APP_CREATED + assert envelope.tenant_id == "tenant-456" + assert envelope.event_id + assert envelope.payload["app_id"] == "app-123" + assert envelope.payload["mode"] == "chat" diff --git a/api/tests/unit_tests/enterprise/telemetry/test_gateway.py b/api/tests/unit_tests/enterprise/telemetry/test_gateway.py new file mode 100644 index 0000000000..ff226dd56c --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_gateway.py @@ -0,0 +1,301 @@ +from __future__ import annotations + +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.trace_entity import TraceTaskName +from enterprise.telemetry.contracts import SignalType, TelemetryCase, TelemetryEnvelope +from enterprise.telemetry.gateway import ( + CASE_ROUTING, + CASE_TO_TRACE_TASK, + PAYLOAD_SIZE_THRESHOLD_BYTES, + TelemetryGateway, + emit, +) + + +class TestCaseRoutingTable: + def test_all_cases_have_routing(self) -> None: + for case in TelemetryCase: + assert case in CASE_ROUTING, f"Missing routing for {case}" + + def test_trace_cases(self) -> None: + trace_cases = [ + TelemetryCase.WORKFLOW_RUN, + TelemetryCase.MESSAGE_RUN, + TelemetryCase.NODE_EXECUTION, + TelemetryCase.DRAFT_NODE_EXECUTION, + TelemetryCase.PROMPT_GENERATION, + ] + for case in trace_cases: + assert CASE_ROUTING[case].signal_type is SignalType.TRACE, f"{case} should be trace" + + def test_metric_log_cases(self) -> None: + metric_log_cases = [ + TelemetryCase.APP_CREATED, + TelemetryCase.APP_UPDATED, + TelemetryCase.APP_DELETED, + TelemetryCase.FEEDBACK_CREATED, + TelemetryCase.TOOL_EXECUTION, + TelemetryCase.MODERATION_CHECK, + TelemetryCase.SUGGESTED_QUESTION, + TelemetryCase.DATASET_RETRIEVAL, + TelemetryCase.GENERATE_NAME, + ] + for case in metric_log_cases: + assert CASE_ROUTING[case].signal_type is SignalType.METRIC_LOG, f"{case} should be metric_log" + + def test_ce_eligible_cases(self) -> None: + ce_eligible_cases = [TelemetryCase.WORKFLOW_RUN, TelemetryCase.MESSAGE_RUN] + for case in ce_eligible_cases: + assert CASE_ROUTING[case].ce_eligible is True, f"{case} should be CE eligible" + + def test_enterprise_only_cases(self) -> None: + enterprise_only_cases = [ + TelemetryCase.NODE_EXECUTION, + TelemetryCase.DRAFT_NODE_EXECUTION, + TelemetryCase.PROMPT_GENERATION, + ] + for case in enterprise_only_cases: + assert CASE_ROUTING[case].ce_eligible is False, f"{case} should be enterprise-only" + + def test_trace_cases_have_task_name_mapping(self) -> None: + trace_cases = [c for c in TelemetryCase if CASE_ROUTING[c].signal_type is SignalType.TRACE] + for case in trace_cases: + assert case in CASE_TO_TRACE_TASK, f"Missing TraceTaskName mapping for {case}" + + +@pytest.fixture +def mock_ops_trace_manager(): + mock_module = MagicMock() + mock_trace_task_class = MagicMock() + mock_trace_task_class.return_value = MagicMock() + mock_module.TraceTask = mock_trace_task_class + mock_module.TraceQueueManager = MagicMock() + + mock_trace_entity = MagicMock() + mock_trace_task_name = MagicMock() + mock_trace_task_name.return_value = "workflow" + mock_trace_entity.TraceTaskName = mock_trace_task_name + + with ( + patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}), + patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}), + ): + yield mock_module, mock_trace_entity + + +class TestTelemetryGatewayTraceRouting: + @pytest.fixture + def gateway(self) -> TelemetryGateway: + return TelemetryGateway() + + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True) + def test_trace_case_routes_to_trace_manager( + self, + _mock_ee_enabled: MagicMock, + gateway: TelemetryGateway, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"} + payload = {"workflow_run_id": "run-abc"} + + gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False) + def test_ce_eligible_trace_enqueued_when_ee_disabled( + self, + _mock_ee_enabled: MagicMock, + gateway: TelemetryGateway, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False) + def test_enterprise_only_trace_dropped_when_ee_disabled( + self, + _mock_ee_enabled: MagicMock, + gateway: TelemetryGateway, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True) + def test_enterprise_only_trace_enqueued_when_ee_enabled( + self, + _mock_ee_enabled: MagicMock, + gateway: TelemetryGateway, + mock_trace_manager: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + +class TestTelemetryGatewayMetricLogRouting: + @pytest.fixture + def gateway(self) -> TelemetryGateway: + return TelemetryGateway() + + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_metric_case_routes_to_celery_task( + self, + mock_delay: MagicMock, + gateway: TelemetryGateway, + ) -> None: + context = {"tenant_id": "tenant-123"} + payload = {"app_id": "app-abc", "name": "My App"} + + gateway.emit(TelemetryCase.APP_CREATED, context, payload) + + mock_delay.assert_called_once() + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.case == TelemetryCase.APP_CREATED + assert envelope.tenant_id == "tenant-123" + assert envelope.payload["app_id"] == "app-abc" + + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_envelope_has_unique_event_id( + self, + mock_delay: MagicMock, + gateway: TelemetryGateway, + ) -> None: + context = {"tenant_id": "tenant-123"} + payload = {"app_id": "app-abc"} + + gateway.emit(TelemetryCase.APP_CREATED, context, payload) + gateway.emit(TelemetryCase.APP_CREATED, context, payload) + + assert mock_delay.call_count == 2 + envelope1 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[0][0][0]) + envelope2 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[1][0][0]) + assert envelope1.event_id != envelope2.event_id + + +class TestTelemetryGatewayPayloadSizing: + @pytest.fixture + def gateway(self) -> TelemetryGateway: + return TelemetryGateway() + + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_small_payload_inlined( + self, + mock_delay: MagicMock, + gateway: TelemetryGateway, + ) -> None: + context = {"tenant_id": "tenant-123"} + payload = {"key": "small_value"} + + gateway.emit(TelemetryCase.APP_CREATED, context, payload) + + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.payload == payload + assert envelope.metadata is None + + @patch("enterprise.telemetry.gateway.storage") + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_large_payload_stored( + self, + mock_delay: MagicMock, + mock_storage: MagicMock, + gateway: TelemetryGateway, + ) -> None: + context = {"tenant_id": "tenant-123"} + large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000) + payload = {"key": large_value} + + gateway.emit(TelemetryCase.APP_CREATED, context, payload) + + mock_storage.save.assert_called_once() + storage_key = mock_storage.save.call_args[0][0] + assert storage_key.startswith("telemetry/tenant-123/") + + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.payload == {} + assert envelope.metadata is not None + assert envelope.metadata["payload_ref"] == storage_key + + @patch("enterprise.telemetry.gateway.storage") + @patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") + def test_large_payload_fallback_on_storage_error( + self, + mock_delay: MagicMock, + mock_storage: MagicMock, + gateway: TelemetryGateway, + ) -> None: + mock_storage.save.side_effect = Exception("Storage failure") + context = {"tenant_id": "tenant-123"} + large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000) + payload = {"key": large_value} + + gateway.emit(TelemetryCase.APP_CREATED, context, payload) + + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.payload == payload + assert envelope.metadata is None + + +class TestModuleLevelFunctions: + @patch("extensions.ext_enterprise_telemetry.get_gateway") + @patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True) + def test_emit_function_uses_gateway( + self, + _mock_ee_enabled: MagicMock, + mock_get_gateway: MagicMock, + mock_ops_trace_manager: tuple[MagicMock, MagicMock], + ) -> None: + mock_gateway = TelemetryGateway() + mock_get_gateway.return_value = mock_gateway + mock_trace_manager = MagicMock() + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + with patch.object(mock_gateway, "emit") as mock_emit: + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + mock_emit.assert_called_once_with(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + +class TestTraceTaskNameMapping: + def test_workflow_run_mapping(self) -> None: + assert CASE_TO_TRACE_TASK[TelemetryCase.WORKFLOW_RUN] is TraceTaskName.WORKFLOW_TRACE + + def test_message_run_mapping(self) -> None: + assert CASE_TO_TRACE_TASK[TelemetryCase.MESSAGE_RUN] is TraceTaskName.MESSAGE_TRACE + + def test_node_execution_mapping(self) -> None: + assert CASE_TO_TRACE_TASK[TelemetryCase.NODE_EXECUTION] is TraceTaskName.NODE_EXECUTION_TRACE + + def test_draft_node_execution_mapping(self) -> None: + assert CASE_TO_TRACE_TASK[TelemetryCase.DRAFT_NODE_EXECUTION] is TraceTaskName.DRAFT_NODE_EXECUTION_TRACE + + def test_prompt_generation_mapping(self) -> None: + assert CASE_TO_TRACE_TASK[TelemetryCase.PROMPT_GENERATION] is TraceTaskName.PROMPT_GENERATION_TRACE diff --git a/api/tests/unit_tests/enterprise/telemetry/test_metric_handler.py b/api/tests/unit_tests/enterprise/telemetry/test_metric_handler.py new file mode 100644 index 0000000000..581e1631d5 --- /dev/null +++ b/api/tests/unit_tests/enterprise/telemetry/test_metric_handler.py @@ -0,0 +1,452 @@ +"""Unit tests for EnterpriseMetricHandler.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope +from enterprise.telemetry.metric_handler import EnterpriseMetricHandler + + +@pytest.fixture +def mock_redis(): + with patch("enterprise.telemetry.metric_handler.redis_client") as mock: + yield mock + + +@pytest.fixture +def sample_envelope(): + return TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-123", + payload={"app_id": "app-123", "name": "Test App"}, + ) + + +def test_dispatch_app_created(sample_envelope, mock_redis): + mock_redis.set.return_value = True + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_app_created") as mock_handler: + handler.handle(sample_envelope) + mock_handler.assert_called_once_with(sample_envelope) + + +def test_dispatch_app_updated(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_UPDATED, + tenant_id="test-tenant", + event_id="test-event-456", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_app_updated") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_app_deleted(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_DELETED, + tenant_id="test-tenant", + event_id="test-event-789", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_app_deleted") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_feedback_created(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.FEEDBACK_CREATED, + tenant_id="test-tenant", + event_id="test-event-abc", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_feedback_created") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_message_run(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.MESSAGE_RUN, + tenant_id="test-tenant", + event_id="test-event-msg", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_message_run") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_tool_execution(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.TOOL_EXECUTION, + tenant_id="test-tenant", + event_id="test-event-tool", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_tool_execution") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_moderation_check(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.MODERATION_CHECK, + tenant_id="test-tenant", + event_id="test-event-mod", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_moderation_check") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_suggested_question(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.SUGGESTED_QUESTION, + tenant_id="test-tenant", + event_id="test-event-sq", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_suggested_question") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_dataset_retrieval(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.DATASET_RETRIEVAL, + tenant_id="test-tenant", + event_id="test-event-ds", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_dataset_retrieval") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_generate_name(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.GENERATE_NAME, + tenant_id="test-tenant", + event_id="test-event-gn", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_generate_name") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_dispatch_prompt_generation(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.PROMPT_GENERATION, + tenant_id="test-tenant", + event_id="test-event-pg", + payload={}, + ) + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_prompt_generation") as mock_handler: + handler.handle(envelope) + mock_handler.assert_called_once_with(envelope) + + +def test_all_known_cases_have_handlers(mock_redis): + mock_redis.set.return_value = True + handler = EnterpriseMetricHandler() + + for case in TelemetryCase: + envelope = TelemetryEnvelope( + case=case, + tenant_id="test-tenant", + event_id=f"test-{case.value}", + payload={}, + ) + handler.handle(envelope) + + +def test_idempotency_duplicate(sample_envelope, mock_redis): + mock_redis.set.return_value = None + + handler = EnterpriseMetricHandler() + with patch.object(handler, "_on_app_created") as mock_handler: + handler.handle(sample_envelope) + mock_handler.assert_not_called() + + +def test_idempotency_first_seen(sample_envelope, mock_redis): + mock_redis.set.return_value = True + + handler = EnterpriseMetricHandler() + is_dup = handler._is_duplicate(sample_envelope) + + assert is_dup is False + mock_redis.set.assert_called_once_with( + "telemetry:dedup:test-tenant:test-event-123", + b"1", + nx=True, + ex=3600, + ) + + +def test_idempotency_redis_failure_fails_open(sample_envelope, mock_redis, caplog): + mock_redis.set.side_effect = Exception("Redis unavailable") + + handler = EnterpriseMetricHandler() + is_dup = handler._is_duplicate(sample_envelope) + + assert is_dup is False + assert "Redis unavailable for deduplication check" in caplog.text + + +def test_rehydration_uses_payload(sample_envelope): + handler = EnterpriseMetricHandler() + payload = handler._rehydrate(sample_envelope) + + assert payload == {"app_id": "app-123", "name": "Test App"} + + +def test_rehydration_fallback(): + import pickle + + fallback_data = {"fallback": "data"} + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-fb", + payload={}, + payload_fallback=pickle.dumps(fallback_data), + ) + + handler = EnterpriseMetricHandler() + payload = handler._rehydrate(envelope) + + assert payload == fallback_data + + +def test_rehydration_emits_degraded_event_on_failure(): + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-fail", + payload={}, + payload_fallback=None, + ) + + handler = EnterpriseMetricHandler() + with patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit: + payload = handler._rehydrate(envelope) + + assert payload == {} + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert call_args[1]["event_name"] == "dify.telemetry.rehydration_failed" + assert call_args[1]["attributes"]["rehydration_failed"] is True + + +def test_on_app_created_emits_correct_event(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="tenant-123", + event_id="event-456", + payload={"app_id": "app-789", "mode": "chat"}, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_exporter = MagicMock() + mock_get_exporter.return_value = mock_exporter + + handler._on_app_created(envelope) + + mock_emit.assert_called_once_with( + event_name="dify.app.created", + attributes={ + "dify.app.id": "app-789", + "dify.tenant_id": "tenant-123", + "dify.app.mode": "chat", + }, + tenant_id="tenant-123", + ) + mock_exporter.increment_counter.assert_called_once() + call_args = mock_exporter.increment_counter.call_args + assert call_args[0][1] == 1 + assert call_args[0][2]["type"] == "app.created" + assert call_args[0][2]["tenant_id"] == "tenant-123" + + +def test_on_app_updated_emits_correct_event(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_UPDATED, + tenant_id="tenant-123", + event_id="event-456", + payload={"app_id": "app-789"}, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_exporter = MagicMock() + mock_get_exporter.return_value = mock_exporter + + handler._on_app_updated(envelope) + + mock_emit.assert_called_once_with( + event_name="dify.app.updated", + attributes={ + "dify.app.id": "app-789", + "dify.tenant_id": "tenant-123", + }, + tenant_id="tenant-123", + ) + mock_exporter.increment_counter.assert_called_once() + call_args = mock_exporter.increment_counter.call_args + assert call_args[0][2]["type"] == "app.updated" + + +def test_on_app_deleted_emits_correct_event(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_DELETED, + tenant_id="tenant-123", + event_id="event-456", + payload={"app_id": "app-789"}, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_exporter = MagicMock() + mock_get_exporter.return_value = mock_exporter + + handler._on_app_deleted(envelope) + + mock_emit.assert_called_once_with( + event_name="dify.app.deleted", + attributes={ + "dify.app.id": "app-789", + "dify.tenant_id": "tenant-123", + }, + tenant_id="tenant-123", + ) + mock_exporter.increment_counter.assert_called_once() + call_args = mock_exporter.increment_counter.call_args + assert call_args[0][2]["type"] == "app.deleted" + + +def test_on_feedback_created_emits_correct_event(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.FEEDBACK_CREATED, + tenant_id="tenant-123", + event_id="event-456", + payload={ + "message_id": "msg-001", + "app_id": "app-789", + "conversation_id": "conv-123", + "from_end_user_id": "user-456", + "from_account_id": None, + "rating": "like", + "from_source": "api", + "content": "Great!", + }, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_exporter = MagicMock() + mock_exporter.include_content = True + mock_get_exporter.return_value = mock_exporter + + handler._on_feedback_created(envelope) + + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert call_args[1]["event_name"] == "dify.feedback.created" + assert call_args[1]["attributes"]["dify.message.id"] == "msg-001" + assert call_args[1]["attributes"]["dify.feedback.content"] == "Great!" + assert call_args[1]["tenant_id"] == "tenant-123" + assert call_args[1]["user_id"] == "user-456" + + mock_exporter.increment_counter.assert_called_once() + counter_args = mock_exporter.increment_counter.call_args + assert counter_args[0][2]["app_id"] == "app-789" + assert counter_args[0][2]["rating"] == "like" + + +def test_on_feedback_created_without_content(mock_redis): + mock_redis.set.return_value = True + envelope = TelemetryEnvelope( + case=TelemetryCase.FEEDBACK_CREATED, + tenant_id="tenant-123", + event_id="event-456", + payload={ + "message_id": "msg-001", + "app_id": "app-789", + "conversation_id": "conv-123", + "from_end_user_id": "user-456", + "from_account_id": None, + "rating": "like", + "from_source": "api", + "content": "Great!", + }, + ) + + handler = EnterpriseMetricHandler() + with ( + patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter, + patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit, + ): + mock_exporter = MagicMock() + mock_exporter.include_content = False + mock_get_exporter.return_value = mock_exporter + + handler._on_feedback_created(envelope) + + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert "dify.feedback.content" not in call_args[1]["attributes"] diff --git a/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py b/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py new file mode 100644 index 0000000000..b48c69a146 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py @@ -0,0 +1,69 @@ +"""Unit tests for enterprise telemetry Celery task.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope +from tasks.enterprise_telemetry_task import process_enterprise_telemetry + + +@pytest.fixture +def sample_envelope_json(): + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-123", + payload={"app_id": "app-123"}, + ) + return envelope.model_dump_json() + + +def test_process_enterprise_telemetry_success(sample_envelope_json): + with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class: + mock_handler = MagicMock() + mock_handler_class.return_value = mock_handler + + process_enterprise_telemetry(sample_envelope_json) + + mock_handler.handle.assert_called_once() + call_args = mock_handler.handle.call_args[0][0] + assert isinstance(call_args, TelemetryEnvelope) + assert call_args.case == TelemetryCase.APP_CREATED + assert call_args.tenant_id == "test-tenant" + assert call_args.event_id == "test-event-123" + + +def test_process_enterprise_telemetry_invalid_json(caplog): + invalid_json = "not valid json" + + process_enterprise_telemetry(invalid_json) + + assert "Failed to process enterprise telemetry envelope" in caplog.text + + +def test_process_enterprise_telemetry_handler_exception(sample_envelope_json, caplog): + with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class: + mock_handler = MagicMock() + mock_handler.handle.side_effect = Exception("Handler error") + mock_handler_class.return_value = mock_handler + + process_enterprise_telemetry(sample_envelope_json) + + assert "Failed to process enterprise telemetry envelope" in caplog.text + + +def test_process_enterprise_telemetry_validation_error(caplog): + invalid_envelope = json.dumps( + { + "case": "INVALID_CASE", + "tenant_id": "test-tenant", + "event_id": "test-event", + "payload": {}, + } + ) + + process_enterprise_telemetry(invalid_envelope) + + assert "Failed to process enterprise telemetry envelope" in caplog.text