Skip to content

Commit

Permalink
add flake8 lint in ci
Browse files Browse the repository at this point in the history
  • Loading branch information
zzzming committed Mar 13, 2024
1 parent 2ebd487 commit 6b8188e
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 25 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/modules.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ jobs:
cd ragstack/colbertbase
poetry install
# - name: Lint with flake8
# run: |
# cd ragstack
# poetry run flake8 .
- name: Lint with flake8
run: |
cd ragstack/colbertbase
poetry run flake8 ./colbertbase
- name: Test with pytest
run: |
Expand Down
3 changes: 3 additions & 0 deletions ragstack/colbertbase/.flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[flake8]
max-line-length = 140
ignore = E221,E222,E302,E252,E225,E303,E293,W293,E121,E125,E226
12 changes: 2 additions & 10 deletions ragstack/colbertbase/colbertbase/astra_colbert_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import torch # it should part of colbert dependencies
import uuid
import os
from .token_embedding import TokenEmbeddings, PerTokenEmbeddings, PassageEmbeddings


Expand Down Expand Up @@ -54,14 +55,6 @@ class ColbertTokenEmbeddings(TokenEmbeddings):
@classmethod
def validate_environment(self, values: Dict) -> Dict:
"""Validate colbert and its dependency is installed."""
try:
from colbert import CollectionEncoder
except ImportError as exc:
raise ImportError(
"Could not import colbert library. "
"Please install it with `pip install colbert`"
) from exc

try:
import torch
if torch.cuda.is_available():
Expand Down Expand Up @@ -160,7 +153,7 @@ def encode_query(
def encode(self, texts: List[str], title: str="") -> List[PassageEmbeddings]:
# collection = Collection(texts)
# batches = collection.enumerate_batches(rank=Run().rank)
'''
'''
config = ColBERTConfig(
doc_maxlen=self.__doc_maxlen,
nbits=self.__nbits,
Expand All @@ -176,7 +169,6 @@ def encode(self, texts: List[str], title: str="") -> List[PassageEmbeddings]:
# split up embeddings by counts, a list of the number of tokens in each passage
start_indices = [0] + list(itertools.accumulate(count[:-1]))
embeddings_by_part = [embeddings[start:start+count] for start, count in zip(start_indices, count)]
size = len(embeddings_by_part)
for part, embedding in enumerate(embeddings_by_part):
collectionEmbd = PassageEmbeddings(text=texts[part], title=title, part=part)
pid = collectionEmbd.id()
Expand Down
8 changes: 3 additions & 5 deletions ragstack/colbertbase/colbertbase/astra_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ def __init__(
verbose: bool=False,
timeout: int=60,
**kwargs,
):

):
required_cred(secure_connect_bundle)
required_cred(astra_token)

Expand Down Expand Up @@ -88,7 +87,6 @@ def __init__(
self.query_part_by_pk_stmt = self.session.prepare(query_part_by_pk)

print("statements are prepared")


def create_tables(self):
self.session.execute(f"""
Expand Down Expand Up @@ -163,7 +161,7 @@ def insert_colbert_embeddings_chunks(
# insert colbert embeddings
for passageEmbd in embeddings:
title = passageEmbd.title()
parameters = [(title, e[1].part, e[1].id, e[1].get_embeddings()) for e in enumerate(passageEmbd.get_all_token_embeddings())]
parameters = [(title, e[1].part, e[1].id, e[1].get_embeddings()) for e in enumerate(passageEmbd.get_all_token_embeddings())]
execute_concurrent_with_args(self.session, self.insert_colbert_stmt, parameters)

def delete_title(self, title: str):
Expand All @@ -176,4 +174,4 @@ def delete_title(self, title: str):

def close(self):
self.session.shutdown()
self.cluster.shutdown()
self.cluster.shutdown()
3 changes: 1 addition & 2 deletions ragstack/colbertbase/colbertbase/astra_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from .astra_db import AstraDB
from torch import tensor
from typing import List
import torch
import math

Expand Down Expand Up @@ -124,7 +123,7 @@ def retrieve(
rows = self.astra.session.execute(self.astra.query_colbert_parts_stmt, [title, part])
embeddings_for_part = [tensor(row.bert_embedding) for row in rows]
# score based on The function returns the highest similarity score
#(i.e., the maximum dot product value) between the query vector and any of the embedding vectors in the list.
# (i.e., the maximum dot product value) between the query vector and any of the embedding vectors in the list.
scores[(title, part)] = sum(max_similarity_torch(qv, embeddings_for_part, self.is_cuda) for qv in query_encodings)
# load the source chunk for the top k documents
docs_by_score = sorted(scores, key=scores.get, reverse=True)[:k]
Expand Down
7 changes: 3 additions & 4 deletions ragstack/colbertbase/colbertbase/token_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# this is a base class for ColBERT per token based embedding

from abc import ABC, abstractmethod
from typing import Any, List
from typing import List
import uuid

class PerTokenEmbeddings():
Expand All @@ -22,7 +22,7 @@ def __init__(
self.title = title
self.part =part

def add_embeddings(self, embeddings: List[float]):
def add_embeddings(self, embeddings: List[float]):
self.__embeddings = embeddings

def get_embeddings(self) -> List[float]:
Expand Down Expand Up @@ -52,7 +52,6 @@ def __init__(
model: str = "colbert-ir/colbertv2.0",
dim: int = 128,
):
#self.token_ids = token_ids
self.__text = text
self.__token_embeddings = []
if id is None:
Expand All @@ -62,7 +61,7 @@ def __init__(
self.__model = model
self.__dim = dim
self.__title = title
self.__part = part
self.__part = part

def model(self):
return self.__model
Expand Down

0 comments on commit 6b8188e

Please sign in to comment.