graphrag/graphrag/index/run/derive_from_rows.py
Nathan Evans 7ec9ef0261
Refactor callbacks (#1583)
* Unify Workflow and Verb callbacks interfaces

* Semver

* Fix storage class instantiation (#1582)

---------

Co-authored-by: Josh Bradley <joshbradley@microsoft.com>
2025-01-06 10:58:59 -08:00

159 lines
4.7 KiB
Python

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Apply a generic transform function to each row in a table."""
import asyncio
import inspect
import logging
import traceback
from collections.abc import Awaitable, Callable, Coroutine, Hashable
from typing import Any, TypeVar, cast
import pandas as pd
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import AsyncType
from graphrag.logger.progress import progress_ticker
logger = logging.getLogger(__name__)
ItemType = TypeVar("ItemType")
class ParallelizationError(ValueError):
"""Exception for invalid parallel processing."""
def __init__(self, num_errors: int):
super().__init__(
f"{num_errors} Errors occurred while running parallel transformation, could not complete!"
)
async def derive_from_rows(
input: pd.DataFrame,
transform: Callable[[pd.Series], Awaitable[ItemType]],
callbacks: WorkflowCallbacks,
num_threads: int = 4,
async_type: AsyncType = AsyncType.AsyncIO,
) -> list[ItemType | None]:
"""Apply a generic transform function to each row. Any errors will be reported and thrown."""
match async_type:
case AsyncType.AsyncIO:
return await derive_from_rows_asyncio(
input, transform, callbacks, num_threads
)
case AsyncType.Threaded:
return await derive_from_rows_asyncio_threads(
input, transform, callbacks, num_threads
)
case _:
msg = f"Unsupported scheduling type {async_type}"
raise ValueError(msg)
"""A module containing the derive_from_rows_async method."""
async def derive_from_rows_asyncio_threads(
input: pd.DataFrame,
transform: Callable[[pd.Series], Awaitable[ItemType]],
callbacks: WorkflowCallbacks,
num_threads: int | None = 4,
) -> list[ItemType | None]:
"""
Derive from rows asynchronously.
This is useful for IO bound operations.
"""
semaphore = asyncio.Semaphore(num_threads or 4)
async def gather(execute: ExecuteFn[ItemType]) -> list[ItemType | None]:
tasks = [asyncio.to_thread(execute, row) for row in input.iterrows()]
async def execute_task(task: Coroutine) -> ItemType | None:
async with semaphore:
# fire off the thread
thread = await task
return await thread
return await asyncio.gather(*[execute_task(task) for task in tasks])
return await _derive_from_rows_base(input, transform, callbacks, gather)
"""A module containing the derive_from_rows_async method."""
async def derive_from_rows_asyncio(
input: pd.DataFrame,
transform: Callable[[pd.Series], Awaitable[ItemType]],
callbacks: WorkflowCallbacks,
num_threads: int = 4,
) -> list[ItemType | None]:
"""
Derive from rows asynchronously.
This is useful for IO bound operations.
"""
semaphore = asyncio.Semaphore(num_threads or 4)
async def gather(execute: ExecuteFn[ItemType]) -> list[ItemType | None]:
async def execute_row_protected(
row: tuple[Hashable, pd.Series],
) -> ItemType | None:
async with semaphore:
return await execute(row)
tasks = [
asyncio.create_task(execute_row_protected(row)) for row in input.iterrows()
]
return await asyncio.gather(*tasks)
return await _derive_from_rows_base(input, transform, callbacks, gather)
ItemType = TypeVar("ItemType")
ExecuteFn = Callable[[tuple[Hashable, pd.Series]], Awaitable[ItemType | None]]
GatherFn = Callable[[ExecuteFn], Awaitable[list[ItemType | None]]]
async def _derive_from_rows_base(
input: pd.DataFrame,
transform: Callable[[pd.Series], Awaitable[ItemType]],
callbacks: WorkflowCallbacks,
gather: GatherFn[ItemType],
) -> list[ItemType | None]:
"""
Derive from rows asynchronously.
This is useful for IO bound operations.
"""
tick = progress_ticker(callbacks.progress, num_total=len(input))
errors: list[tuple[BaseException, str]] = []
async def execute(row: tuple[Any, pd.Series]) -> ItemType | None:
try:
result = transform(row[1])
if inspect.iscoroutine(result):
result = await result
except Exception as e: # noqa: BLE001
errors.append((e, traceback.format_exc()))
return None
else:
return cast("ItemType", result)
finally:
tick(1)
result = await gather(execute)
tick.done()
for error, stack in errors:
callbacks.error("parallel transformation error", error, stack)
if len(errors) > 0:
raise ParallelizationError(len(errors))
return result