From 174c712a4641df7eccda8778222fd29d1de8ccc6 Mon Sep 17 00:00:00 2001 From: Alonso Guevara Date: Thu, 13 Feb 2025 15:00:16 -0600 Subject: [PATCH] LLm Factory + Provider base --- graphrag/config/enums.py | 3 + graphrag/llm/__init__.py | 2 + graphrag/llm/clients/__init__.py | 0 graphrag/llm/clients/azure_openai.py | 0 graphrag/llm/clients/base_llm.py | 17 +++++ graphrag/llm/clients/mock.py | 0 graphrag/llm/clients/openai.py | 0 graphrag/llm/clients/static.py | 0 graphrag/llm/factory.py | 47 ++++++++++++++ graphrag/llm/manager.py | 97 ++++++++++++++++++++++++++++ graphrag/llm/protocols/__init__.py | 7 ++ graphrag/llm/protocols/chat.py | 19 ++++++ graphrag/llm/protocols/embedding.py | 19 ++++++ 13 files changed, 211 insertions(+) create mode 100644 graphrag/llm/__init__.py create mode 100644 graphrag/llm/clients/__init__.py create mode 100644 graphrag/llm/clients/azure_openai.py create mode 100644 graphrag/llm/clients/base_llm.py create mode 100644 graphrag/llm/clients/mock.py create mode 100644 graphrag/llm/clients/openai.py create mode 100644 graphrag/llm/clients/static.py create mode 100644 graphrag/llm/factory.py create mode 100644 graphrag/llm/manager.py create mode 100644 graphrag/llm/protocols/__init__.py create mode 100644 graphrag/llm/protocols/chat.py create mode 100644 graphrag/llm/protocols/embedding.py diff --git a/graphrag/config/enums.py b/graphrag/config/enums.py index 450ec6bd..1c2cde0d 100644 --- a/graphrag/config/enums.py +++ b/graphrag/config/enums.py @@ -112,6 +112,9 @@ class LLMType(str, Enum): # Debug StaticResponse = "static_response" + # Mock + Mock = "mock" + def __repr__(self): """Get a string representation.""" return f'"{self.value}"' diff --git a/graphrag/llm/__init__.py b/graphrag/llm/__init__.py new file mode 100644 index 00000000..9f51bd8a --- /dev/null +++ b/graphrag/llm/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2025 Microsoft Corporation. +# Licensed under the MIT License \ No newline at end of file diff --git a/graphrag/llm/clients/__init__.py b/graphrag/llm/clients/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphrag/llm/clients/azure_openai.py b/graphrag/llm/clients/azure_openai.py new file mode 100644 index 00000000..e69de29b diff --git a/graphrag/llm/clients/base_llm.py b/graphrag/llm/clients/base_llm.py new file mode 100644 index 00000000..1cdc6a20 --- /dev/null +++ b/graphrag/llm/clients/base_llm.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing the base LLM class.""" + +def BaseLLM(Protocol): + """A base class for LLMs.""" + def __init__(self): + pass + + def get_response(self, input: str) -> str: + """Get a response from the LLM.""" + pass + + def get_embedding(self, input: str) -> list[float]: + """Get an embedding from the LLM.""" + pass \ No newline at end of file diff --git a/graphrag/llm/clients/mock.py b/graphrag/llm/clients/mock.py new file mode 100644 index 00000000..e69de29b diff --git a/graphrag/llm/clients/openai.py b/graphrag/llm/clients/openai.py new file mode 100644 index 00000000..e69de29b diff --git a/graphrag/llm/clients/static.py b/graphrag/llm/clients/static.py new file mode 100644 index 00000000..e69de29b diff --git a/graphrag/llm/factory.py b/graphrag/llm/factory.py new file mode 100644 index 00000000..72f0ca6d --- /dev/null +++ b/graphrag/llm/factory.py @@ -0,0 +1,47 @@ +# Copyright (c) 2025 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing a factory for supported llm types.""" + +from typing import Any, Callable +from config.enums import LLMType +from graphrag.llm.protocols.chat import ChatLLM +from graphrag.llm.protocols.embedding import EmbeddingLLM + + +class LLMFactory: + """A factory for creating LLM instances.""" + + _chat_registry: dict[str, Callable[..., ChatLLM]] = {} + _embedding_registry: dict[str, Callable[..., EmbeddingLLM]] = {} + + @classmethod + def register_chat(cls, key: str, creator: Callable[..., ChatLLM]) -> None: + cls._chat_registry[key] = creator + + @classmethod + def register_embedding(cls, key: str, creator: Callable[..., EmbeddingLLM]) -> None: + cls._embedding_registry[key] = creator + + @classmethod + def create_chat_llm(cls, key: str, **kwargs: Any) -> ChatLLM: + if key not in cls._chat_registry: + msg = f"ChatLLM implementation '{key}' is not registered." + raise ValueError(msg) + return cls._chat_registry[key](**kwargs) + + @classmethod + def create_embedding_llm(cls, key: str, **kwargs: Any) -> EmbeddingLLM: + if key not in cls._embedding_registry: + msg = f"EmbeddingLLM implementation '{key}' is not registered." + raise ValueError(msg) + return cls._embedding_registry[key](**kwargs) + +# --- Register default implementations --- +LLMFactory.register_chat(LLMType.AzureOpenAIChat, lambda **kwargs: AzureOpenAIChat(**kwargs)) +LLMFactory.register_chat(LLMType.OpenAIChat, lambda **kwargs: OpenAIChat(**kwargs)) +LLMFactory.register_chat(LLMType.Mock, lambda **kwargs: MockChatLLM()) + +LLMFactory.register_embedding(LLMType.AzureOpenAIEmbedding, lambda **kwargs: AzureOpenAIEmbedding(**kwargs)) +LLMFactory.register_embedding(LLMType.OpenAIEmbedding, lambda **kwargs: OpenAIEmbedding(**kwargs)) +LLMFactory.register_embedding(LLMType.Mock, lambda **kwargs: MockEmbeddingLLM()) diff --git a/graphrag/llm/manager.py b/graphrag/llm/manager.py new file mode 100644 index 00000000..019986f9 --- /dev/null +++ b/graphrag/llm/manager.py @@ -0,0 +1,97 @@ +# Copyright (c) 2025 Microsoft Corporation. +# Licensed under the MIT License + +"""Singleton LLM Manager for ChatLLM and EmbeddingsLLM instances. + +This manager lets you register chat and embeddings LLMs independently. +It leverages the LLMFactory for instantiation. +""" + +from __future__ import annotations + +from typing import Any +from llm.protocols import ChatLLM, EmbeddingLLM +from llm.factory import LLMFactory + + +class LLMManager: + """Singleton manager for LLM instances.""" + + _instance: LLMManager | None = None + + def __new__(cls) -> LLMManager: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self) -> None: + # Avoid reinitialization in the singleton. + if not hasattr(self, "_initialized"): + self.chat_llms: dict[str, ChatLLM] = {} + self.embedding_llms: dict[str, EmbeddingLLM] = {} + self._initialized = True + + @classmethod + def get_instance(cls) -> LLMManager: + """Returns the singleton instance of LLMManager.""" + return cls.__new__(cls) + + def register_chat(self, name: str, chat_key: str, **chat_kwargs: Any) -> None: + """ + Registers a ChatLLM instance under a unique name. + + Args: + name: Unique identifier for the ChatLLM instance. + chat_key: Key for the ChatLLM implementation in LLMFactory. + **chat_kwargs: Additional parameters for instantiation. + """ + self.chat_llms[name] = LLMFactory.create_chat_llm(chat_key, **chat_kwargs) + + def register_embedding(self, name: str, embedding_key: str, **embedding_kwargs: Any) -> None: + """ + Registers an EmbeddingsLLM instance under a unique name. + + Args: + name: Unique identifier for the EmbeddingsLLM instance. + embedding_key: Key for the EmbeddingsLLM implementation in LLMFactory. + **embedding_kwargs: Additional parameters for instantiation. + """ + self.embedding_llms[name] = LLMFactory.create_embedding_llm(embedding_key, **embedding_kwargs) + + def get_chat_llm(self, name: str) -> ChatLLM: + """ + Retrieves the ChatLLM instance registered under the given name. + + Raises: + ValueError: If no ChatLLM is registered under the name. + """ + if name not in self.chat_llms: + raise ValueError(f"No ChatLLM registered under name '{name}'.") + return self.chat_llms[name] + + def get_embedding_llm(self, name: str) -> EmbeddingLLM: + """ + Retrieves the EmbeddingsLLM instance registered under the given name. + + Raises: + ValueError: If no EmbeddingsLLM is registered under the name. + """ + if name not in self.embedding_llms: + raise ValueError(f"No EmbeddingsLLM registered under name '{name}'.") + return self.embedding_llms[name] + + def remove_chat(self, name: str) -> None: + """Removes the ChatLLM instance registered under the given name.""" + self.chat_llms.pop(name, None) + + def remove_embedding(self, name: str) -> None: + """Removes the EmbeddingsLLM instance registered under the given name.""" + self.embedding_llms.pop(name, None) + + def list_chat_llms(self) -> dict[str, ChatLLM]: + """Returns a copy of all registered ChatLLM instances.""" + return {k: v for k, v in self.chat_llms.items()} + + def list_embedding_llms(self) -> dict[str, EmbeddingLLM]: + """Returns a copy of all registered EmbeddingsLLM instances.""" + return {k: v for k, v in self.embedding_llms.items()} \ No newline at end of file diff --git a/graphrag/llm/protocols/__init__.py b/graphrag/llm/protocols/__init__.py new file mode 100644 index 00000000..72602bbb --- /dev/null +++ b/graphrag/llm/protocols/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025 Microsoft Corporation. +# Licensed under the MIT License + +from .chat import ChatLLM +from .embedding import EmbeddingLLM + +__all__ = ["ChatLLM", "EmbeddingLLM"] diff --git a/graphrag/llm/protocols/chat.py b/graphrag/llm/protocols/chat.py new file mode 100644 index 00000000..d527dd08 --- /dev/null +++ b/graphrag/llm/protocols/chat.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025 Microsoft Corporation. +# Licensed under the MIT License + +from __future__ import annotations +from typing import Protocol, List, Any + +class ChatLLM(Protocol): + def chat(self, prompt: str, **kwargs: Any) -> str: + """ + Generate a chat response based on the provided prompt. + + Args: + prompt: The text prompt to generate a response for. + **kwargs: Additional keyword arguments (e.g., temperature, max_tokens). + + Returns: + A string response generated by the LLM. + """ + ... \ No newline at end of file diff --git a/graphrag/llm/protocols/embedding.py b/graphrag/llm/protocols/embedding.py new file mode 100644 index 00000000..57351069 --- /dev/null +++ b/graphrag/llm/protocols/embedding.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025 Microsoft Corporation. +# Licensed under the MIT License + +from __future__ import annotations +from typing import Protocol, List, Any + +class EmbeddingLLM(Protocol): + def embed(self, text: str, **kwargs: Any) -> List[float]: + """ + Generate an embedding vector for the given text. + + Args: + text: The text to generate an embedding for. + **kwargs: Additional keyword arguments (e.g., model parameters). + + Returns: + A list of floats representing the embedding vector. + """ + ... \ No newline at end of file