diff --git a/agent-notes/api/core/model_runtime/model_providers/__base/large_language_model.py.md b/agent-notes/api/core/model_runtime/model_providers/__base/large_language_model.py.md new file mode 100644 index 0000000000..f03c41cc25 --- /dev/null +++ b/agent-notes/api/core/model_runtime/model_providers/__base/large_language_model.py.md @@ -0,0 +1,27 @@ +# Notes: `large_language_model.py` + +## Purpose + +Provides the base `LargeLanguageModel` implementation used by the model runtime to invoke plugin-backed LLMs and to +bridge plugin daemon streaming semantics back into API-layer entities (`LLMResult`, `LLMResultChunk`). + +## Key behaviors / invariants + +- `invoke(..., stream=False)` still calls the plugin in streaming mode and then synthesizes a single `LLMResult` from + the first yielded `LLMResultChunk`. +- Plugin invocation is wrapped by `_invoke_llm_via_plugin(...)`, and `stream=False` normalization is handled by + `_normalize_non_stream_plugin_result(...)` / `_build_llm_result_from_first_chunk(...)`. +- Tool call deltas are merged incrementally via `_increase_tool_call(...)` to support multiple provider chunking + patterns (IDs anchored to first chunk, every chunk, or missing entirely). +- A tool-call delta with an empty `id` requires at least one existing tool call; otherwise we raise `ValueError` to + surface invalid delta sequences explicitly. +- Callback invocation is centralized in `_run_callbacks(...)` to ensure consistent error handling/logging. +- For compatibility with dify issue `#17799`, `prompt_messages` may be removed by the plugin daemon in chunks and must + be re-attached in this layer before callbacks/consumers use them. +- Callback hooks (`on_before_invoke`, `on_new_chunk`, `on_after_invoke`, `on_invoke_error`) must not break invocation + unless `callback.raise_error` is true. + +## Test focus + +- `api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py` validates tool-call delta merging and + patches `_gen_tool_call_id` for deterministic IDs. diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index c0f4c504d9..7a0757f219 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -1,7 +1,7 @@ import logging import time import uuid -from collections.abc import Generator, Sequence +from collections.abc import Callable, Generator, Iterator, Sequence from typing import Union from pydantic import ConfigDict @@ -30,6 +30,142 @@ def _gen_tool_call_id() -> str: return f"chatcmpl-tool-{str(uuid.uuid4().hex)}" +def _run_callbacks(callbacks: Sequence[Callback] | None, *, event: str, invoke: Callable[[Callback], None]) -> None: + if not callbacks: + return + + for callback in callbacks: + try: + invoke(callback) + except Exception as e: + if callback.raise_error: + raise + logger.warning("Callback %s %s failed with error %s", callback.__class__.__name__, event, e) + + +def _get_or_create_tool_call( + existing_tools_calls: list[AssistantPromptMessage.ToolCall], + tool_call_id: str, +) -> AssistantPromptMessage.ToolCall: + """ + Get or create a tool call by ID. + + If `tool_call_id` is empty, returns the most recently created tool call. + """ + if not tool_call_id: + if not existing_tools_calls: + raise ValueError("tool_call_id is empty but no existing tool call is available to apply the delta") + return existing_tools_calls[-1] + + tool_call = next((tool_call for tool_call in existing_tools_calls if tool_call.id == tool_call_id), None) + if tool_call is None: + tool_call = AssistantPromptMessage.ToolCall( + id=tool_call_id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), + ) + existing_tools_calls.append(tool_call) + + return tool_call + + +def _merge_tool_call_delta( + tool_call: AssistantPromptMessage.ToolCall, + delta: AssistantPromptMessage.ToolCall, +) -> None: + if delta.id: + tool_call.id = delta.id + if delta.type: + tool_call.type = delta.type + if delta.function.name: + tool_call.function.name = delta.function.name + if delta.function.arguments: + tool_call.function.arguments += delta.function.arguments + + +def _build_llm_result_from_first_chunk( + model: str, + prompt_messages: Sequence[PromptMessage], + chunks: Iterator[LLMResultChunk], +) -> LLMResult: + """ + Build a single `LLMResult` from the first returned chunk. + + This is used for `stream=False` because the plugin side may still implement the response via a chunked stream. + """ + content = "" + content_list: list[PromptMessageContentUnionTypes] = [] + usage = LLMUsage.empty_usage() + system_fingerprint: str | None = None + tools_calls: list[AssistantPromptMessage.ToolCall] = [] + + first_chunk = next(chunks, None) + if first_chunk is not None: + if isinstance(first_chunk.delta.message.content, str): + content += first_chunk.delta.message.content + elif isinstance(first_chunk.delta.message.content, list): + content_list.extend(first_chunk.delta.message.content) + + if first_chunk.delta.message.tool_calls: + _increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls) + + usage = first_chunk.delta.usage or LLMUsage.empty_usage() + system_fingerprint = first_chunk.system_fingerprint + + return LLMResult( + model=model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage( + content=content or content_list, + tool_calls=tools_calls, + ), + usage=usage, + system_fingerprint=system_fingerprint, + ) + + +def _invoke_llm_via_plugin( + *, + tenant_id: str, + user_id: str, + plugin_id: str, + provider: str, + model: str, + credentials: dict, + model_parameters: dict, + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, +) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: + from core.plugin.impl.model import PluginModelClient + + plugin_model_manager = PluginModelClient() + return plugin_model_manager.invoke_llm( + tenant_id=tenant_id, + user_id=user_id, + plugin_id=plugin_id, + provider=provider, + model=model, + credentials=credentials, + model_parameters=model_parameters, + prompt_messages=list(prompt_messages), + tools=tools, + stop=list(stop) if stop else None, + stream=stream, + ) + + +def _normalize_non_stream_plugin_result( + model: str, + prompt_messages: Sequence[PromptMessage], + result: Union[LLMResult, Iterator[LLMResultChunk]], +) -> LLMResult: + if isinstance(result, LLMResult): + return result + return _build_llm_result_from_first_chunk(model=model, prompt_messages=prompt_messages, chunks=result) + + def _increase_tool_call( new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall] ): @@ -40,42 +176,13 @@ def _increase_tool_call( :param existing_tools_calls: List of existing tool calls to be modified IN-PLACE. """ - def get_tool_call(tool_call_id: str): - """ - Get or create a tool call by ID - - :param tool_call_id: tool call ID - :return: existing or new tool call - """ - if not tool_call_id: - return existing_tools_calls[-1] - - _tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None) - if _tool_call is None: - _tool_call = AssistantPromptMessage.ToolCall( - id=tool_call_id, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), - ) - existing_tools_calls.append(_tool_call) - - return _tool_call - for new_tool_call in new_tool_calls: # generate ID for tool calls with function name but no ID to track them if new_tool_call.function.name and not new_tool_call.id: new_tool_call.id = _gen_tool_call_id() - # get tool call - tool_call = get_tool_call(new_tool_call.id) - # update tool call - if new_tool_call.id: - tool_call.id = new_tool_call.id - if new_tool_call.type: - tool_call.type = new_tool_call.type - if new_tool_call.function.name: - tool_call.function.name = new_tool_call.function.name - if new_tool_call.function.arguments: - tool_call.function.arguments += new_tool_call.function.arguments + + tool_call = _get_or_create_tool_call(existing_tools_calls, new_tool_call.id) + _merge_tool_call_delta(tool_call, new_tool_call) class LargeLanguageModel(AIModel): @@ -141,10 +248,7 @@ class LargeLanguageModel(AIModel): result: Union[LLMResult, Generator[LLMResultChunk, None, None]] try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - result = plugin_model_manager.invoke_llm( + result = _invoke_llm_via_plugin( tenant_id=self.tenant_id, user_id=user or "unknown", plugin_id=self.plugin_id, @@ -154,38 +258,13 @@ class LargeLanguageModel(AIModel): model_parameters=model_parameters, prompt_messages=prompt_messages, tools=tools, - stop=list(stop) if stop else None, + stop=stop, stream=stream, ) if not stream: - content = "" - content_list = [] - usage = LLMUsage.empty_usage() - system_fingerprint = None - tools_calls: list[AssistantPromptMessage.ToolCall] = [] - - for chunk in result: - if isinstance(chunk.delta.message.content, str): - content += chunk.delta.message.content - elif isinstance(chunk.delta.message.content, list): - content_list.extend(chunk.delta.message.content) - if chunk.delta.message.tool_calls: - _increase_tool_call(chunk.delta.message.tool_calls, tools_calls) - - usage = chunk.delta.usage or LLMUsage.empty_usage() - system_fingerprint = chunk.system_fingerprint - break - - result = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=content or content_list, - tool_calls=tools_calls, - ), - usage=usage, - system_fingerprint=system_fingerprint, + result = _normalize_non_stream_plugin_result( + model=model, prompt_messages=prompt_messages, result=result ) except Exception as e: self._trigger_invoke_error_callbacks( @@ -425,27 +504,21 @@ class LargeLanguageModel(AIModel): :param user: unique user id :param callbacks: callbacks """ - if callbacks: - for callback in callbacks: - try: - callback.on_before_invoke( - llm_instance=self, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ) - except Exception as e: - if callback.raise_error: - raise e - else: - logger.warning( - "Callback %s on_before_invoke failed with error %s", callback.__class__.__name__, e - ) + _run_callbacks( + callbacks, + event="on_before_invoke", + invoke=lambda callback: callback.on_before_invoke( + llm_instance=self, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ), + ) def _trigger_new_chunk_callbacks( self, @@ -473,26 +546,22 @@ class LargeLanguageModel(AIModel): :param stream: is stream response :param user: unique user id """ - if callbacks: - for callback in callbacks: - try: - callback.on_new_chunk( - llm_instance=self, - chunk=chunk, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ) - except Exception as e: - if callback.raise_error: - raise e - else: - logger.warning("Callback %s on_new_chunk failed with error %s", callback.__class__.__name__, e) + _run_callbacks( + callbacks, + event="on_new_chunk", + invoke=lambda callback: callback.on_new_chunk( + llm_instance=self, + chunk=chunk, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ), + ) def _trigger_after_invoke_callbacks( self, @@ -521,28 +590,22 @@ class LargeLanguageModel(AIModel): :param user: unique user id :param callbacks: callbacks """ - if callbacks: - for callback in callbacks: - try: - callback.on_after_invoke( - llm_instance=self, - result=result, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ) - except Exception as e: - if callback.raise_error: - raise e - else: - logger.warning( - "Callback %s on_after_invoke failed with error %s", callback.__class__.__name__, e - ) + _run_callbacks( + callbacks, + event="on_after_invoke", + invoke=lambda callback: callback.on_after_invoke( + llm_instance=self, + result=result, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ), + ) def _trigger_invoke_error_callbacks( self, @@ -571,25 +634,19 @@ class LargeLanguageModel(AIModel): :param user: unique user id :param callbacks: callbacks """ - if callbacks: - for callback in callbacks: - try: - callback.on_invoke_error( - llm_instance=self, - ex=ex, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ) - except Exception as e: - if callback.raise_error: - raise e - else: - logger.warning( - "Callback %s on_invoke_error failed with error %s", callback.__class__.__name__, e - ) + _run_callbacks( + callbacks, + event="on_invoke_error", + invoke=lambda callback: callback.on_invoke_error( + llm_instance=self, + ex=ex, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ), + ) diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py b/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py index 93d8a20cac..5fbdabceed 100644 --- a/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py +++ b/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py @@ -1,5 +1,7 @@ from unittest.mock import MagicMock, patch +import pytest + from core.model_runtime.entities.message_entities import AssistantPromptMessage from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call @@ -97,3 +99,14 @@ def test__increase_tool_call(): mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4] with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator): _run_case(INPUTS_CASE_4, EXPECTED_CASE_4) + + +def test__increase_tool_call__no_id_no_name_first_delta_should_raise(): + inputs = [ + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='"value"}')), + ] + actual: list[ToolCall] = [] + with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", MagicMock()): + with pytest.raises(ValueError): + _increase_tool_call(inputs, actual) diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py b/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py new file mode 100644 index 0000000000..91352b2a5f --- /dev/null +++ b/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py @@ -0,0 +1,103 @@ +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.model_providers.__base.large_language_model import _normalize_non_stream_plugin_result + + +def _make_chunk( + *, + model: str = "test-model", + content: str | list[TextPromptMessageContent] | None, + tool_calls: list[AssistantPromptMessage.ToolCall] | None = None, + usage: LLMUsage | None = None, + system_fingerprint: str | None = None, +) -> LLMResultChunk: + message = AssistantPromptMessage(content=content, tool_calls=tool_calls or []) + delta = LLMResultChunkDelta(index=0, message=message, usage=usage) + return LLMResultChunk(model=model, delta=delta, system_fingerprint=system_fingerprint) + + +def test__normalize_non_stream_plugin_result__from_first_chunk_str_content_and_tool_calls(): + prompt_messages = [UserPromptMessage(content="hi")] + + tool_calls = [ + AssistantPromptMessage.ToolCall( + id="1", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="func_foo", arguments=""), + ), + AssistantPromptMessage.ToolCall( + id="", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments='{"arg1": '), + ), + AssistantPromptMessage.ToolCall( + id="", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments='"value"}'), + ), + ] + + usage = LLMUsage.empty_usage().model_copy(update={"prompt_tokens": 1, "total_tokens": 1}) + chunk = _make_chunk(content="hello", tool_calls=tool_calls, usage=usage, system_fingerprint="fp-1") + + result = _normalize_non_stream_plugin_result( + model="test-model", prompt_messages=prompt_messages, result=iter([chunk]) + ) + + assert result.model == "test-model" + assert result.prompt_messages == prompt_messages + assert result.message.content == "hello" + assert result.usage.prompt_tokens == 1 + assert result.system_fingerprint == "fp-1" + assert result.message.tool_calls == [ + AssistantPromptMessage.ToolCall( + id="1", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'), + ) + ] + + +def test__normalize_non_stream_plugin_result__from_first_chunk_list_content(): + prompt_messages = [UserPromptMessage(content="hi")] + + content_list = [TextPromptMessageContent(data="a"), TextPromptMessageContent(data="b")] + chunk = _make_chunk(content=content_list, usage=LLMUsage.empty_usage()) + + result = _normalize_non_stream_plugin_result( + model="test-model", prompt_messages=prompt_messages, result=iter([chunk]) + ) + + assert result.message.content == content_list + + +def test__normalize_non_stream_plugin_result__passthrough_llm_result(): + prompt_messages = [UserPromptMessage(content="hi")] + llm_result = LLMResult( + model="test-model", + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content="ok"), + usage=LLMUsage.empty_usage(), + ) + + assert ( + _normalize_non_stream_plugin_result(model="test-model", prompt_messages=prompt_messages, result=llm_result) + == llm_result + ) + + +def test__normalize_non_stream_plugin_result__empty_iterator_defaults(): + prompt_messages = [UserPromptMessage(content="hi")] + + result = _normalize_non_stream_plugin_result(model="test-model", prompt_messages=prompt_messages, result=iter([])) + + assert result.model == "test-model" + assert result.prompt_messages == prompt_messages + assert result.message.content == [] + assert result.message.tool_calls == [] + assert result.usage == LLMUsage.empty_usage() + assert result.system_fingerprint is None