mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
prep storage class for silicon indexing
This commit is contained in:
parent
4cd3186135
commit
26c4a70ce4
@ -14,7 +14,7 @@ from typing import Any
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyodbc
|
||||
from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from pyodbc import Connection
|
||||
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
@ -52,10 +52,7 @@ class SQLServerPipelineStorage(PipelineStorage):
|
||||
|
||||
# Use password-less authentication for the db server
|
||||
self._local_connection_string = f"Driver={{ODBC Driver 18 for SQL Server}};Server=tcp:{database_server_name}.database.windows.net,1433;Database={database_name};Encrypt=yes;TrustServerCertificate=no;Connection Timeout=30;LongAsMax=yes"
|
||||
if client_id:
|
||||
credential = ManagedIdentityCredential(client_id=client_id)
|
||||
else:
|
||||
credential = DefaultAzureCredential()
|
||||
credential = DefaultAzureCredential()
|
||||
|
||||
# These connection options are recommended by Microsoft
|
||||
token_bytes = credential.get_token(
|
||||
@ -272,12 +269,12 @@ class SQLServerPipelineStorage(PipelineStorage):
|
||||
try:
|
||||
# Only store parquet file data
|
||||
if isinstance(value, bytes) and key.endswith(".parquet"):
|
||||
cursor = self._connection.cursor()
|
||||
table_name = key.split(".")[0]
|
||||
data_frame = pd.read_parquet(BytesIO(value))
|
||||
|
||||
# Automatically build CREATE TABLE statement based on DataFrame columns
|
||||
if self._autogenerate_tables:
|
||||
cursor = self._connection.cursor()
|
||||
columns = []
|
||||
for col_name, dtype in zip(
|
||||
data_frame.columns, data_frame.dtypes, strict=False
|
||||
@ -312,20 +309,51 @@ class SQLServerPipelineStorage(PipelineStorage):
|
||||
columns,
|
||||
)
|
||||
cursor.execute(create_table_sql)
|
||||
self._connection.commit()
|
||||
|
||||
# Insert parquet data into SQL server
|
||||
num_rows = len(data_frame)
|
||||
num_loaded = 0
|
||||
for _, row in data_frame.iterrows():
|
||||
try:
|
||||
placeholders = ", ".join([
|
||||
"?" for _ in range(len(data_frame.columns))
|
||||
])
|
||||
column_names = ", ".join([
|
||||
f"[{col}]" for col in data_frame.columns
|
||||
])
|
||||
insert_sql = f"INSERT INTO [{table_name}] ({column_names}) VALUES ({placeholders})" # noqa: S608
|
||||
placeholders = ", ".join([
|
||||
"?" for _ in range(len(data_frame.columns))
|
||||
])
|
||||
column_names = ", ".join([
|
||||
f"[{col}]" for col in data_frame.columns
|
||||
])
|
||||
insert_sql = f"INSERT INTO [{table_name}] ({column_names}) VALUES ({placeholders})" # noqa: S608
|
||||
log.info(
|
||||
"Inserting %s rows into table %s", num_rows, table_name
|
||||
)
|
||||
|
||||
# Use a row-wise cursor loop for small tables
|
||||
if num_rows < 1000:
|
||||
cursor = self._connection.cursor()
|
||||
for _, row in data_frame.iterrows():
|
||||
try:
|
||||
# Handle various value types, converting complex types to strings
|
||||
values = []
|
||||
for val in row:
|
||||
if isinstance(val, np.ndarray):
|
||||
values.append(json.dumps(val.tolist()))
|
||||
elif isinstance(val, list):
|
||||
values.append(json.dumps(val))
|
||||
elif pd.isna(val):
|
||||
values.append(None)
|
||||
else:
|
||||
values.append(val)
|
||||
cursor.execute(insert_sql, values)
|
||||
except Exception: # noqa: BLE001
|
||||
log.info(
|
||||
"Error inserting row %s into table: %s, skipping...",
|
||||
row.to_dict(),
|
||||
table_name,
|
||||
)
|
||||
continue
|
||||
# For large tables, use a bulk insert strategy
|
||||
else:
|
||||
cursor = self._connection.cursor()
|
||||
cursor.fast_executemany = True
|
||||
bulk_values = []
|
||||
for row in data_frame.itertuples(index=False):
|
||||
# Handle various value types, converting complex types to strings
|
||||
values = []
|
||||
for val in row:
|
||||
@ -337,22 +365,8 @@ class SQLServerPipelineStorage(PipelineStorage):
|
||||
values.append(None)
|
||||
else:
|
||||
values.append(val)
|
||||
cursor.execute(insert_sql, values)
|
||||
num_loaded += 1
|
||||
if num_loaded % 1000 == 0:
|
||||
log.info(
|
||||
"Successfully inserted row %s out of %s into table %s",
|
||||
num_loaded,
|
||||
num_rows,
|
||||
table_name,
|
||||
)
|
||||
except Exception:
|
||||
log.exception(
|
||||
"Error inserting row %s into table: %s",
|
||||
row.to_dict(),
|
||||
table_name,
|
||||
)
|
||||
continue
|
||||
bulk_values.append(values)
|
||||
cursor.executemany(insert_sql, bulk_values)
|
||||
self._connection.commit()
|
||||
log.info("Successfully stored %s in SQL Server", key)
|
||||
else:
|
||||
@ -361,8 +375,8 @@ class SQLServerPipelineStorage(PipelineStorage):
|
||||
key,
|
||||
)
|
||||
except Exception:
|
||||
self._connection.rollback()
|
||||
log.exception("Error writing data %s: %s", key)
|
||||
self._connection.commit()
|
||||
log.exception("Error writing data %s: %s.", key)
|
||||
|
||||
async def has(self, key: str) -> bool:
|
||||
"""Check if a table/file exists in SQL Server.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user