From 940f60303084d175fa6f3dedad86edafd9a96bc1 Mon Sep 17 00:00:00 2001 From: kachim2 Date: Sun, 16 Feb 2025 21:26:23 +0100 Subject: [PATCH] Add a vector db --- aslite/config.py | 3 +++ aslite/db.py | 30 +++++++++++++++++++++++- requirements.txt | 9 +++---- test_db.py | 61 ++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 98 insertions(+), 5 deletions(-) create mode 100644 aslite/config.py create mode 100644 test_db.py diff --git a/aslite/config.py b/aslite/config.py new file mode 100644 index 0000000..2cdce8f --- /dev/null +++ b/aslite/config.py @@ -0,0 +1,3 @@ +chemical_embedding_size: int = 2048 +consistency_level = "Strong" # in production should probably be Eventually, but Strong might make testing easier +chemical_index_type = "BIN_FLAT" diff --git a/aslite/db.py b/aslite/db.py index 55f1916..3b4551a 100644 --- a/aslite/db.py +++ b/aslite/db.py @@ -8,6 +8,8 @@ import os import sqlite3, zlib, pickle, tempfile from sqlitedict import SqliteDict from contextlib import contextmanager +from pymilvus import MilvusClient, DataType +from aslite import config # ----------------------------------------------------------------------------- # global configuration @@ -103,7 +105,7 @@ flag='r': open for read-only PAPERS_DB_FILE = os.path.join(DATA_DIR, 'papers.db') # stores account-relevant info, like which tags exist for which papers DICT_DB_FILE = os.path.join(DATA_DIR, 'dict.db') - +EMBEDDING_DB_FILE = os.path.join(DATA_DIR, 'embeddings.db') #NOTE: once we set it up with docker it will probably need to be a standalone db def get_papers_db(flag='r', autocommit=True): assert flag in ['r', 'c'] pdb = CompressedSqliteDict(PAPERS_DB_FILE, tablename='papers', flag=flag, autocommit=autocommit) @@ -128,7 +130,33 @@ def get_email_db(flag='r', autocommit=True): assert flag in ['r', 'c'] edb = SqliteDict(DICT_DB_FILE, tablename='email', flag=flag, autocommit=autocommit) return edb +def setup_chemical_embeddings_collection(client : MilvusClient): + + schema = MilvusClient.create_schema(auto_id=False) + schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True, auto_id=True) + schema.add_field(field_name="chemical_embedding", datatype=DataType.BINARY_VECTOR, dim=config.chemical_embedding_size) + schema.add_field(field_name="paper_id", datatype=DataType.INT64) + schema.add_field(field_name="category", datatype=DataType.VARCHAR, max_length=127) + schema.add_field(field_name="SMILES", datatype=DataType.VARCHAR, max_length=65535) + schema.add_field(field_name="tags",datatype=DataType.ARRAY, element_type=DataType.VARCHAR, max_capacity=100, max_length=127) + index_params = client.prepare_index_params() + index_params.add_index(field_name="chemical_embedding", index_type=config.chemical_index_type, metric_type="JACCARD") + client.create_collection( + collection_name="chemical_embeddings", + schema=schema, + index_params=index_params, + consistency_level=config.consistency_level + ) +def get_embeddings_db(): + client = MilvusClient(EMBEDDING_DB_FILE) + if not client.has_collection("chemical_embeddings"): + setup_chemical_embeddings_collection(client) + return client + + + + # ----------------------------------------------------------------------------- """ our "feature store" is currently just a pickle file, may want to consider hdf5 in the future diff --git a/requirements.txt b/requirements.txt index ae7a992..0d5d4f9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,10 +5,11 @@ itsdangerous==2.2.0 Jinja2==3.1.5 joblib==1.4.2 MarkupSafe==3.0.2 -numpy==1.21.4 -scikit-learn==1.0.1 -scipy==1.10.1 -sgmllib3k==1.0.0 +numpy +scikit-learn +Scipy +Sgmllib3k==1.0.0 sqlitedict==1.7.0 threadpoolctl==3.5.0 Werkzeug==2.3.8 +pymilvus==2.5.4 diff --git a/test_db.py b/test_db.py new file mode 100644 index 0000000..a37f5d5 --- /dev/null +++ b/test_db.py @@ -0,0 +1,61 @@ +from aslite.db import * +def convert_bool_list_to_bytes(bool_list): + if len(bool_list) % 8 != 0: + raise ValueError("The length of a boolean list must be a multiple of 8") + + byte_array = bytearray(len(bool_list) // 8) + for i, bit in enumerate(bool_list): + if bit == 1: + index = i // 8 + shift = i % 8 + byte_array[index] |= (1 << shift) + return bytes(byte_array) + + +def test_chemical_embeddings_db(): # temporary test code, should not be run on a full db + import random + embedding_db: MilvusClient = get_embeddings_db() + print(embedding_db.describe_collection("chemical_embeddings")) + random_embeddings = [[bool(random.randint(0, 1)) for _ in range(config.chemical_embedding_size)] for _ in range(1000)] + data = [ + { + "chemical_embedding": convert_bool_list_to_bytes(random_embeddings[i]), + "tags": ["test", "test 1" if random.randint(0, 1)==0 else "test2"], + "category": "chemistry?", + "paper_id": random.randint(1, 1000), + "SMILES": "CN=C=O" + } for i in range(len(random_embeddings)) + ] + res = embedding_db.insert(collection_name="chemical_embeddings", data=data) + print(res) + random_vector_q = [bool(random.randint(0, 1)) for _ in range(config.chemical_embedding_size)] + print(random_vector_q) + res = embedding_db.search( + collection_name="chemical_embeddings", + data=[convert_bool_list_to_bytes(random_vector_q)], + limit=5, + anns_field="chemical_embedding", + filter='ARRAY_CONTAINS(tags, "test")', + search_params={"metric_type": "JACCARD"}, + ) + print(res) + res = embedding_db.search( + collection_name="chemical_embeddings", + data=[convert_bool_list_to_bytes(random_vector_q)], + limit=5, + anns_field="chemical_embedding", + filter='ARRAY_CONTAINS(tags, "test1")', + search_params={"metric_type": "JACCARD"}, + ) + print(res) + res = embedding_db.search( + collection_name="chemical_embeddings", + data=[convert_bool_list_to_bytes(random_vector_q)], + limit=5, + anns_field="chemical_embedding", + filter='ARRAY_CONTAINS(tags, "test2")', + search_params={"metric_type": "JACCARD"}, + ) + print(res) + +test_chemical_embeddings_db()