Add a vector db
This commit is contained in:
parent
167fc1607d
commit
940f603030
3
aslite/config.py
Normal file
3
aslite/config.py
Normal file
@ -0,0 +1,3 @@
|
||||
chemical_embedding_size: int = 2048
|
||||
consistency_level = "Strong" # in production should probably be Eventually, but Strong might make testing easier
|
||||
chemical_index_type = "BIN_FLAT"
|
||||
30
aslite/db.py
30
aslite/db.py
@ -8,6 +8,8 @@ import os
|
||||
import sqlite3, zlib, pickle, tempfile
|
||||
from sqlitedict import SqliteDict
|
||||
from contextlib import contextmanager
|
||||
from pymilvus import MilvusClient, DataType
|
||||
from aslite import config
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# global configuration
|
||||
@ -103,7 +105,7 @@ flag='r': open for read-only
|
||||
PAPERS_DB_FILE = os.path.join(DATA_DIR, 'papers.db')
|
||||
# stores account-relevant info, like which tags exist for which papers
|
||||
DICT_DB_FILE = os.path.join(DATA_DIR, 'dict.db')
|
||||
|
||||
EMBEDDING_DB_FILE = os.path.join(DATA_DIR, 'embeddings.db') #NOTE: once we set it up with docker it will probably need to be a standalone db
|
||||
def get_papers_db(flag='r', autocommit=True):
|
||||
assert flag in ['r', 'c']
|
||||
pdb = CompressedSqliteDict(PAPERS_DB_FILE, tablename='papers', flag=flag, autocommit=autocommit)
|
||||
@ -128,7 +130,33 @@ def get_email_db(flag='r', autocommit=True):
|
||||
assert flag in ['r', 'c']
|
||||
edb = SqliteDict(DICT_DB_FILE, tablename='email', flag=flag, autocommit=autocommit)
|
||||
return edb
|
||||
def setup_chemical_embeddings_collection(client : MilvusClient):
|
||||
|
||||
schema = MilvusClient.create_schema(auto_id=False)
|
||||
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True, auto_id=True)
|
||||
schema.add_field(field_name="chemical_embedding", datatype=DataType.BINARY_VECTOR, dim=config.chemical_embedding_size)
|
||||
schema.add_field(field_name="paper_id", datatype=DataType.INT64)
|
||||
schema.add_field(field_name="category", datatype=DataType.VARCHAR, max_length=127)
|
||||
schema.add_field(field_name="SMILES", datatype=DataType.VARCHAR, max_length=65535)
|
||||
schema.add_field(field_name="tags",datatype=DataType.ARRAY, element_type=DataType.VARCHAR, max_capacity=100, max_length=127)
|
||||
index_params = client.prepare_index_params()
|
||||
index_params.add_index(field_name="chemical_embedding", index_type=config.chemical_index_type, metric_type="JACCARD")
|
||||
client.create_collection(
|
||||
collection_name="chemical_embeddings",
|
||||
schema=schema,
|
||||
index_params=index_params,
|
||||
consistency_level=config.consistency_level
|
||||
)
|
||||
def get_embeddings_db():
|
||||
client = MilvusClient(EMBEDDING_DB_FILE)
|
||||
if not client.has_collection("chemical_embeddings"):
|
||||
setup_chemical_embeddings_collection(client)
|
||||
return client
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
"""
|
||||
our "feature store" is currently just a pickle file, may want to consider hdf5 in the future
|
||||
|
||||
@ -5,10 +5,11 @@ itsdangerous==2.2.0
|
||||
Jinja2==3.1.5
|
||||
joblib==1.4.2
|
||||
MarkupSafe==3.0.2
|
||||
numpy==1.21.4
|
||||
scikit-learn==1.0.1
|
||||
scipy==1.10.1
|
||||
sgmllib3k==1.0.0
|
||||
numpy
|
||||
scikit-learn
|
||||
Scipy
|
||||
Sgmllib3k==1.0.0
|
||||
sqlitedict==1.7.0
|
||||
threadpoolctl==3.5.0
|
||||
Werkzeug==2.3.8
|
||||
pymilvus==2.5.4
|
||||
|
||||
61
test_db.py
Normal file
61
test_db.py
Normal file
@ -0,0 +1,61 @@
|
||||
from aslite.db import *
|
||||
def convert_bool_list_to_bytes(bool_list):
|
||||
if len(bool_list) % 8 != 0:
|
||||
raise ValueError("The length of a boolean list must be a multiple of 8")
|
||||
|
||||
byte_array = bytearray(len(bool_list) // 8)
|
||||
for i, bit in enumerate(bool_list):
|
||||
if bit == 1:
|
||||
index = i // 8
|
||||
shift = i % 8
|
||||
byte_array[index] |= (1 << shift)
|
||||
return bytes(byte_array)
|
||||
|
||||
|
||||
def test_chemical_embeddings_db(): # temporary test code, should not be run on a full db
|
||||
import random
|
||||
embedding_db: MilvusClient = get_embeddings_db()
|
||||
print(embedding_db.describe_collection("chemical_embeddings"))
|
||||
random_embeddings = [[bool(random.randint(0, 1)) for _ in range(config.chemical_embedding_size)] for _ in range(1000)]
|
||||
data = [
|
||||
{
|
||||
"chemical_embedding": convert_bool_list_to_bytes(random_embeddings[i]),
|
||||
"tags": ["test", "test 1" if random.randint(0, 1)==0 else "test2"],
|
||||
"category": "chemistry?",
|
||||
"paper_id": random.randint(1, 1000),
|
||||
"SMILES": "CN=C=O"
|
||||
} for i in range(len(random_embeddings))
|
||||
]
|
||||
res = embedding_db.insert(collection_name="chemical_embeddings", data=data)
|
||||
print(res)
|
||||
random_vector_q = [bool(random.randint(0, 1)) for _ in range(config.chemical_embedding_size)]
|
||||
print(random_vector_q)
|
||||
res = embedding_db.search(
|
||||
collection_name="chemical_embeddings",
|
||||
data=[convert_bool_list_to_bytes(random_vector_q)],
|
||||
limit=5,
|
||||
anns_field="chemical_embedding",
|
||||
filter='ARRAY_CONTAINS(tags, "test")',
|
||||
search_params={"metric_type": "JACCARD"},
|
||||
)
|
||||
print(res)
|
||||
res = embedding_db.search(
|
||||
collection_name="chemical_embeddings",
|
||||
data=[convert_bool_list_to_bytes(random_vector_q)],
|
||||
limit=5,
|
||||
anns_field="chemical_embedding",
|
||||
filter='ARRAY_CONTAINS(tags, "test1")',
|
||||
search_params={"metric_type": "JACCARD"},
|
||||
)
|
||||
print(res)
|
||||
res = embedding_db.search(
|
||||
collection_name="chemical_embeddings",
|
||||
data=[convert_bool_list_to_bytes(random_vector_q)],
|
||||
limit=5,
|
||||
anns_field="chemical_embedding",
|
||||
filter='ARRAY_CONTAINS(tags, "test2")',
|
||||
search_params={"metric_type": "JACCARD"},
|
||||
)
|
||||
print(res)
|
||||
|
||||
test_chemical_embeddings_db()
|
||||
Loading…
Reference in New Issue
Block a user