Skip to content

Commit

Permalink
Merge pull request #3 from pepkit/dev
Browse files Browse the repository at this point in the history
move to fastembed
  • Loading branch information
nleroy917 authored Dec 11, 2023
2 parents 336dc32 + 8645e6f commit 987faea
Show file tree
Hide file tree
Showing 16 changed files with 252 additions and 71 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,5 +161,5 @@ cython_debug/

# qdrant
qdrant_storage/

scripts/
local_cache/
scripts/
7 changes: 7 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
7 changes: 6 additions & 1 deletion keywords.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,9 @@ cell
protocol
description
processing
source
source
table
file_path
pep_version
project_name
experiment_metadata
28 changes: 19 additions & 9 deletions pepembed/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# %%
import sys
import logging
import os
Expand All @@ -16,7 +17,7 @@
PKG_NAME,
DESCRIPTION_COLUNM,
PROJECT_TABLE,
PROJECT_COLUMN,
CONFIG_COLUMN,
PROJECT_NAME_COLUMN,
NAMESPACE_COLUMN,
TAG_COLUMN,
Expand All @@ -29,7 +30,7 @@
from .pepembed import PEPEncoder
from .utils import batch_generator


# %%
def main():
"""Entry point for the CLI."""
load_dotenv()
Expand Down Expand Up @@ -75,7 +76,7 @@ def main():
# get list of peps
_LOGGER.info("Pulling PEPs from database.")
curs.execute(
f"SELECT {NAMESPACE_COLUMN}, {PROJECT_NAME_COLUMN}, {TAG_COLUMN}, {PROJECT_COLUMN}, {DESCRIPTION_COLUNM}, {ROW_ID_COLUMN} FROM {PROJECT_TABLE}"
f"SELECT {NAMESPACE_COLUMN}, {PROJECT_NAME_COLUMN}, {TAG_COLUMN}, {CONFIG_COLUMN}, {ROW_ID_COLUMN} FROM {PROJECT_TABLE}"
)
projects = curs.fetchall()

Expand All @@ -94,9 +95,9 @@ def main():

# we need to work in batches since its much faster
projects_encoded = []
for batch in tqdm(
for i, batch in enumerate(tqdm(
batch_generator(projects, BATCH_SIZE), total=len(projects) // BATCH_SIZE
):
)):
# build list of descriptions for batch
descs = []
for p in batch:
Expand All @@ -105,6 +106,10 @@ def main():
descs.append(d)
else:
descs.append(f"{p[0]} {p[1]} {p[2]}")

# every 100th batch, print out the first description
if i % 100 == 0:
_LOGGER.info(f"First description: {descs[0]}")

# encode descriptions
try:
Expand Down Expand Up @@ -133,13 +138,17 @@ def main():

# connect to qdrant
qdrant = QdrantClient(
url=QDRANT_HOST,
url=QDRANT_HOST,
port=QDRANT_PORT,
api_key=QDRANT_API_KEY,
)

# get the collection info
COLLECTION = args.qdrant_collection or os.environ.get("QDRANT_COLLECTION") or QDRANT_DEFAULT_COLLECTION
COLLECTION = (
args.qdrant_collection
or os.environ.get("QDRANT_COLLECTION")
or QDRANT_DEFAULT_COLLECTION
)

# recreate the collection if necessary
if args.recreate_collection:
Expand All @@ -148,7 +157,7 @@ def main():
vectors_config=models.VectorParams(
size=EMBEDDING_DIM, distance=models.Distance.COSINE
),
on_disk_payload=True
on_disk_payload=True,
)
collection_info = qdrant.get_collection(collection_name=COLLECTION)
else:
Expand All @@ -164,7 +173,7 @@ def main():
vectors_config=models.VectorParams(
size=EMBEDDING_DIM, distance=models.Distance.COSINE
),
on_disk_payload=True
on_disk_payload=True,
)
collection_info = qdrant.get_collection(collection_name=COLLECTION)

Expand Down Expand Up @@ -214,6 +223,7 @@ def main():
"""
)


if __name__ == "__main__":
try:
sys.exit(main())
Expand Down
5 changes: 2 additions & 3 deletions pepembed/const.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from sentence_transformers import __version__ as st_version
from platform import python_version
from logging import CRITICAL, DEBUG, ERROR, INFO, WARN

Expand All @@ -12,16 +11,16 @@
QDRANT_DEFAULT_COLLECTION = "pephub"

VERSIONS = {
"sentence_transformers_version": st_version,
"python_version": python_version(),
}

DEFAULT_KEYWORDS = ["cell", "protocol", "description", "processing", "source"]
DEFAULT_MODEL = "sentence-transformers/all-MiniLM-L12-v2"

MIN_DESCRIPTION_LENGTH = 5

PROJECT_TABLE = "projects"
PROJECT_COLUMN = "project_value"
CONFIG_COLUMN = "config"
PROJECT_NAME_COLUMN = "name"
CONFIG_COLUMN = "config"
NAMESPACE_COLUMN = "namespace"
Expand Down
62 changes: 21 additions & 41 deletions pepembed/pepembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from typing import List, Dict, Any, Union
from peppy import Project
from peppy.const import SAMPLE_MODS_KEY, CONSTANT_KEY, CONFIG_KEY, NAME_KEY
from sentence_transformers import SentenceTransformer
from fastembed.embedding import FlagEmbedding as Embedding

import flatdict

from .utils import read_in_key_words
from .const import DEFAULT_KEYWORDS, MIN_DESCRIPTION_LENGTH


class PEPEncoder(SentenceTransformer):
class PEPEncoder(Embedding):
"""
Simple wrapper of the sentence trasnformer class that lets you
embed metadata inside a PEP.
Expand All @@ -34,7 +35,11 @@ def mine_metadata_from_dict(
:param project: A dictionary representing a peppy.Project instance.
:param min_desc_length: The minimum length of the description.
"""
# project_config = project.get(CONFIG_KEY) or project.get(CONFIG_KEY.replace("_", ""))
# project_config = project.get(CONFIG_KEY) or project.get(
# CONFIG_KEY.replace("_", "")
# )
# fix bug where config key is not in the project,
# new database schema does not have config key
project_config = project
if project_config is None:
return ""
Expand All @@ -44,14 +49,24 @@ def mine_metadata_from_dict(
):
return project[NAME_KEY] or ""

project_level_dict: dict = project_config[SAMPLE_MODS_KEY][CONSTANT_KEY]
# project_level_dict: dict = project_config[SAMPLE_MODS_KEY][CONSTANT_KEY]
# Flatten dictionary
project_level_dict: dict = flatdict.FlatDict(project_config)
project_level_attrs = list(project_level_dict.keys())
desc = ""

# build up a description
# search for "summary" in keys, if found, use that first, then pop it out
# should catch if key simply contains "summary"
for attr in project_level_attrs:
if "summary" in attr:
desc += str(project_level_dict[attr]) + " "
project_level_attrs.remove(attr)
break

# build up a description using the rest
for attr in project_level_attrs:
if any([kw in attr for kw in self.keywords]):
desc += project_level_dict[attr] + " "
desc += str(project_level_dict[attr]) + " "

# return if description is sufficient
if len(desc) > min_desc_length:
Expand All @@ -74,38 +89,3 @@ def mine_metadata_from_pep(
return self.mine_metadata_from_dict(
project_dict, min_desc_length=min_desc_length
)

def embed(
self, projects: Union[dict, List[dict], Project, List[Project]], **kwargs
) -> np.ndarray:
"""
Embed a PEP based on it's metadata.
:param projects: A PEP or list of PEPs to embed.
:param kwargs: Keyword arguments to pass to the `encode` method of the SentenceTransformer class.
"""
# if single dictionary is passed
if isinstance(projects, dict):
desc = self.mine_metadata_from_dict(projects)
return super().encode(desc, **kwargs)

# if single peppy.Project is passed
elif isinstance(projects, Project):
desc = self.mine_metadata_from_pep(projects)
return super().encode(desc, **kwargs)

# if list of dictionaries is passed
elif isinstance(projects, list) and isinstance(projects[0], dict):
descs = [self.mine_metadata_from_dict(p) for p in projects]
return super().encode(descs, **kwargs)

# if list of peppy.Projects is passed
elif isinstance(projects, list) and isinstance(projects[0], Project):
descs = [self.mine_metadata_from_pep(p) for p in projects]
return super().encode(descs, **kwargs)

# else, return ValueError
else:
raise ValueError(
"Invalid input type. Must be a dictionary, peppy.Project, list of dictionaries, or list of peppy.Projects."
)
7 changes: 7 additions & 0 deletions pepembed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ def read_in_key_words(key_words_file: str) -> List[str]:
return key_words


def generate_key_words(key_words_file: str) -> List[str]:
"""Generates keywords based on current PEPs by finding most common shared attributes"""
# TODO Generate a dynamic list of keywords for custom PEPs
key_words = []
return key_words


def batch_generator(iterable, batch_size) -> List:
"""Batch generator."""
l = len(iterable)
Expand Down
2 changes: 1 addition & 1 deletion production.env
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ export QDRANT_HOST=`pass databio/pephub/qdrant_host`
export QDRANT_PORT=6333
export QDRANT_API_KEY=`pass databio/pephub/qdrant_api_key`

export HF_MODEL="sentence-transformers/all-MiniLM-L12-v2"
export HF_MODEL="BAAI/bge-small-en-v1.5"
3 changes: 2 additions & 1 deletion requirements/requirements-all.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
logmuse
sentence-transformers
fastembed
peppy
python-dotenv
qdrant-client
psycopg2
ubiquerg
tqdm
flatdict
37 changes: 24 additions & 13 deletions run_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
verbosity=None,
logging_level=None,
recreate_collection=True,
hf_model="sentence-transformers/all-MiniLM-L12-v2",
hf_model=os.environ.get("HF_MODEL"),
keywords_file="keywords.txt",
batch_size=DEFAULT_BATCH_SIZE,
upsert_batch_size=DEFAULT_UPSERT_BATCH_SIZE,
Expand Down Expand Up @@ -107,19 +107,19 @@
# initialize encoder
_LOGGER.info("Initializing encoder.")
encoder = PEPEncoder(args.hf_model, keywords_file=args.keywords_file)
EMBEDDING_DIM = int(encoder.get_sentence_embedding_dimension())
EMBEDDING_DIM = 384 # hardcoded for sentence-transformers/all-MiniLM-L12-v2 and BAAI/bge-small-en-v1.5
_LOGGER.info(f"Computing embeddings of {EMBEDDING_DIM} dimensions.")

# %%
# encode PEPs in batches
_LOGGER.info("Encoding PEPs.")
BATCH_SIZE = args.batch_size or DEFAULT_BATCH_SIZE

# we need to work in batches since its much faster
projects_encoded = []
for i, batch in enumerate(tqdm(
batch_generator(projects, BATCH_SIZE), total=len(projects) // BATCH_SIZE
)):
i = 0
for batch in tqdm(
batch_generator(projects, BATCH_SIZE), total=(len(projects) // BATCH_SIZE)
):
# build list of descriptions for batch
descs = []
for p in batch:
Expand All @@ -128,14 +128,12 @@
descs.append(d)
else:
descs.append(f"{p[0]} {p[1]} {p[2]}")

# every 100th batch, print out the first description
if i % 100 == 0:
_LOGGER.info(f"First description: {descs[0]}")

# encode descriptions
try:
embeddings = encoder.encode(descs)
embeddings = encoder.embed(descs)
projects_encoded.extend(
[
dict(
Expand All @@ -149,6 +147,7 @@
)
except Exception as e:
_LOGGER.error(f"Error encoding batch: {e}")
i += 1

# %%
_LOGGER.info("Encoding complete.")
Expand Down Expand Up @@ -181,6 +180,13 @@
size=EMBEDDING_DIM, distance=models.Distance.COSINE
),
on_disk_payload=True,
quantization_config=models.ScalarQuantization(
scalar=models.ScalarQuantizationConfig(
type=models.ScalarType.INT8,
quantile=0.99,
always_ram=True,
),
),
)
collection_info = qdrant.get_collection(collection_name=COLLECTION)
else:
Expand All @@ -197,6 +203,13 @@
size=EMBEDDING_DIM, distance=models.Distance.COSINE
),
on_disk_payload=True,
quantization_config=models.ScalarQuantization(
scalar=models.ScalarQuantizationConfig(
type=models.ScalarType.INT8,
quantile=0.99,
always_ram=True,
),
),
)
collection_info = qdrant.get_collection(collection_name=COLLECTION)

Expand Down Expand Up @@ -226,9 +239,7 @@
batch_generator(all_points, UPSERT_BATCH_SIZE),
total=len(all_points) // UPSERT_BATCH_SIZE,
):
operation_info = qdrant.upsert(
collection_name=COLLECTION, wait=True, points=batch
)
operation_info = qdrant.upsert(collection_name=COLLECTION, wait=True, points=batch)

assert operation_info.status == "completed"

Expand All @@ -244,4 +255,4 @@
"ids": [0, 3, 100]
}}' 'http://{QDRANT_HOST}:{QDRANT_PORT}/collections/{COLLECTION}/points'
"""
)
)
Loading

0 comments on commit 987faea

Please sign in to comment.