Skip to content

Commit

Permalink
Add tests and Github Action (#17)
Browse files Browse the repository at this point in the history
* add .gitignore

* sentence transformers is optional

* add npids as dependency

* add gha

* WIP - none of these tests pass

* add error detection for common misuse

* add memory index testing

* not sure why overwrite is needed

* bump action versions

* update tests to use FlexIndex

* test_flexindex

* test_flexindex

* test_models

* more test cases

* more test cases in test_models

* more test cases in test_models

added default transformations to FlexIndex

* removed expensive (an so far unused) GA dependencies

* lazy faiss dependency

* more tests

* subtests in pytest

* fix empty

* update model cache key

* fixing some broken gha tests

* more tests, fixed bug with torch_retriever, fixed rank indexing

* fix dependency cache

* more tests

* scann test

* better np_vec_loader test cases

* fix indentation in scann_retr.py

* fix gta dependency caching

* optional dependency package assertions

* fix faiss_retr.py indentation

* support for ivf when cuda not available

* more tests, correct rank index, etc.

* no more dependency caching; doesn't seem to work properly

* better re-ranking test case

* test_torch_vecs

* make test skip conditions consistent

---------

Co-authored-by: Sean MacAvaney <[email protected]>
  • Loading branch information
cmacdonald and seanmacavaney authored Oct 21, 2023
1 parent 53e965a commit 4813f1e
Show file tree
Hide file tree
Showing 21 changed files with 770 additions and 162 deletions.
51 changes: 51 additions & 0 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
name: Test Python package

on: [push, pull_request]

jobs:
build:

strategy:
matrix:
python: [3.8]
java: [13]
os: ['ubuntu-latest'] #
architecture: ['x64']
terrier: ['snapshot'] #'5.3', '5.4-SNAPSHOT',

runs-on: ${{ matrix.os }}
steps:

- uses: actions/checkout@v3

- name: Setup java
uses: actions/setup-java@v3
with:
java-version: ${{ matrix.java }}
architecture: ${{ matrix.architecture }}
distribution: 'zulu'

- name: Setup conda
uses: s-weigand/setup-conda@v1
with:
python-version: ${{ matrix.python }}
conda-channels: anaconda, conda-forge
activate-conda: true

# follows https://medium.com/ai2-blog/python-caching-in-github-actions-e9452698e98d
- name: Loading Torch models from cache
uses: actions/cache@v3
with:
path: /home/runner/.cache/
key: model-cache

- name: Install Python dependencies
run: |
pip install --upgrade --upgrade-strategy eager -r requirements.txt -r requirements-dev.txt
conda install -c pytorch faiss-cpu=1.7.4 mkl=2021 blas=1.0=mkl
- name: All unit tests
env:
TERRIER_VERSION: ${{ matrix.terrier }}
run: |
pytest -s
131 changes: 131 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

.vscode/

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/
2 changes: 1 addition & 1 deletion pyterrier_dr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
from .flex import FlexIndex
from .biencoder import BiEncoder, BiQueryEncoder, BiDocEncoder, BiScorer
from .hgf_models import HgfBiEncoder, TasB, RetroMAE
from .sbert_models import SBertBiEncoder, Ance, Query2Query
from .sbert_models import SBertBiEncoder, Ance, Query2Query, GTR
from .tctcolbert_model import TctColBert
from .electra import ElectraScorer
3 changes: 3 additions & 0 deletions pyterrier_dr/biencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
columns = set(inp.columns)
modes = [
(['qid', 'query', self.text_field], self.scorer),
(['qid', 'query_vec', self.text_field], self.scorer),
(['qid', 'query', 'doc_vec'], self.scorer),
(['qid', 'query_vec', 'doc_vec'], self.scorer),
(['query'], self.query_encoder),
([self.text_field], self.doc_encoder),
]
Expand Down
24 changes: 16 additions & 8 deletions pyterrier_dr/flex/core.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,17 @@
import re
import math
import os
import shutil
import threading
import struct
import tempfile
import itertools
import json
from pathlib import Path
from warnings import warn
import numpy as np
import shutil
import more_itertools
import pandas as pd
import pyterrier as pt
from pyterrier.model import add_ranks
from npids import Lookup
from enum import Enum
from .. import SimFn
from ..indexes import RankedLists, TorchRankedLists
from ..indexes import RankedLists
import ir_datasets
import torch

Expand Down Expand Up @@ -91,6 +85,20 @@ def index(self, inp):
with open(path/'pt_meta.json', 'wt') as f_meta:
json.dump({"type": "dense_index", "format": "flex", "vec_size": vec_size, "doc_count": count}, f_meta)

def transform(self, inp):
columns = set(inp.columns)
modes = [
(['qid', 'query_vec'], self.np_retriever, "performing exhaustive saerch with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster"),
]
for fields, fn, note in modes:
if all(f in columns for f in fields):
warn(f'based on input columns {list(columns)}, {note}')
return fn()(inp)
message = f'Unexpected input with columns: {inp.columns}. Supports:'
for fields, fn in modes:
message += f'\n - {fn.__doc__.strip()}: {fields}'
raise RuntimeError(message)

def get_corpus_iter(self, start_idx=None, stop_idx=None, verbose=True):
docnos, dvecs, meta = self.payload()
docno_iter = iter(docnos)
Expand Down
13 changes: 8 additions & 5 deletions pyterrier_dr/flex/corpus_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import ir_datasets
import torch
import numpy as np
import pyterrier_dr
from ..indexes import TorchRankedLists
from . import FlexIndex

Expand Down Expand Up @@ -40,20 +41,22 @@ def _build_corpus_graph(flex_index, k, out_dir, batch_size):
out_dir.mkdir(parents=True, exist_ok=True)
edges_path = out_dir/'edges.u32.np'
weights_path = out_dir/'weights.f16.np'
device = pyterrier_dr.util.infer_device()
dtype = torch.half if device.type == 'cuda' else torch.float
with logger.pbar_raw(total=int((num_chunks+1)*num_chunks/2), unit='chunk', smoothing=1) as pbar, \
ir_datasets.util.finialized_file(str(edges_path), 'wb') as fe, \
ir_datasets.util.finialized_file(str(weights_path), 'wb') as fw:
for i in range(num_chunks):
left = torch.from_numpy(vectors[i*S:(i+1)*S]).cuda().half()
left /= left.norm(dim=1, keepdim=True)
left = torch.from_numpy(vectors[i*S:(i+1)*S]).to(device).to(dtype)
left = left / left.norm(dim=1, keepdim=True)
scores = left @ left.T
scores[torch.eye(left.shape[0], dtype=bool, device='cuda')] = float('-inf')
scores[torch.eye(left.shape[0], dtype=bool, device=device)] = float('-inf')
i_scores, i_dids = scores.topk(k, sorted=True, dim=1)
rankings[i].update(i_scores, (i_dids + i*S))
pbar.update()
for j in range(i+1, num_chunks):
right = torch.from_numpy(vectors[j*S:(j+1)*S]).cuda().half()
right /= right.norm(dim=1, keepdim=True)
right = torch.from_numpy(vectors[j*S:(j+1)*S]).to(device).to(dtype)
right = right / right.norm(dim=1, keepdim=True)
scores = left @ right.T
i_scores, i_dids = scores.topk(min(k, right.shape[0]), sorted=True, dim=1)
j_scores, j_dids = scores.topk(min(k, left.shape[0]), sorted=True, dim=0)
Expand Down
Loading

0 comments on commit 4813f1e

Please sign in to comment.