prep storage class for silicon indexing

This commit is contained in:
Kenny Zhang 2025-05-06 16:31:53 -04:00
parent 4cd3186135
commit 26c4a70ce4

View File

@ -14,7 +14,7 @@ from typing import Any
import numpy as np
import pandas as pd
import pyodbc
from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
from azure.identity import DefaultAzureCredential
from pyodbc import Connection
from graphrag.logger.base import ProgressLogger
@ -52,10 +52,7 @@ class SQLServerPipelineStorage(PipelineStorage):
# Use password-less authentication for the db server
self._local_connection_string = f"Driver={{ODBC Driver 18 for SQL Server}};Server=tcp:{database_server_name}.database.windows.net,1433;Database={database_name};Encrypt=yes;TrustServerCertificate=no;Connection Timeout=30;LongAsMax=yes"
if client_id:
credential = ManagedIdentityCredential(client_id=client_id)
else:
credential = DefaultAzureCredential()
credential = DefaultAzureCredential()
# These connection options are recommended by Microsoft
token_bytes = credential.get_token(
@ -272,12 +269,12 @@ class SQLServerPipelineStorage(PipelineStorage):
try:
# Only store parquet file data
if isinstance(value, bytes) and key.endswith(".parquet"):
cursor = self._connection.cursor()
table_name = key.split(".")[0]
data_frame = pd.read_parquet(BytesIO(value))
# Automatically build CREATE TABLE statement based on DataFrame columns
if self._autogenerate_tables:
cursor = self._connection.cursor()
columns = []
for col_name, dtype in zip(
data_frame.columns, data_frame.dtypes, strict=False
@ -312,20 +309,51 @@ class SQLServerPipelineStorage(PipelineStorage):
columns,
)
cursor.execute(create_table_sql)
self._connection.commit()
# Insert parquet data into SQL server
num_rows = len(data_frame)
num_loaded = 0
for _, row in data_frame.iterrows():
try:
placeholders = ", ".join([
"?" for _ in range(len(data_frame.columns))
])
column_names = ", ".join([
f"[{col}]" for col in data_frame.columns
])
insert_sql = f"INSERT INTO [{table_name}] ({column_names}) VALUES ({placeholders})" # noqa: S608
placeholders = ", ".join([
"?" for _ in range(len(data_frame.columns))
])
column_names = ", ".join([
f"[{col}]" for col in data_frame.columns
])
insert_sql = f"INSERT INTO [{table_name}] ({column_names}) VALUES ({placeholders})" # noqa: S608
log.info(
"Inserting %s rows into table %s", num_rows, table_name
)
# Use a row-wise cursor loop for small tables
if num_rows < 1000:
cursor = self._connection.cursor()
for _, row in data_frame.iterrows():
try:
# Handle various value types, converting complex types to strings
values = []
for val in row:
if isinstance(val, np.ndarray):
values.append(json.dumps(val.tolist()))
elif isinstance(val, list):
values.append(json.dumps(val))
elif pd.isna(val):
values.append(None)
else:
values.append(val)
cursor.execute(insert_sql, values)
except Exception: # noqa: BLE001
log.info(
"Error inserting row %s into table: %s, skipping...",
row.to_dict(),
table_name,
)
continue
# For large tables, use a bulk insert strategy
else:
cursor = self._connection.cursor()
cursor.fast_executemany = True
bulk_values = []
for row in data_frame.itertuples(index=False):
# Handle various value types, converting complex types to strings
values = []
for val in row:
@ -337,22 +365,8 @@ class SQLServerPipelineStorage(PipelineStorage):
values.append(None)
else:
values.append(val)
cursor.execute(insert_sql, values)
num_loaded += 1
if num_loaded % 1000 == 0:
log.info(
"Successfully inserted row %s out of %s into table %s",
num_loaded,
num_rows,
table_name,
)
except Exception:
log.exception(
"Error inserting row %s into table: %s",
row.to_dict(),
table_name,
)
continue
bulk_values.append(values)
cursor.executemany(insert_sql, bulk_values)
self._connection.commit()
log.info("Successfully stored %s in SQL Server", key)
else:
@ -361,8 +375,8 @@ class SQLServerPipelineStorage(PipelineStorage):
key,
)
except Exception:
self._connection.rollback()
log.exception("Error writing data %s: %s", key)
self._connection.commit()
log.exception("Error writing data %s: %s.", key)
async def has(self, key: str) -> bool:
"""Check if a table/file exists in SQL Server.