Merge pull request #14 from WyszukiwarkaPublikacji/add-vector-db

Add a vector database for chemical embeddings. Maybe will be replaced in the future, but for now we need it to keep progressing with our tool.
This commit is contained in:
Bartosz Trojan 2025-02-21 16:01:14 +01:00 committed by GitHub
commit 8085608d62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 98 additions and 5 deletions

3
aslite/config.py Normal file
View File

@ -0,0 +1,3 @@
chemical_embedding_size: int = 2048
consistency_level = "Strong" # in production should probably be Eventually, but Strong might make testing easier
chemical_index_type = "BIN_FLAT"

View File

@ -8,6 +8,8 @@ import os
import sqlite3, zlib, pickle, tempfile
from sqlitedict import SqliteDict
from contextlib import contextmanager
from pymilvus import MilvusClient, DataType
from aslite import config
# -----------------------------------------------------------------------------
# global configuration
@ -103,7 +105,7 @@ flag='r': open for read-only
PAPERS_DB_FILE = os.path.join(DATA_DIR, 'papers.db')
# stores account-relevant info, like which tags exist for which papers
DICT_DB_FILE = os.path.join(DATA_DIR, 'dict.db')
EMBEDDING_DB_FILE = os.path.join(DATA_DIR, 'embeddings.db') #NOTE: once we set it up with docker it will probably need to be a standalone db
def get_papers_db(flag='r', autocommit=True):
assert flag in ['r', 'c']
pdb = CompressedSqliteDict(PAPERS_DB_FILE, tablename='papers', flag=flag, autocommit=autocommit)
@ -128,7 +130,33 @@ def get_email_db(flag='r', autocommit=True):
assert flag in ['r', 'c']
edb = SqliteDict(DICT_DB_FILE, tablename='email', flag=flag, autocommit=autocommit)
return edb
def setup_chemical_embeddings_collection(client : MilvusClient):
schema = MilvusClient.create_schema(auto_id=False)
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True, auto_id=True)
schema.add_field(field_name="chemical_embedding", datatype=DataType.BINARY_VECTOR, dim=config.chemical_embedding_size)
schema.add_field(field_name="paper_id", datatype=DataType.INT64)
schema.add_field(field_name="category", datatype=DataType.VARCHAR, max_length=127)
schema.add_field(field_name="SMILES", datatype=DataType.VARCHAR, max_length=65535)
schema.add_field(field_name="tags",datatype=DataType.ARRAY, element_type=DataType.VARCHAR, max_capacity=100, max_length=127)
index_params = client.prepare_index_params()
index_params.add_index(field_name="chemical_embedding", index_type=config.chemical_index_type, metric_type="JACCARD")
client.create_collection(
collection_name="chemical_embeddings",
schema=schema,
index_params=index_params,
consistency_level=config.consistency_level
)
def get_embeddings_db():
client = MilvusClient(EMBEDDING_DB_FILE)
if not client.has_collection("chemical_embeddings"):
setup_chemical_embeddings_collection(client)
return client
# -----------------------------------------------------------------------------
"""
our "feature store" is currently just a pickle file, may want to consider hdf5 in the future

View File

@ -5,10 +5,11 @@ itsdangerous==2.2.0
Jinja2==3.1.5
joblib==1.4.2
MarkupSafe==3.0.2
numpy==1.21.4
scikit-learn==1.0.1
scipy==1.10.1
sgmllib3k==1.0.0
numpy
scikit-learn
Scipy
Sgmllib3k==1.0.0
sqlitedict==1.7.0
threadpoolctl==3.5.0
Werkzeug==2.3.8
pymilvus==2.5.4

61
test_db.py Normal file
View File

@ -0,0 +1,61 @@
from aslite.db import *
def convert_bool_list_to_bytes(bool_list):
if len(bool_list) % 8 != 0:
raise ValueError("The length of a boolean list must be a multiple of 8")
byte_array = bytearray(len(bool_list) // 8)
for i, bit in enumerate(bool_list):
if bit == 1:
index = i // 8
shift = i % 8
byte_array[index] |= (1 << shift)
return bytes(byte_array)
def test_chemical_embeddings_db(): # temporary test code, should not be run on a full db
import random
embedding_db: MilvusClient = get_embeddings_db()
print(embedding_db.describe_collection("chemical_embeddings"))
random_embeddings = [[bool(random.randint(0, 1)) for _ in range(config.chemical_embedding_size)] for _ in range(1000)]
data = [
{
"chemical_embedding": convert_bool_list_to_bytes(random_embeddings[i]),
"tags": ["test", "test 1" if random.randint(0, 1)==0 else "test2"],
"category": "chemistry?",
"paper_id": random.randint(1, 1000),
"SMILES": "CN=C=O"
} for i in range(len(random_embeddings))
]
res = embedding_db.insert(collection_name="chemical_embeddings", data=data)
print(res)
random_vector_q = [bool(random.randint(0, 1)) for _ in range(config.chemical_embedding_size)]
print(random_vector_q)
res = embedding_db.search(
collection_name="chemical_embeddings",
data=[convert_bool_list_to_bytes(random_vector_q)],
limit=5,
anns_field="chemical_embedding",
filter='ARRAY_CONTAINS(tags, "test")',
search_params={"metric_type": "JACCARD"},
)
print(res)
res = embedding_db.search(
collection_name="chemical_embeddings",
data=[convert_bool_list_to_bytes(random_vector_q)],
limit=5,
anns_field="chemical_embedding",
filter='ARRAY_CONTAINS(tags, "test1")',
search_params={"metric_type": "JACCARD"},
)
print(res)
res = embedding_db.search(
collection_name="chemical_embeddings",
data=[convert_bool_list_to_bytes(random_vector_q)],
limit=5,
anns_field="chemical_embedding",
filter='ARRAY_CONTAINS(tags, "test2")',
search_params={"metric_type": "JACCARD"},
)
print(res)
test_chemical_embeddings_db()