Skip to content

Commit

Permalink
Merge branch 'master' into prf
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed Nov 24, 2024
2 parents 084e277 + 6909587 commit 9d17bb2
Show file tree
Hide file tree
Showing 30 changed files with 609 additions and 655 deletions.
File renamed without changes.
51 changes: 0 additions & 51 deletions .github/workflows/push.yml

This file was deleted.

36 changes: 36 additions & 0 deletions .github/workflows/style.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: style

on:
push: {branches: [master]} # pushes to master
pull_request: {} # all PRs

jobs:
ruff:
strategy:
matrix:
python-version: ['3.10']
os: ['ubuntu-latest']

runs-on: ${{ matrix.os }}
steps:
- name: Checkout
uses: actions/checkout@v4

- name: Install Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Cache Dependencies
uses: actions/cache@v4
with:
path: ${{ env.pythonLocation }}
key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('requirements.txt', 'requirements-dev.txt') }}

- name: Install Dependencies
run: |
pip install --upgrade -r requirements-dev.txt
pip install -e .
- name: Ruff
run: 'ruff check --output-format=github pyterrier_dr'
62 changes: 62 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
name: test

on:
push: {branches: [master]} # pushes to master
pull_request: {} # all PRs
schedule: [cron: '0 12 * * 3'] # every Wednesday at noon

jobs:
pytest:
strategy:
matrix:
os: ['ubuntu-latest']
python-version: ['3.8', '3.12']

runs-on: ${{ matrix.os }}
env:
runtag: ${{ matrix.os }}-${{ matrix.python-version }}

steps:
- name: Checkout
uses: actions/checkout@v4

- name: Install Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Cache Dependencies
uses: actions/cache@v4
with:
path: ${{ env.pythonLocation }}
key: ${{ env.runtag }}-${{ hashFiles('requirements.txt', 'requirements-dev.txt') }}

- name: Loading Torch models from cache
uses: actions/cache@v3
with:
path: /home/runner/.cache/
key: model-cache

- name: Install Dependencies
run: |
pip install --upgrade -r requirements.txt -r requirements-dev.txt
pip install -e .
- name: Unit Test
run: |
pytest --durations=20 -p no:faulthandler --json-report --json-report-file ${{ env.runtag }}.results.json --cov pyterrier_dr --cov-report json:${{ env.runtag }}.coverage.json tests/
- name: Upload Test Results
if: always()
uses: actions/upload-artifact@v4
with:
path: ${{ env.runtag }}.*.json
overwrite: true

- name: Report Test Results
if: always()
run: |
printf "**Test Results**\n\n" >> $GITHUB_STEP_SUMMARY
jq '.summary' ${{ env.runtag }}.results.json >> $GITHUB_STEP_SUMMARY
printf "\n\n**Test Coverage**\n\n" >> $GITHUB_STEP_SUMMARY
jq '.files | to_entries[] | " - `" + .key + "`: **" + .value.summary.percent_covered_display + "%**"' -r ${{ env.runtag }}.coverage.json >> $GITHUB_STEP_SUMMARY
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,5 @@ dmypy.json

# Pyre type checker
.pyre/

.DS_Store
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2024, Sean MacAvaney

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1 @@
include requirements.txt
recursive-include pyterrier_dr *.rst
43 changes: 43 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
[build-system]
requires = ["setuptools >= 61.0"]
build-backend = "setuptools.build_meta"

[project]
name = "pyterrier-dr"
description = "Dense Retrieval for PyTerrier"
requires-python = ">=3.8"
authors = [
{name = "Sean MacAvaney", email = "[email protected]"},
]
maintainers = [
{name = "Sean MacAvaney", email = "[email protected]"},
]
readme = "README.rst"
classifiers = [
"Programming Language :: Python",
"Operating System :: OS Independent",
"Topic :: Text Processing",
"Topic :: Text Processing :: Indexing",
"License :: OSI Approved :: MIT License",
]
dynamic = ["version", "dependencies"]

[tool.setuptools.dynamic]
version = {attr = "pyterrier_dr.__version__"}
dependencies = {file = ["requirements.txt"]}

[project.optional-dependencies]
bgem3 = [
"FlagEmbedding",
]

[tool.setuptools.packages.find]
exclude = ["tests"]

[project.urls]
Repository = "https://github.com/terrierteam/pyterrier_dr"
"Bug Tracker" = "https://github.com/terrierteam/pyterrier_dr/issues"

[project.entry-points."pyterrier.artifact"]
"dense_index.flex" = "pyterrier_dr:FlexIndex"
"cde_cache.np_pickle" = "pyterrier_dr:CDECache"
27 changes: 16 additions & 11 deletions pyterrier_dr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
__version__ = '0.2.0'

from .util import SimFn, infer_device
from .indexes import DocnoFile, NilIndex, NumpyIndex, RankedLists, FaissFlat, FaissHnsw, MemIndex, TorchIndex
from .flex import FlexIndex
from .biencoder import BiEncoder, BiQueryEncoder, BiDocEncoder, BiScorer
from .hgf_models import HgfBiEncoder, TasB, RetroMAE
from .sbert_models import SBertBiEncoder, Ance, Query2Query, GTR
from .tctcolbert_model import TctColBert
from .electra import ElectraScorer
from .bge_m3 import BGEM3, BGEM3QueryEncoder, BGEM3DocEncoder
from .cde import CDE, CDECache
from .prf import average_prf, vector_prf
from pyterrier_dr.util import SimFn, infer_device
from pyterrier_dr.indexes import DocnoFile, NilIndex, NumpyIndex, RankedLists, FaissFlat, FaissHnsw, MemIndex, TorchIndex
from pyterrier_dr.flex import FlexIndex
from pyterrier_dr.biencoder import BiEncoder, BiQueryEncoder, BiDocEncoder, BiScorer
from pyterrier_dr.hgf_models import HgfBiEncoder, TasB, RetroMAE
from pyterrier_dr.sbert_models import SBertBiEncoder, Ance, Query2Query, GTR
from pyterrier_dr.tctcolbert_model import TctColBert
from pyterrier_dr.electra import ElectraScorer
from pyterrier_dr.bge_m3 import BGEM3, BGEM3QueryEncoder, BGEM3DocEncoder
from pyterrier_dr.cde import CDE, CDECache
from pyterrier_dr.prf import average_prf, vector_prf

__all__ = ["FlexIndex", "DocnoFile", "NilIndex", "NumpyIndex", "RankedLists", "FaissFlat", "FaissHnsw", "MemIndex", "TorchIndex",
"BiEncoder", "BiQueryEncoder", "BiDocEncoder", "BiScorer", "HgfBiEncoder", "TasB", "RetroMAE", "SBertBiEncoder", "Ance",
"Query2Query", "GTR", "TctColBert", "ElectraScorer", "BGEM3", "BGEM3QueryEncoder", "BGEM3DocEncoder", "CDE", "CDECache",
"SimFn", "infer_device", "average_prf", "vector_prf"]
12 changes: 6 additions & 6 deletions pyterrier_dr/bge_m3.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from tqdm import tqdm
import pyterrier as pt
import pandas as pd
import numpy as np
import torch
import pyterrier_alpha as pta
from .biencoder import BiEncoder

class BGEM3(BiEncoder):
Expand All @@ -16,7 +16,7 @@ def __init__(self, model_name='BAAI/bge-m3', batch_size=32, max_length=8192, tex
self.device = torch.device(device)
try:
from FlagEmbedding import BGEM3FlagModel
except ImportError as e:
except ImportError:
raise ImportError("BGE-M3 requires the FlagEmbedding package. You can install it using 'pip install pyterrier-dr[bgem3]'")

self.model = BGEM3FlagModel(self.model_name, use_fp16=self.use_fp16, device=self.device)
Expand Down Expand Up @@ -61,8 +61,8 @@ def encode(self, texts):
return_dense=self.dense, return_sparse=self.sparse, return_colbert_vecs=self.multivecs)

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
assert all(c in inp.columns for c in ['query'])
pta.validate.columns(inp, includes=['query'])

# check if inp is empty
if len(inp) == 0:
if self.dense:
Expand Down Expand Up @@ -102,14 +102,14 @@ def __init__(self, bge_factory: BGEM3, verbose=None, batch_size=None, max_length
self.dense = return_dense
self.sparse = return_sparse
self.multivecs = return_colbert_vecs

def encode(self, texts):
return self.bge_factory.model.encode(list(texts), batch_size=self.batch_size, max_length=self.max_length,
return_dense=self.dense, return_sparse=self.sparse, return_colbert_vecs=self.multivecs)

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
# check if the input dataframe contains the field(s) specified in the text_field
assert all(c in inp.columns for c in [self.bge_factory.text_field])
pta.validate.columns(inp, includes=[self.bge_factory.text_field])
# check if inp is empty
if len(inp) == 0:
if self.dense:
Expand Down
45 changes: 22 additions & 23 deletions pyterrier_dr/biencoder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from more_itertools import chunked
import numpy as np
import torch
from torch import nn
import pyterrier as pt
import pandas as pd
import pyterrier_alpha as pta
from . import SimFn


Expand All @@ -21,22 +19,20 @@ def encode_docs(self, texts, batch_size=None) -> np.array:
raise NotImplementedError()

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
columns = set(inp.columns)
modes = [
(['qid', 'query', self.text_field], self.scorer),
(['qid', 'query_vec', self.text_field], self.scorer),
(['qid', 'query', 'doc_vec'], self.scorer),
(['qid', 'query_vec', 'doc_vec'], self.scorer),
(['query'], self.query_encoder),
([self.text_field], self.doc_encoder),
]
for fields, fn in modes:
if all(f in columns for f in fields):
return fn()(inp)
message = f'Unexpected input with columns: {inp.columns}. Supports:'
for fields, fn in modes:
message += f'\n - {fn.__doc__.strip()}: {fields}'
raise RuntimeError(message)
with pta.validate.any(inp) as v:
v.columns(includes=['query', self.text_field], mode='scorer')
v.columns(includes=['query_vec', self.text_field], mode='scorer')
v.columns(includes=['query', 'doc_vec'], mode='scorer')
v.columns(includes=['query_vec', 'doc_vec'], mode='scorer')
v.columns(includes=['query'], mode='query_encoder')
v.columns(includes=[self.text_field], mode='doc_encoder')

if v.mode == 'scorer':
return self.scorer()(inp)
elif v.mode == 'query_encoder':
return self.query_encoder()(inp)
elif v.mode == 'doc_encoder':
return self.doc_encoder()(inp)

def query_encoder(self, verbose=None, batch_size=None) -> pt.Transformer:
"""
Expand Down Expand Up @@ -76,7 +72,7 @@ def encode(self, texts, batch_size=None) -> np.array:
return self.bi_encoder_model.encode_queries(texts, batch_size=batch_size or self.batch_size)

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
assert all(c in inp.columns for c in ['query'])
pta.validate.columns(inp, includes=['query'])
it = inp['query'].values
it, inv = np.unique(it, return_inverse=True)
if self.verbose:
Expand All @@ -99,7 +95,7 @@ def encode(self, texts, batch_size=None) -> np.array:
return self.bi_encoder_model.encode_docs(texts, batch_size=batch_size or self.batch_size)

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
assert all(c in inp.columns for c in [self.text_field])
pta.validate.columns(inp, includes=[self.text_field])
it = inp[self.text_field]
if self.verbose:
it = pt.tqdm(it, desc='Encoding Docs', unit='doc')
Expand All @@ -118,8 +114,11 @@ def __init__(self, bi_encoder_model: BiEncoder, verbose=None, batch_size=None, t
self.sim_fn = sim_fn if sim_fn is not None else bi_encoder_model.sim_fn

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
assert 'query_vec' in inp.columns or 'query' in inp.columns
assert 'doc_vec' in inp.columns or self.text_field in inp.columns
with pta.validate.any(inp) as v:
v.columns(includes=['query_vec', 'doc_vec'])
v.columns(includes=['query', 'doc_vec'])
v.columns(includes=['query_vec', self.text_field])
v.columns(includes=['query', self.text_field])
if 'query_vec' in inp.columns:
query_vec = inp['query_vec']
else:
Expand Down
Loading

0 comments on commit 9d17bb2

Please sign in to comment.