From 4dae8e649523ff58b2e0b3c6bd3a2818738a41d4 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Wed, 4 Dec 2024 15:15:01 +0000 Subject: [PATCH] Documentation for `FlexIndex` (#31) * flexindex documentation * removed unused import * name conflict * copy-paste error * faster tests * misc * misc * : in directives mean something else, changed to . * some model documentation * bind staticmethod * default model_name flow through * biencoder documentation * good enough, I just need to be done with this --- pyterrier_dr/bge_m3.py | 2 +- pyterrier_dr/biencoder.py | 68 ++++++- pyterrier_dr/cde.py | 2 +- pyterrier_dr/flex/core.py | 205 ++++++++++++++++---- pyterrier_dr/flex/corpus_graph.py | 17 +- pyterrier_dr/flex/faiss_retr.py | 111 ++++++++++- pyterrier_dr/flex/gar.py | 18 +- pyterrier_dr/flex/ladr.py | 69 ++++++- pyterrier_dr/flex/np_retr.py | 65 ++++++- pyterrier_dr/flex/scann_retr.py | 34 +++- pyterrier_dr/flex/torch_retr.py | 79 +++++++- pyterrier_dr/flex/voyager_retr.py | 33 +++- pyterrier_dr/hgf_models.py | 32 ++- pyterrier_dr/prf.py | 46 +---- pyterrier_dr/pt_docs/encoding.rst | 31 +++ pyterrier_dr/pt_docs/index.rst | 42 +--- pyterrier_dr/pt_docs/indexing-retrieval.rst | 91 +++++++++ pyterrier_dr/pt_docs/overview.rst | 138 +++++++++++++ pyterrier_dr/pt_docs/prf.rst | 4 +- pyterrier_dr/pt_docs/sbert.rst | 24 --- pyterrier_dr/sbert_models.py | 54 +++++- pyterrier_dr/tctcolbert_model.py | 29 ++- pyterrier_dr/util.py | 5 + tests/test_models.py | 4 +- 24 files changed, 988 insertions(+), 215 deletions(-) create mode 100644 pyterrier_dr/pt_docs/encoding.rst create mode 100644 pyterrier_dr/pt_docs/indexing-retrieval.rst create mode 100644 pyterrier_dr/pt_docs/overview.rst delete mode 100644 pyterrier_dr/pt_docs/sbert.rst diff --git a/pyterrier_dr/bge_m3.py b/pyterrier_dr/bge_m3.py index 780f36f..9915d7b 100644 --- a/pyterrier_dr/bge_m3.py +++ b/pyterrier_dr/bge_m3.py @@ -7,7 +7,7 @@ class BGEM3(BiEncoder): def __init__(self, model_name='BAAI/bge-m3', batch_size=32, max_length=8192, text_field='text', verbose=False, device=None, use_fp16=False): - super().__init__(batch_size, text_field, verbose) + super().__init__(batch_size=batch_size, text_field=text_field, verbose=verbose) self.model_name = model_name self.use_fp16 = use_fp16 self.max_length = max_length diff --git a/pyterrier_dr/biencoder.py b/pyterrier_dr/biencoder.py index 9e7684e..91f3a02 100644 --- a/pyterrier_dr/biencoder.py +++ b/pyterrier_dr/biencoder.py @@ -1,3 +1,5 @@ +from typing import List, Optional +from abc import abstractmethod import numpy as np import pyterrier as pt import pandas as pd @@ -6,18 +8,33 @@ class BiEncoder(pt.Transformer): - def __init__(self, batch_size=32, text_field='text', verbose=False): + """Represents a single-vector dense bi-encoder. + + A ``BiEncoder`` encodes the text of a query or document into a dense vector. + + This class functions as a transformer factory: + - Query encoding using :meth:`query_encoder` + - Document encoding using :meth:`doc_encoder` + - Text scoring (re-reranking) using :meth:`text_scorer` + + It can also be used as a transformer directly. It infers which transformer to use + based on columns present in the input frame. + + Note that in most cases, you will want to use a ``BiEncoder`` as part of a pipeline + with a :class:`~pyterrier_dr.FlexIndex` to perform dense indexing and retrival. + """ + def __init__(self, *, batch_size=32, text_field='text', verbose=False): + """ + Args: + batch_size: The default batch size to use for query/document encoding + text_field: The field in the input dataframe that contains the document text + verbose: Whether to show progress bars + """ super().__init__() self.batch_size = batch_size self.text_field = text_field self.verbose = verbose - def encode_queries(self, texts, batch_size=None) -> np.array: - raise NotImplementedError() - - def encode_docs(self, texts, batch_size=None) -> np.array: - raise NotImplementedError() - def transform(self, inp: pd.DataFrame) -> pd.DataFrame: with pta.validate.any(inp) as v: v.columns(includes=['query', self.text_field], mode='scorer') @@ -46,12 +63,15 @@ def doc_encoder(self, verbose=None, batch_size=None) -> pt.Transformer: """ return BiDocEncoder(self, verbose=verbose, batch_size=batch_size) - def scorer(self, verbose=None, batch_size=None, sim_fn=None) -> pt.Transformer: + def text_scorer(self, verbose=None, batch_size=None, sim_fn=None) -> pt.Transformer: """ - Scoring (re-ranking) + Text Scoring (re-ranking) """ return BiScorer(self, verbose=verbose, batch_size=batch_size, sim_fn=sim_fn) + def scorer(self, verbose=None, batch_size=None, sim_fn=None) -> pt.Transformer: + return self.text_scorer(verbose=verbose, batch_size=batch_size, sim_fn=sim_fn) + @property def sim_fn(self) -> SimFn: """ @@ -61,6 +81,36 @@ def sim_fn(self) -> SimFn: return SimFn(self.config.sim_fn) return SimFn.dot # default + @abstractmethod + def encode_queries(self, texts: List[str], batch_size: Optional[int] = None) -> np.array: + """Abstract method to encode a list of query texts into dense vectors. + + This function is used by the transformer returned by :meth:`query_encoder`. + + Args: + texts: A list of query texts + batch_size: The batch size to use for encoding + + Returns: + np.array: A numpy array of shape (n_queries, n_dims) + """ + raise NotImplementedError() + + @abstractmethod + def encode_docs(self, texts: List[str], batch_size: Optional[int] = None) -> np.array: + """Abstract method to encode a list of document texts into dense vectors. + + This function is used by the transformer returned by :meth:`doc_encoder`. + + Args: + texts: A list of document texts + batch_size: The batch size to use for encoding + + Returns: + np.array: A numpy array of shape (n_docs, n_dims) + """ + raise NotImplementedError() + class BiQueryEncoder(pt.Transformer): def __init__(self, bi_encoder_model: BiEncoder, verbose=None, batch_size=None): diff --git a/pyterrier_dr/cde.py b/pyterrier_dr/cde.py index da6c39f..f421300 100644 --- a/pyterrier_dr/cde.py +++ b/pyterrier_dr/cde.py @@ -13,7 +13,7 @@ class CDE(BiEncoder): def __init__(self, model_name='jxm/cde-small-v1', cache: Optional['CDECache'] = None, batch_size=32, text_field='text', verbose=False, device=None): - super().__init__(batch_size, text_field, verbose) + super().__init__(batch_size=batch_size, text_field=text_field, verbose=verbose) self.model_name = model_name if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' diff --git a/pyterrier_dr/flex/core.py b/pyterrier_dr/flex/core.py index 399418c..6227dc4 100644 --- a/pyterrier_dr/flex/core.py +++ b/pyterrier_dr/flex/core.py @@ -1,8 +1,8 @@ +from typing import Union, Iterable, Dict import shutil import itertools import json from pathlib import Path -from warnings import warn import numpy as np import more_itertools import pandas as pd @@ -20,12 +20,32 @@ class IndexingMode(Enum): class FlexIndex(pta.Artifact, pt.Indexer): - def __init__(self, index_path, num_results=1000, sim_fn=SimFn.dot, indexing_mode=IndexingMode.create, verbose=True): - super().__init__(index_path) - self.index_path = Path(index_path) - self.num_results = num_results + """ Represents a FLexible EXecution (FLEX) Index, which is a dense index format. + + FLEX allows for a variety of retrieval implementations (NumPy, FAISS, etc.) and algorithms (exhaustive, HNSW, etc.) + to be tested. In most cases, the same vector storage can be used across implementations and algorithms, saving + considerably on disk space. + """ + + ARTIFACT_TYPE = 'dense_index' + ARTIFACT_FORMAT = 'flex' + + + def __init__(self, + path: str, + *, + sim_fn: Union[SimFn, str] = SimFn.dot, + verbose: bool = True + ): + """ + Args: + path: The path to the index directory + sim_fn: The similarity function to use + verbose: Whether to display verbose output (e.g., progress bars) + """ + super().__init__(path) + self.index_path = Path(path) self.sim_fn = SimFn(sim_fn) - self.indexing_mode = IndexingMode(indexing_mode) self.verbose = verbose self._meta = None self._docnos = None @@ -53,44 +73,88 @@ def __len__(self): meta, = self.payload(return_dvecs=False, return_docnos=False) return meta['doc_count'] - def index(self, inp): - if isinstance(inp, pd.DataFrame): - inp = inp.to_dict(orient="records") - inp = more_itertools.peekable(inp) - path = Path(self.index_path) - if path.exists(): - if self.indexing_mode == IndexingMode.overwrite: - shutil.rmtree(path) - else: - raise RuntimeError(f'Index already exists at {self.index_path}. If you want to delete and re-create an existing index, you can pass indexing_mode=IndexingMode.overwrite') - path.mkdir(parents=True, exist_ok=True) - vec_size = None - count = 0 - if self.verbose: - inp = pt.tqdm(inp, desc='indexing', unit='dvec') - with open(path/'vecs.f4', 'wb') as fout, Lookup.builder(path/'docnos.npids') as docnos: - for d in inp: - vec = d['doc_vec'] - if vec_size is None: - vec_size = vec.shape[0] - elif vec_size != vec.shape[0]: - raise ValueError(f'Inconsistent vector shapes detected (expected {vec_size} but found {vec.shape[0]})') - vec = vec.astype(np.float32) - fout.write(vec.tobytes()) - docnos.add(d['docno']) - count += 1 - with open(path/'pt_meta.json', 'wt') as f_meta: - json.dump({"type": "dense_index", "format": "flex", "vec_size": vec_size, "doc_count": count}, f_meta) + def index(self, inp: Iterable[Dict]) -> pta.Artifact: + """Index the given input data stream to a new index at this location. + + Each record in ``inp`` is expected to be a dictionary containing at least two keys: ``docno`` (a unique document + identifier) and ``doc_vec`` (a dense vector representation of the document). + + Typically this method will be used in a pipeline of operations, where the input data is first transformed by a + document encoder to add the ``doc_vec`` values before it is indexed. For example: + + .. code-block:: python + :caption: Index documents into a :class:`~pyterrier_dr.FlexIndex` using a :class:`~pyterrier_dr.TasB` encoder. + + from pyterrier_dr import TasB, FlexIndex + encoder = TasB.dot() + index = FlexIndex('my_index') + pipeline = encoder >> index + pipeline.index([ + {'docno': 'doc1', 'text': 'hello'}, + {'docno': 'doc2', 'text': 'world'}, + ]) + + Args: + inp: An iterable of dictionaries to index. + + Returns: + :class:`pyterrier_alpha.Artifact`: A reference back to this index (``self``). + + Raises: + RuntimeError: If the index is aready built. + """ + return self.indexer().index(inp) + + def indexer(self, *, mode: Union[IndexingMode, str] = IndexingMode.create) -> 'FlexIndexer': + """Return an indexer for this index with the specified options. + + This transformer gives more fine-grained control over the indexing process, allowing you to specify whether + to create a new index or overwrite an existing one. + + Similar to :meth:`index`, this method will typically be used in a pipeline of operations, where the input data + is first transformed by a document encoder to add the ``doc_vec`` values before it is indexed. For example: + + .. code-block:: python + :caption: Oerwrite a :class:`~pyterrier_dr.FlexIndex` using a :class:`~pyterrier_dr.TasB` encoder. + + from pyterrier_dr import TasB, FlexIndex + encoder = TasB.dot() + index = FlexIndex('my_index') + pipeline = encoder >> index.indexer(mode='overwrite') + pipeline.index([ + {'docno': 'doc1', 'text': 'hello'}, + {'docno': 'doc2', 'text': 'world'}, + ]) + + Args: + mode: The indexing mode to use (``create`` or ``overwrite``). + + Returns: + :class:`~pyterrier.Indexer`: A new indexer instance. + """ + return FlexIndexer(self, mode=mode) def transform(self, inp): with pta.validate.any(inp) as v: - v.query_frame(extra_columns=['query_vec'], mode='np_retriever') + v.query_frame(extra_columns=['query_vec'], mode='retriever') + v.result_frame(extra_columns=['query_vec'], mode='scorer') + + if v.mode == 'retriever': + return self.retriever()(inp) + if v.mode == 'scorer': + return self.scorer()(inp) + + def get_corpus_iter(self, start_idx=None, stop_idx=None, verbose=True) -> Iterable[Dict]: + """Iterate over the documents in the index. - if v.mode == 'np_retriever': - warn("performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster") - return self.np_retriever()(inp) + Args: + start_idx: The index of the first document to return (or ``None`` to start at the first document). + stop_idx: The index of the last document to return (or ``None`` to end on the last document). + verbose: Whether to display a progress bar. - def get_corpus_iter(self, start_idx=None, stop_idx=None, verbose=True): + Yields: + Dict[str,Any]: A dictionary with keys ``docno`` and ``doc_vec``. + """ docnos, dvecs, meta = self.payload() docno_iter = iter(docnos) if start_idx is not None or stop_idx is not None: @@ -111,9 +175,70 @@ def _load_docids(self, inp): docnos, config = self.payload(return_dvecs=False) return docnos.inv[inp['docno'].values] # look up docids from docnos - def built(self): + def built(self) -> bool: + """Check if the index has been built. + + Returns: + bool: ``True`` if the index has been built, otherwise ``False``. + """ return self.index_path.exists() + def docnos(self) -> Lookup: + """Return the document identifier (docno) lookup data structure. + + Returns: + :class:`npids.Lookup`: The document number lookup. + """ + docnos, meta = self.payload(return_dvecs=False) + return docnos + + +class FlexIndexer(pt.Indexer): + def __init__(self, index: FlexIndex, mode: Union[IndexingMode, str] = IndexingMode.create): + self._index = index + self.mode = IndexingMode(mode) + + def __repr__(self): + return f'{self._index}.indexer(mode={self.mode!r})' + + def transform(self, inp): + raise RuntimeError("FlexIndexer cannot be used as a transformer, use .index() instead") + + def index(self, inp): + if isinstance(inp, pd.DataFrame): + inp = inp.to_dict(orient="records") + inp = more_itertools.peekable(inp) + path = Path(self._index.index_path) + if path.exists(): + if self.mode == IndexingMode.overwrite: + shutil.rmtree(path) + else: + raise RuntimeError(f'Index already exists at {self._index.index_path}. If you want to delete and re-create an existing index, you can pass index.indexer(mode="overwrite")') + path.mkdir(parents=True, exist_ok=True) + vec_size = None + count = 0 + if self._index.verbose: + inp = pt.tqdm(inp, desc='indexing', unit='dvec') + with open(path/'vecs.f4', 'wb') as fout, Lookup.builder(path/'docnos.npids') as docnos: + for d in inp: + vec = d['doc_vec'] + if vec_size is None: + vec_size = vec.shape[0] + elif vec_size != vec.shape[0]: + raise ValueError(f'Inconsistent vector shapes detected (expected {vec_size} but found {vec.shape[0]})') + vec = vec.astype(np.float32) + fout.write(vec.tobytes()) + docnos.add(d['docno']) + count += 1 + with open(path/'pt_meta.json', 'wt') as f_meta: + json.dump({ + "type": self._index.ARTIFACT_TYPE, + "format": self._index.ARTIFACT_FORMAT, + "vec_size": vec_size, + "doc_count": count + }, f_meta) + return self._index + def _load_dvecs(flex_index, inp): dvecs, config = flex_index.payload(return_docnos=False) diff --git a/pyterrier_dr/flex/corpus_graph.py b/pyterrier_dr/flex/corpus_graph.py index e99fb06..516f302 100644 --- a/pyterrier_dr/flex/corpus_graph.py +++ b/pyterrier_dr/flex/corpus_graph.py @@ -10,7 +10,22 @@ from . import FlexIndex -def _corpus_graph(self, k=16, batch_size=8192): +def _corpus_graph(self, k: int = 16, *, batch_size: int = 8192): + """Return the corpus graph (neighborhood graph) for the index. + + The corpus graph is a directed graph where each node represents a document and each edge represents a + connection between two documents. The graph is built by computing the cosine similarity between each + pair of documents and storing the k-nearest neighbors for each document. + + If the corpus graph has not been built yet, it will be built using the given k and batch size. + + Args: + k: The number of neighbors to store for each document. + batch_size: The number of vectors to process in each batch. + + Returns: + :class:`pyterrier_adaptive.CorpusGraph`: The corpus graph for the index. + """ from pyterrier_adaptive import CorpusGraph key = ('corpus_graph', k) if key not in self._cache: diff --git a/pyterrier_dr/flex/faiss_retr.py b/pyterrier_dr/flex/faiss_retr.py index a1a1d22..fab8c6b 100644 --- a/pyterrier_dr/flex/faiss_retr.py +++ b/pyterrier_dr/flex/faiss_retr.py @@ -1,3 +1,4 @@ +from typing import Optional import json import math import struct @@ -14,7 +15,7 @@ class FaissRetriever(pt.Indexer): - def __init__(self, flex_index, faiss_index, n_probe=None, ef_search=None, search_bounded_queue=None, qbatch=64, drop_query_vec=False): + def __init__(self, flex_index, faiss_index, n_probe=None, ef_search=None, search_bounded_queue=None, qbatch=64, drop_query_vec=False, num_results=1000): self.flex_index = flex_index self.faiss_index = faiss_index self.n_probe = n_probe @@ -22,6 +23,7 @@ def __init__(self, flex_index, faiss_index, n_probe=None, ef_search=None, search self.search_bounded_queue = search_bounded_queue self.qbatch = qbatch self.drop_query_vec = drop_query_vec + self.num_results = num_results def transform(self, inp): pta.validate.query_frame(inp, extra_columns=['query_vec']) @@ -43,7 +45,7 @@ def transform(self, inp): result = pta.DataFrameBuilder(['docno', 'docid', 'score', 'rank']) for qidx in it: - scores, dids = self.faiss_index.search(query_vecs[qidx:qidx+QBATCH], self.flex_index.num_results) + scores, dids = self.faiss_index.search(query_vecs[qidx:qidx+QBATCH], self.num_results) for s, d in zip(scores, dids): mask = d != -1 d = d[mask] @@ -60,7 +62,22 @@ def transform(self, inp): return result.to_df(inp) -def _faiss_flat_retriever(self, gpu=False, qbatch=64, drop_query_vec=False): +def _faiss_flat_retriever(self, *, gpu=False, qbatch=64, drop_query_vec=False): + """Returns a retriever that uses FAISS to perform brute-force search over the indexed vectors. + + Args: + gpu: Whether to load the index onto GPU for scoring + qbatch: The batch size during search + drop_query_vec: Whether to drop the query vector from the output + + Returns: + :class:`~pyterrier.Transformer`: A retriever that uses FAISS to perform brute-force search over the indexed vectors + + .. note:: + This transformer requires the ``faiss`` package to be installed. + + .. cite.dblp:: journals/corr/abs-2401-08281 + """ pyterrier_dr.util.assert_faiss() import faiss if 'faiss_flat' not in self._cache: @@ -85,7 +102,41 @@ def _faiss_flat_retriever(self, gpu=False, qbatch=64, drop_query_vec=False): FlexIndex.faiss_flat_retriever = _faiss_flat_retriever -def _faiss_hnsw_retriever(self, neighbours=32, ef_construction=40, ef_search=16, cache=True, search_bounded_queue=True, qbatch=64, drop_query_vec=False): +def _faiss_hnsw_retriever( + self, + neighbours: int = 32, + *, + num_results: int = 1000, + ef_construction: int = 40, + ef_search: int = 16, + cache: bool = True, + search_bounded_queue: bool = True, + qbatch: int = 64, + drop_query_vec: bool = False, +) -> pt.Transformer: + """Returns a retriever that uses FAISS over a HNSW index. + + Creates the HNSW graph structure if it does not already exist. When ``cache=True`` (dfault), this graph structure is + cached to disk for subsequent use. + + Args: + neighbours: The number of neighbours of the constructed neighborhood graph + num_results: The number of results to return per query + ef_construction: The number of neighbours to consider during construction + ef_search: The number of neighbours to consider during search + cache: Whether to cache the index to disk + search_bounded_queue: Whether to use a bounded queue during search + qbatch: The batch size during search + drop_query_vec: Whether to drop the query vector from the output + + Returns: + :class:`~pyterrier.Transformer`: A retriever that uses FAISS over a HNSW index + + .. note:: + This transformer requires the ``faiss`` package to be installed. + + .. cite.dblp:: journals/corr/abs-2401-08281 + """ pyterrier_dr.util.assert_faiss() import faiss meta, = self.payload(return_dvecs=False, return_docnos=False) @@ -107,11 +158,27 @@ def _faiss_hnsw_retriever(self, neighbours=32, ef_construction=40, ef_search=16, with logger.duration('reading hnsw table'): self._cache[key] = faiss.read_index(str(self.index_path/index_name)) self._cache[key].storage = self.faiss_flat_retriever().faiss_index - return FaissRetriever(self, self._cache[key], ef_search=ef_search, search_bounded_queue=search_bounded_queue, qbatch=qbatch, drop_query_vec=drop_query_vec) + return FaissRetriever(self, self._cache[key], num_results=num_results, ef_search=ef_search, search_bounded_queue=search_bounded_queue, qbatch=qbatch, drop_query_vec=drop_query_vec) FlexIndex.faiss_hnsw_retriever = _faiss_hnsw_retriever -def _faiss_hnsw_graph(self, neighbours=32, ef_construction=40): +def _faiss_hnsw_graph(self, neighbours: int = 32, *, ef_construction: int = 40): + """Returns the (approximate) HNSW graph structure created by the HNSW index. + + If the graph structure does not already exist, it is created and cached to disk. + + Args: + neighbours: The number of neighbours of the constructed neighborhood graph + ef_construction: The number of neighbours to consider during construction + + Returns: + :class:`pyterrier_adaptive.CorpusGraph`: The HNSW graph structure + + .. note:: + This function requires the ``faiss`` package to be installed. + + .. cite.dblp:: journals/corr/abs-2401-08281 + """ key = ('faiss_hnsw', neighbours//2, ef_construction) graph_name = f'hnsw_n-{neighbours}_ef-{ef_construction}.graph' if key not in self._cache: @@ -154,7 +221,35 @@ def _sample_train(index, count=None): idxs = np.random.RandomState(0).choice(dvecs.shape[0], size=count, replace=False) return dvecs[idxs] -def _faiss_ivf_retriever(self, train_sample=None, n_list=None, cache=True, n_probe=1, drop_query_vec=False): +def _faiss_ivf_retriever(self, + *, + num_results: int = 1000, + train_sample: Optional[int] = None, + n_list: Optional[int] = None, + cache: bool = True, + n_probe: int = 1, + drop_query_vec: bool = False +): + """Returns a retriever that uses FAISS over an IVF index. + + If the IVF structure does not already exist, it is created and cached to disk (when ``cache=True`` (default)). + + Args: + num_results: The number of results to return per query + train_sample: The number of training samples to use for training the index. If not provided, a default value is used (approximately the square root of the number of documents). + n_list: The number of posting lists to use for the index. If not provided, a default value is used (approximately ``train_sample/39``). + cache: Whether to cache the index to disk. + n_probe: The number of posting lists to probe during search. The higher the value, the better the approximation will be, but the longer it will take. + drop_query_vec: Whether to drop the query vector from the output. + + Returns: + :class:`~pyterrier.Transformer`: A retriever that uses FAISS over an IVF index + + .. note:: + This transformer requires the ``faiss`` package to be installed. + + .. cite.dblp:: journals/corr/abs-2401-08281 + """ pyterrier_dr.util.assert_faiss() import faiss meta, = self.payload(return_dvecs=False, return_docnos=False) @@ -197,5 +292,5 @@ def _faiss_ivf_retriever(self, train_sample=None, n_list=None, cache=True, n_pro else: with logger.duration('reading index'): self._cache[key] = faiss.read_index(str(self.index_path/index_name)) - return FaissRetriever(self, self._cache[key], n_probe=n_probe, drop_query_vec=drop_query_vec) + return FaissRetriever(self, self._cache[key], num_results=num_results, n_probe=n_probe, drop_query_vec=drop_query_vec) FlexIndex.faiss_ivf_retriever = _faiss_ivf_retriever diff --git a/pyterrier_dr/flex/gar.py b/pyterrier_dr/flex/gar.py index a9356c8..2f3c818 100644 --- a/pyterrier_dr/flex/gar.py +++ b/pyterrier_dr/flex/gar.py @@ -65,7 +65,23 @@ def transform(self, inp): return all_results.to_df() +def _gar(self, + k: int = 16, + *, + batch_size: int = 128, + num_results: int = 1000 +) -> pt.Transformer: + """Returns a retriever that uses a corpus graph to search over a FlexIndex. -def _gar(self, k=16, batch_size=128, num_results=1000): + Args: + k (int): Number of neighbours in the corpus graph. Defaults to 16. + batch_size (int): Batch size for retrieval. Defaults to 128. + num_results (int): Number of results per query to return. Defaults to 1000. + + Returns: + :class:`~pyterrier.Transformer`: A retriever that uses a corpus graph to search over a FlexIndex. + + .. cite.dblp:: conf/cikm/MacAvaneyTM22 + """ return FlexGar(self, self.corpus_graph(k), SimFn.dot, batch_size=batch_size, num_results=num_results) FlexIndex.gar = _gar diff --git a/pyterrier_dr/flex/ladr.py b/pyterrier_dr/flex/ladr.py index 660140b..bc79e46 100644 --- a/pyterrier_dr/flex/ladr.py +++ b/pyterrier_dr/flex/ladr.py @@ -1,3 +1,4 @@ +from typing import Optional import numpy as np import pyterrier as pt import pyterrier_alpha as pta @@ -5,10 +6,11 @@ class LadrPreemptive(pt.Transformer): - def __init__(self, flex_index, graph, dense_scorer, hops=1, drop_query_vec=False): + def __init__(self, flex_index, graph, dense_scorer, num_results=1000, hops=1, drop_query_vec=False): self.flex_index = flex_index self.graph = graph self.dense_scorer = dense_scorer + self.num_results = num_results self.hops = hops self.drop_query_vec = drop_query_vec @@ -35,8 +37,8 @@ def transform(self, inp): query_vecs = df['query_vec'].iloc[0].reshape(1, -1) scores = self.dense_scorer.score(query_vecs, ext_docids) scores = scores.reshape(-1) - if scores.shape[0] > self.flex_index.num_results: - idxs = np.argpartition(scores, -self.flex_index.num_results)[-self.flex_index.num_results:] + if scores.shape[0] > self.num_results: + idxs = np.argpartition(scores, -self.num_results)[-self.num_results:] else: idxs = np.arange(scores.shape[0]) docids, scores = ext_docids[idxs], scores[idxs] @@ -51,17 +53,40 @@ def transform(self, inp): return all_results.to_df() -def _pre_ladr(self, k=16, hops=1, dense_scorer=None, drop_query_vec=False): +def _pre_ladr(self, + k: int = 16, + *, + hops: int = 1, + num_results: int = 1000, + dense_scorer: Optional[pt.Transformer] = None, + drop_query_vec: bool = False +) -> pt.Transformer: + """Returns a proactive LADR (Lexicaly-Accelerated Dense Retrieval) transformer. + + Args: + k (int): The number of neighbours in the corpus graph. + hops (int): The number of hops to consider. Defaults to 1. + num_results (int): The number of results to return per query. + dense_scorer (:class:`~pyterrier.Transformer`, optional): The dense scorer to use. Defaults to :meth:`np_scorer`. + drop_query_vec (bool): Whether to drop the query vector from the output. + + Returns: + :class:`~pyterrier.Transformer`: A proactive LADR transformer. + + .. cite.dblp:: conf/sigir/KulkarniMGF23 + """ graph = self.corpus_graph(k) if isinstance(k, int) else k - return LadrPreemptive(self, graph, hops=hops, dense_scorer=dense_scorer or self.scorer(), drop_query_vec=drop_query_vec) + return LadrPreemptive(self, graph, num_results=num_results, hops=hops, dense_scorer=dense_scorer or self.scorer(), drop_query_vec=drop_query_vec) FlexIndex.ladr = _pre_ladr # TODO: remove this alias later FlexIndex.pre_ladr = _pre_ladr +FlexIndex.ladr_proactive = _pre_ladr class LadrAdaptive(pt.Transformer): - def __init__(self, flex_index, graph, dense_scorer, depth=100, max_hops=None, drop_query_vec=False): + def __init__(self, flex_index, graph, dense_scorer, num_results=1000, depth=100, max_hops=None, drop_query_vec=False): self.flex_index = flex_index self.graph = graph self.dense_scorer = dense_scorer + self.num_results = num_results self.depth = depth self.max_hops = max_hops self.drop_query_vec = drop_query_vec @@ -100,8 +125,8 @@ def transform(self, inp): docids = cat_dids[idxs] scores = np.concatenate([scores, neighbour_scores])[idxs] rnd += 1 - if scores.shape[0] > self.flex_index.num_results: - idxs = np.argpartition(scores, -self.flex_index.num_results)[-self.flex_index.num_results:] + if scores.shape[0] > self.num_results: + idxs = np.argpartition(scores, -self.num_results)[-self.num_results:] else: idxs = np.arange(scores.shape[0]) docids, scores = docids[idxs], scores[idxs] @@ -115,7 +140,31 @@ def transform(self, inp): )) return all_results.to_df() -def _ada_ladr(self, k=16, dense_scorer=None, depth=100, max_hops=None, drop_query_vec=False): +def _ada_ladr(self, + k: int = 16, + *, + depth: int = 100, + num_results: int = 1000, + dense_scorer: Optional[pt.Transformer] = None, + max_hops: Optional[int] = None, + drop_query_vec: bool = False +) -> pt.Transformer: + """Returns an adaptive LADR (Lexicaly-Accelerated Dense Retrieval) transformer. + + Args: + k (int): The number of neighbours in the corpus graph. + depth (int): The depth of the ranked list to consider for convergence. + num_results (int): The number of results to return per query. + dense_scorer (:class:`~pyterrier.Transformer`, optional): The dense scorer to use. Defaults to :meth:`np_scorer`. + max_hops (int, optional): The maximum number of hops to consider. Defaults to ``None`` (no limit). + drop_query_vec (bool): Whether to drop the query vector from the output. + + Returns: + :class:`~pyterrier.Transformer`: An adaptive LADR transformer. + + .. cite.dblp:: conf/sigir/KulkarniMGF23 + """ graph = self.corpus_graph(k) if isinstance(k, int) else k - return LadrAdaptive(self, graph, dense_scorer=dense_scorer or self.scorer(), depth=depth, max_hops=max_hops, drop_query_vec=drop_query_vec) + return LadrAdaptive(self, graph, num_results=num_results, dense_scorer=dense_scorer or self.scorer(), depth=depth, max_hops=max_hops, drop_query_vec=drop_query_vec) FlexIndex.ada_ladr = _ada_ladr +FlexIndex.ladr_adaptive = _ada_ladr diff --git a/pyterrier_dr/flex/np_retr.py b/pyterrier_dr/flex/np_retr.py index c0e6c2f..620b135 100644 --- a/pyterrier_dr/flex/np_retr.py +++ b/pyterrier_dr/flex/np_retr.py @@ -1,3 +1,4 @@ +from typing import Optional import pyterrier as pt import numpy as np import pandas as pd @@ -8,7 +9,13 @@ class NumpyRetriever(pt.Transformer): - def __init__(self, flex_index, num_results=1000, batch_size=None, drop_query_vec=False): + def __init__(self, + flex_index: FlexIndex, + *, + num_results: int = 1000, + batch_size: Optional[int] = None, + drop_query_vec: bool = False + ): self.flex_index = flex_index self.num_results = num_results self.batch_size = batch_size or 4096 @@ -53,7 +60,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame: class NumpyVectorLoader(pt.Transformer): - def __init__(self, flex_index): + def __init__(self, flex_index: FlexIndex): self.flex_index = flex_index def transform(self, inp: pd.DataFrame) -> pd.DataFrame: @@ -63,7 +70,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame: class NumpyScorer(pt.Transformer): - def __init__(self, flex_index, num_results=None): + def __init__(self, flex_index: FlexIndex, *, num_results: Optional[int] = None): self.flex_index = flex_index self.num_results = num_results @@ -79,7 +86,7 @@ def score(self, query_vecs, docids): else: raise ValueError(f'{self.flex_index.sim_fn} not supported') - def transform(self, inp): + def transform(self, inp: pd.DataFrame) -> pd.DataFrame: with pta.validate.any(inp) as v: v.columns(includes=['query_vec', 'docno']) v.columns(includes=['query_vec', 'docid']) @@ -108,19 +115,63 @@ def transform(self, inp): res = res.assign(score=res_scores, rank=res_ranks) return res +def _np_vecs(self) -> np.ndarray: + """Return the indexed vectors. + + Returns: + :class:`numpy.ndarray`: The indexed vectors as a memory-mapped numpy array. + """ + dvecs, meta = self.payload(return_docnos=False) + return dvecs +FlexIndex.np_vecs = _np_vecs + +def _np_retriever(self, *, num_results: int = 1000, batch_size: Optional[int] = None, drop_query_vec: bool = False): + """Return a retriever that uses numpy to perform a brute force search over the index. -def _np_retriever(self, num_results=1000, batch_size=None, drop_query_vec=False): + The returned transformer expects a DataFrame with columns ``qid`` and ``query_vec``. It outpus + a result frame containing the retrieved results. + + Args: + num_results: The number of results to return per query. + batch_size: The number of documents to score in each batch. + drop_query_vec: Whether to drop the query vector from the output. + + Returns: + :class:`~pyterrier.Transformer`: A retriever that uses numpy to perform a brute force search. + """ return NumpyRetriever(self, num_results=num_results, batch_size=batch_size, drop_query_vec=drop_query_vec) FlexIndex.np_retriever = _np_retriever +FlexIndex.retriever = _np_retriever # default retriever def _np_vec_loader(self): + """Return a transformer that loads indexed vectors. + + The returned transformer expects a DataFrame with columns ``docno``. It outputs a frame that + includes a column ``doc_vec``, which contains the indexed vectors. + + Returns: + :class:`~pyterrier.Transformer`: A transformer that loads indexed vectors. + """ return NumpyVectorLoader(self) FlexIndex.np_vec_loader = _np_vec_loader FlexIndex.vec_loader = _np_vec_loader # default vec_loader -def _np_scorer(self, num_results=None): - return NumpyScorer(self, num_results) +def _np_scorer(self, *, num_results: Optional[int] = None) -> pt.Transformer: + """Return a scorer that uses numpy to score (re-rank) results using indexed vectors. + + The returned transformer expects a DataFrame with columns ``qid``, ``query_vec`` and ``docno``. + (If an internal ``docid`` column is provided, this will be used to speed up vector lookups.) + + This method uses memory-mapping to avoid loading the entire index into memory at once. + + Args: + num_results: The number of results to return per query. If not provided, all resuls from the original fram are returned. + + Returns: + :class:`~pyterrier.Transformer`: A transformer that scores query vectors with numpy. + """ + return NumpyScorer(self, num_results=num_results) FlexIndex.np_scorer = _np_scorer FlexIndex.scorer = _np_scorer # default scorer diff --git a/pyterrier_dr/flex/scann_retr.py b/pyterrier_dr/flex/scann_retr.py index e55eb9a..025e3b2 100644 --- a/pyterrier_dr/flex/scann_retr.py +++ b/pyterrier_dr/flex/scann_retr.py @@ -1,3 +1,4 @@ +from typing import Optional import math import os import pyterrier as pt @@ -11,10 +12,11 @@ class ScannRetriever(pt.Indexer): - def __init__(self, flex_index, scann_index, leaves_to_search=None, qbatch=64, drop_query_vec=False): + def __init__(self, flex_index, scann_index, num_results=1000, leaves_to_search=None, qbatch=64, drop_query_vec=False): self.flex_index = flex_index self.scann_index = scann_index self.leaves_to_search = leaves_to_search + self.num_results = num_results self.qbatch = qbatch self.drop_query_vec = drop_query_vec @@ -29,7 +31,7 @@ def transform(self, inp): num_q = query_vecs.shape[0] QBATCH = self.qbatch for qidx in range(0, num_q, QBATCH): - dids, scores = self.scann_index.search_batched(query_vecs[qidx:qidx+QBATCH], leaves_to_search=self.leaves_to_search, final_num_neighbors=self.flex_index.num_results) + dids, scores = self.scann_index.search_batched(query_vecs[qidx:qidx+QBATCH], leaves_to_search=self.leaves_to_search, final_num_neighbors=self.num_results) for s, d in zip(scores, dids): mask = d != -1 d = d[mask] @@ -46,7 +48,31 @@ def transform(self, inp): return result.to_df(inp) -def _scann_retriever(self, n_leaves=None, leaves_to_search=1, train_sample=None, drop_query_vec=False): +def _scann_retriever(self, + *, + n_leaves: Optional[int] = None, + leaves_to_search: int = 1, + num_results: int = 1000, + train_sample: Optional[int] = None, + drop_query_vec=False +): + """Returns a retriever over a ScaNN (Scalable Nearest Neighbors) index. + + Args: + n_leaves (int, optional): Number of leaves in the ScaNN index. Defaults to approximatley sqrt(doc_count). + leaves_to_search (int, optional): Number of leaves to search. Defaults to 1. The higher the value, the more accurate the search. + num_results (int, optional): Number of results to return. Defaults to 1000. + train_sample (int, optional): Number of training samples. Defaults to ``n_leaves*39``. + drop_query_vec (bool, optional): Whether to drop the query vector from the output. + + Returns: + :class:`~pyterrier.Transformer`: A transformer that retrieves using ScaNN. + + .. note:: + This method requires the ``scann`` package. Install it via ``pip install scann``. + + .. cite.dblp:: conf/icml/GuoSLGSCK20 + """ pyterrier_dr.util.assert_scann() import scann dvecs, meta, = self.payload(return_docnos=False) @@ -79,5 +105,5 @@ def _scann_retriever(self, n_leaves=None, leaves_to_search=1, train_sample=None, else: with logger.duration('reading index'): self._cache[key] = scann.scann_ops_pybind.load_searcher(dvecs, str(self.index_path/index_name)) - return ScannRetriever(self, self._cache[key], leaves_to_search=leaves_to_search, drop_query_vec=drop_query_vec) + return ScannRetriever(self, self._cache[key], num_results=num_results, leaves_to_search=leaves_to_search, drop_query_vec=drop_query_vec) FlexIndex.scann_retriever = _scann_retriever diff --git a/pyterrier_dr/flex/torch_retr.py b/pyterrier_dr/flex/torch_retr.py index da3e133..589b151 100644 --- a/pyterrier_dr/flex/torch_retr.py +++ b/pyterrier_dr/flex/torch_retr.py @@ -1,3 +1,4 @@ +from typing import Optional import numpy as np import torch import pyterrier_alpha as pta @@ -8,7 +9,12 @@ class TorchScorer(NumpyScorer): - def __init__(self, flex_index, torch_vecs, num_results=None): + def __init__(self, + flex_index: FlexIndex, + torch_vecs: torch.Tensor, + *, + num_results: Optional[int] = None + ): self.flex_index = flex_index self.torch_vecs = torch_vecs self.num_results = num_results @@ -30,7 +36,14 @@ def score(self, query_vecs, docids): class TorchRetriever(pt.Transformer): - def __init__(self, flex_index, torch_vecs, num_results=None, qbatch=64, drop_query_vec=False): + def __init__(self, + flex_index: FlexIndex, + torch_vecs: torch.Tensor, + *, + num_results: int = 1000, + qbatch: int = 64, + drop_query_vec: bool = False + ): self.flex_index = flex_index self.torch_vecs = torch_vecs self.num_results = num_results or 1000 @@ -74,7 +87,21 @@ def transform(self, inp): return result.to_df(inp) -def _torch_vecs(self, device=None, fp16=False): +def _torch_vecs(self, *, device: Optional[str] = None, fp16: bool = False) -> torch.Tensor: + """Return the indexed vectors as a pytorch tensor. + + .. caution:: + This method loads the entire index into memory on the provided device. If the index is too large to fit in memory, + consider using a different method that does not fully load the index into memory, like :meth:`np_vecs` or + :meth:`get_corpus_iter`. + + Args: + device: The device to use for the tensor. If not provided, the default device is used (cuda if available, otherwise cpu). + fp16: Whether to use half precision (fp16) for the tensor. + + Returns: + :class:`torch.Tensor`: The indexed vectors as a torch tensor. + """ device = infer_device(device) key = ('torch_vecs', device, fp16) if key not in self._cache: @@ -87,11 +114,53 @@ def _torch_vecs(self, device=None, fp16=False): FlexIndex.torch_vecs = _torch_vecs -def _torch_scorer(self, num_results=None, device=None, fp16=False): +def _torch_scorer(self, *, num_results: Optional[int] = None, device: Optional[str] = None, fp16: bool = False): + """Return a scorer that uses pytorch to score (re-rank) results using indexed vectors. + + The returned :class:`pyterrier.Transformer` expects a DataFrame with columns ``qid``, ``query_vec`` and ``docno``. + (If an internal ``docid`` column is provided, this will be used to speed up vector lookups.) + + .. caution:: + This method loads the entire index into memory on the provided device. If the index is too large to fit in memory, + consider using a different scorer that does not fully load the index into memory, like :meth:`np_scorer`. + + Args: + num_results: The number of results to return per query. If not provided, all resuls from the original fram are returned. + device: The device to use for scoring. If not provided, the default device is used (cuda if available, otherwise cpu). + fp16: Whether to use half precision (fp16) for scoring. + + Returns: + :class:`~pyterrier.Transformer`: A transformer that scores query vectors with pytorch. + """ return TorchScorer(self, self.torch_vecs(device=device, fp16=fp16), num_results=num_results) FlexIndex.torch_scorer = _torch_scorer -def _torch_retriever(self, num_results=None, device=None, fp16=False, qbatch=64, drop_query_vec=False): +def _torch_retriever(self, + *, + num_results: int = 1000, + device: Optional[str] = None, + fp16: bool = False, + qbatch: int = 64, + drop_query_vec: bool = False +): + """Return a retriever that uses pytorch to perform brute-force retrieval results using the indexed vectors. + + The returned :class:`pyterrier.Transformer` expects a DataFrame with columns ``qid``, ``query_vec``. + + .. caution:: + This method loads the entire index into memory on the provided device. If the index is too large to fit in memory, + consider using a different retriever that does not fully load the index into memory, like :meth:`np_retriever`. + + Args: + num_results: The number of results to return per query. + device: The device to use for scoring. If not provided, the default device is used (cuda if available, otherwise cpu). + fp16: Whether to use half precision (fp16) for scoring. + qbatch: The number of queries to score in each batch. + drop_query_vec: Whether to drop the query vector from the output. + + Returns: + :class:`~pyterrier.Transformer`: A transformer that retrieves using pytorch. + """ return TorchRetriever(self, self.torch_vecs(device=device, fp16=fp16), num_results=num_results, qbatch=qbatch, drop_query_vec=drop_query_vec) FlexIndex.torch_retriever = _torch_retriever diff --git a/pyterrier_dr/flex/voyager_retr.py b/pyterrier_dr/flex/voyager_retr.py index 9bfa07e..785812f 100644 --- a/pyterrier_dr/flex/voyager_retr.py +++ b/pyterrier_dr/flex/voyager_retr.py @@ -10,10 +10,11 @@ class VoyagerRetriever(pt.Indexer): - def __init__(self, flex_index, voyager_index, query_ef=None, qbatch=64, drop_query_vec=False): + def __init__(self, flex_index, voyager_index, query_ef=None, num_results=1000, qbatch=64, drop_query_vec=False): self.flex_index = flex_index self.voyager_index = voyager_index self.query_ef = query_ef + self.num_results = num_results self.qbatch = qbatch self.drop_query_vec = drop_query_vec @@ -32,7 +33,7 @@ def transform(self, inp): it = pt.tqdm(it, unit='qbatch') for qidx in it: qvec_batch = query_vecs[qidx:qidx+QBATCH] - neighbor_ids, distances = self.voyager_index.query(qvec_batch, self.flex_index.num_results, self.query_ef) + neighbor_ids, distances = self.voyager_index.query(qvec_batch, self.num_results, self.query_ef) for s, d in zip(distances, neighbor_ids): mask = d != -1 d = d[mask] @@ -49,7 +50,33 @@ def transform(self, inp): return result.to_df(inp) -def _voyager_retriever(self, neighbours=12, ef_construction=200, random_seed=1, storage_data_type='float32', query_ef=10, drop_query_vec=False): +def _voyager_retriever(self, + neighbours: int = 12, + *, + num_results: int = 1000, + ef_construction: int = 200, + random_seed: int = 1, + storage_data_type: str = 'float32', + query_ef: int = 10, + drop_query_vec: bool = False +) -> pt.Transformer: + """Returns a retriever that uses HNSW to search over a Voyager index. + + Args: + neighbours (int, optional): Number of neighbours to search. Defaults to 12. + num_results (int, optional): Number of results to return per query. Defaults to 1000. + ef_construction (int, optional): Expansion factor for graph construction. Defaults to 200. + random_seed (int, optional): Random seed. Defaults to 1. + storage_data_type (str, optional): Storage data type. One of 'float32', 'float8', 'e4m3'. Defaults to 'float32'. + query_ef (int, optional): Expansion factor during querying. Defaults to 10. + drop_query_vec (bool, optional): Drop the query vector from the output. Defaults to False. + + Returns: + :class:`~pyterrier.Transformer`: A retriever that uses HNSW to search over a Voyager index. + + .. note:: + This method requires the ``voyager`` package. Install it via ``pip install voyager``. + """ pyterrier_dr.util.assert_voyager() import voyager meta, = self.payload(return_dvecs=False, return_docnos=False) diff --git a/pyterrier_dr/hgf_models.py b/pyterrier_dr/hgf_models.py index 280e80e..e7d04f2 100644 --- a/pyterrier_dr/hgf_models.py +++ b/pyterrier_dr/hgf_models.py @@ -8,7 +8,7 @@ class HgfBiEncoder(BiEncoder): def __init__(self, model, tokenizer, config, batch_size=32, text_field='text', verbose=False, device=None): - super().__init__(batch_size, text_field, verbose) + super().__init__(batch_size=batch_size, text_field=text_field, verbose=verbose) if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = torch.device(device) @@ -56,11 +56,11 @@ def __repr__(self): class _HgfBiEncoder(HgfBiEncoder, metaclass=Variants): VARIANTS: dict = None - def __init__(self, model_name, batch_size=32, text_field='text', verbose=False, device=None): - self.model_name = model_name - model = AutoModel.from_pretrained(model_name) - tokenizer = AutoTokenizer.from_pretrained(model_name) - config = AutoConfig.from_pretrained(model_name) + def __init__(self, model_name=None, batch_size=32, text_field='text', verbose=False, device=None): + self.model_name = model_name or next(iter(self.VARIANTS.values())) + model = AutoModel.from_pretrained(self.model_name) + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + config = AutoConfig.from_pretrained(self.model_name) super().__init__(model, tokenizer, config, batch_size=batch_size, text_field=text_field, verbose=verbose, device=device) def __repr__(self): @@ -71,15 +71,31 @@ def __repr__(self): class TasB(_HgfBiEncoder): + """Dense encoder for TAS-B (Topic Aware Sampling, Balanced). + + See :class:`~pyterrier_dr.BiEncoder` for usage information. + + .. cite.dblp:: conf/sigir/HofstatterLYLH21 + + .. automethod:: dot() + """ VARIANTS = { 'dot': 'sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco', } class RetroMAE(_HgfBiEncoder): + """Dense encoder for RetroMAE (Masked Auto-Encoder). + + See :class:`~pyterrier_dr.BiEncoder` for usage information. + + .. cite.dblp:: conf/emnlp/XiaoLSC22 + + .. automethod:: msmarco_finetune() + .. automethod:: msmarco_distill() + .. automethod:: wiki_bookscorpus_beir() + """ VARIANTS = { - #'wiki_bookscorpus': 'Shitao/RetroMAE', # only pre-trained - #'msmarco': 'Shitao/RetroMAE_MSMARCO', # only pre-trained 'msmarco_finetune': 'Shitao/RetroMAE_MSMARCO_finetune', 'msmarco_distill': 'Shitao/RetroMAE_MSMARCO_distill', 'wiki_bookscorpus_beir': 'Shitao/RetroMAE_BEIR', diff --git a/pyterrier_dr/prf.py b/pyterrier_dr/prf.py index 95aa34e..00fa87e 100644 --- a/pyterrier_dr/prf.py +++ b/pyterrier_dr/prf.py @@ -22,28 +22,7 @@ class VectorPrf(pt.Transformer): prf_pipe = model >> index >> index.vec_loader() >> pyterier_dr.vector_prf() >> index - .. code-block:: bibtex - :caption: Citation - - @article{DBLP:journals/tois/0009MZKZ23, - author = {Hang Li and - Ahmed Mourad and - Shengyao Zhuang and - Bevan Koopman and - Guido Zuccon}, - title = {Pseudo Relevance Feedback with Deep Language Models and Dense Retrievers: - Successes and Pitfalls}, - journal = {{ACM} Trans. Inf. Syst.}, - volume = {41}, - number = {3}, - pages = {62:1--62:40}, - year = {2023}, - url = {https://doi.org/10.1145/3570724}, - doi = {10.1145/3570724}, - timestamp = {Fri, 21 Jul 2023 22:26:51 +0200}, - biburl = {https://dblp.org/rec/journals/tois/0009MZKZ23.bib}, - bibsource = {dblp computer science bibliography, https://dblp.org} - } + .. cite.dblp:: journals/tois/0009MZKZ23 """ def __init__(self, *, @@ -89,28 +68,7 @@ class AveragePrf(pt.Transformer): prf_pipe = model >> index >> index.vec_loader() >> pyterier_dr.average_prf() >> index - .. code-block:: bibtex - :caption: Citation - - @article{DBLP:journals/tois/0009MZKZ23, - author = {Hang Li and - Ahmed Mourad and - Shengyao Zhuang and - Bevan Koopman and - Guido Zuccon}, - title = {Pseudo Relevance Feedback with Deep Language Models and Dense Retrievers: - Successes and Pitfalls}, - journal = {{ACM} Trans. Inf. Syst.}, - volume = {41}, - number = {3}, - pages = {62:1--62:40}, - year = {2023}, - url = {https://doi.org/10.1145/3570724}, - doi = {10.1145/3570724}, - timestamp = {Fri, 21 Jul 2023 22:26:51 +0200}, - biburl = {https://dblp.org/rec/journals/tois/0009MZKZ23.bib}, - bibsource = {dblp computer science bibliography, https://dblp.org} - } + .. cite.dblp:: journals/tois/0009MZKZ23 """ def __init__(self, *, diff --git a/pyterrier_dr/pt_docs/encoding.rst b/pyterrier_dr/pt_docs/encoding.rst new file mode 100644 index 0000000..d45113e --- /dev/null +++ b/pyterrier_dr/pt_docs/encoding.rst @@ -0,0 +1,31 @@ +Encoding +================================================================== + +Sentence Transformers +------------------------------------------------------------------ + +With pyterrier_dr, its easy to support Sentence Transformer (formerly called SentenceBERT) +models, e.g. from HuggingFace, for dense retrieval. + +The base class is ``SBertBiEncoder('huggingface/path')``. + +Pretrained Encoders +------------------------------------------------------------------ + +These classes are convenience aliases to popular dense encoding models. + +.. autoclass:: pyterrier_dr.Ance() +.. autoclass:: pyterrier_dr.BGEM3() +.. autoclass:: pyterrier_dr.CDE() +.. autoclass:: pyterrier_dr.E5() +.. autoclass:: pyterrier_dr.GTR() +.. autoclass:: pyterrier_dr.Query2Query() +.. autoclass:: pyterrier_dr.RetroMAE() +.. autoclass:: pyterrier_dr.TasB() +.. autoclass:: pyterrier_dr.TctColBert() + +API Documentation +------------------------------------------------------------------ + +.. autoclass:: pyterrier_dr.BiEncoder + :members: diff --git a/pyterrier_dr/pt_docs/index.rst b/pyterrier_dr/pt_docs/index.rst index 92328c2..7e8dd8a 100644 --- a/pyterrier_dr/pt_docs/index.rst +++ b/pyterrier_dr/pt_docs/index.rst @@ -1,45 +1,21 @@ Dense Retrieval for PyTerrier ======================================================= -Features to support Dense Retrieval in `PyTerrier `__. +`pyterrier-dr `__ is a PyTerrier plugin +that provides functionality for Dense Retrieval. -.. rubric:: Getting Started +It provides this functionality primarily through: -.. code-block:: console - :caption: Install ``pyterrier-dr`` with ``pip`` +1. Transformers for :doc:`encoding queries/documents <./encoding>` into dense vectors (e.g., :class:`~pyterrier_dr.SBertBiEncoder`) - $ pip install pyterrier-dr +2. Transformers for :doc:`indexing and retrieval <./indexing-retrieval>` using these dense vectors (e.g., :class:`~pyterrier_dr.FlexIndex`) -Import ``pyterrier_dr``, load a pre-built index and model, and retrieve: - -.. code-block:: python - :caption: Basic example of using ``pyterrier_dr`` - - >>> from pyterrier_dr import FlexIndex, TasB - - >>> index = FlexIndex.from_hf('macavaney/vaswani.tasb.flex') - >>> model = TasB('sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco') - >>> pipeline = model.query_encoder() >> index.np_retriever() - >>> pipeline.search('chemical reactions') - - score docno docid rank qid query - 0 95.841721 7049 7048 0 1 chemical reactions - 1 94.669395 9374 9373 1 1 chemical reactions - 2 93.520027 3101 3100 2 1 chemical reactions - 3 92.809227 6480 6479 3 1 chemical reactions - 4 92.376190 3452 3451 4 1 chemical reactions - .. ... ... ... ... .. ... - 995 82.554390 7701 7700 995 1 chemical reactions - 996 82.552139 1553 1552 996 1 chemical reactions - 997 82.551933 10064 10063 997 1 chemical reactions - 998 82.546890 4417 4416 998 1 chemical reactions - 999 82.545776 7120 7119 999 1 chemical reactions - - -.. rubric:: Table of Contents +This functionality is covered in more detail in the following pages: .. toctree:: :maxdepth: 1 + overview + encoding + indexing-retrieval prf - sbert diff --git a/pyterrier_dr/pt_docs/indexing-retrieval.rst b/pyterrier_dr/pt_docs/indexing-retrieval.rst new file mode 100644 index 0000000..a9e29fe --- /dev/null +++ b/pyterrier_dr/pt_docs/indexing-retrieval.rst @@ -0,0 +1,91 @@ +Indexing & Retrieval +===================================================== + +This page covers the indexing and retrieval functionality provided by ``pyterrier_dr``. + +:class:`~pyterrier_dr.FlexIndex` provides a flexible way to index and retrieve documents +using dense vectors, and is the main class for indexing and retrieval. + +API Documentation +----------------------------------------------------- + +.. autoclass:: pyterrier_dr.FlexIndex + :show-inheritance: + + Indexing + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + Basic indexing functionality is provided through :meth:`index`. For more advanced options, use :meth:`indexer`. + + .. automethod:: index + .. automethod:: indexer + + Retrieval + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + ``FlexIndex`` provides a variety of retriever backends. Each one expects ``qid`` and ``query_vec`` columns + as input, and outputs a result frame. When you do not care about which backend you want, you can use + :meth:`retriever` (an alias to :meth:`np_retriever`), which preforms exact retrieval using a brute force search + over all vectors. + + .. py:method:: retriever(*, num_results=1000) + + Returns a transformer that performs basic exact retrieval over indexed vectors using a brute force search. An alias to :meth:`np_retriever`. + + .. automethod:: np_retriever + .. automethod:: torch_retriever + .. automethod:: faiss_flat_retriever + .. automethod:: faiss_hnsw_retriever + .. automethod:: faiss_ivf_retriever + .. automethod:: scann_retriever + .. automethod:: voyager_retriever + + Re-Ranking + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + Results can be re-ranked using indexed vectors using :meth:`scorer`. (:meth:`np_scorer` and :meth:`torch_scorer` are + available as specific implemenations, if needed.) + + :meth:`gar`, :meth:`ladr_proactive`, and :meth:`ladr_adaptive` are *adaptive* re-ranking approaches that pull in other + documents from the corpus that may be relevant. + + .. py:method:: scorer + + An alias to :meth:`np_scorer`. + + .. automethod:: np_scorer + .. automethod:: torch_scorer + .. automethod:: gar + .. automethod:: ladr_proactive + .. automethod:: ladr_adaptive + + Index Data Access + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + These methods are for low-level index data access. + + .. automethod:: vec_loader + .. automethod:: get_corpus_iter + .. automethod:: np_vecs + .. automethod:: torch_vecs + .. automethod:: docnos + .. automethod:: corpus_graph + .. automethod:: faiss_hnsw_graph + + Extras + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + .. automethod:: built + .. py:classmethod:: from_hf(repo) + + Loads the index from HuggingFace Hub. + + :param repo: The repository name download from. + + :returns: A :class:`~pyterrier_dr.FlexIndex` object. + + .. py:method:: to_hf(repo) + + Uploads the index to HuggingFace Hub. + + :param repo: The repository name to upload to. diff --git a/pyterrier_dr/pt_docs/overview.rst b/pyterrier_dr/pt_docs/overview.rst new file mode 100644 index 0000000..4cf34a5 --- /dev/null +++ b/pyterrier_dr/pt_docs/overview.rst @@ -0,0 +1,138 @@ +Overview +======================================================= + +Installation +------------------------------------------------------- + +``pyterrier-dr`` can be installed with ``pip``. + +.. code-block:: console + :caption: Install ``pyterrier-dr`` with ``pip`` + + $ pip install pyterrier-dr + +.. hint:: + + Some functionality requires the installation ot other software packages. For instance, to retrieve using + `FAISS `__ (e.g., using :meth:`~pyterrier_dr.FlexIndex.faiss_hnsw_retriever`), + you will need to install the FAISS package: + + .. code-block:: bash + :caption: Install FAISS with ``pip`` or ``conda`` + + pip install faiss-cpu + # or with conda: + conda install -c pytorch faiss-cpu + # or with GPU support: + conda install -c pytorch faiss-gpu + +Basic Usage +------------------------------------------------------- + +Dense Retrieval consists of two main components: (1) a model that encodes content as dense vectors, +and (2) algorithms and data structures to index and retrieve documents using these dense vectors. + +Encoding +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +(More information can be found at :doc:`encoding`.) + +Let's start by loading a dense model: `RetroMAE `__. The model has several +checkpoints available on huggingface, including ``Shitao/RetroMAE_MSMARCO_distill``. +``pyterrier_dr`` provides an alias to this checkpoint with :meth:`RetroMAE.msmarco_distill() `:[#]_ + +.. code-block:: python + :caption: Loading a dense model with ``pyterrier_dr`` + + >>> from pyterrier_dr import RetroMAE + >>> model = RetroMAE.msmarco_distill() + +Dense models model acts as transformers that can encode queries and documents into dense vectors. For example: + +.. code-block:: python + :caption: Encode queries and documents with a dense model + + >>> import pandas as pd + >>> model(pd.DataFrame([ + ... {"qid": "0", "query": "hello terrier"}, + ... {"qid": "1", "query": "information retrieval"}, + ... {"qid": "2", "query": "chemical reactions"}, + ... ])) + qid query query_vec + 0 hello terrier [ 0.26, -0.17, 0.49, -0.12, ...] + 1 information retrieval [-0.49, 0.16, 0.24, 0.38, ...] + 2 chemical reactions [ 0.19, 0.11, -0.08, -0.00, ...] + + >>> model(pd.DataFrame([ + ... {"docno": "1161848_2", "text": "Cutest breed of dog is a PBGV (look up on Internet) they are a little hound that looks like a shaggy terrier."}, + ... {"docno": "686980_0", "text": "Golden retriever has longer hair and is a little heavier."}, + ... {"docno": "4189224_1", "text": "The onion releases a chemical that makes your eyes water up. I mean, no way short of wearing a mask or just avoiding the sting."}, + ... ])) + docno text doc_vec + 1161848_2 Cutest breed of dog is a PBGV... [0.03, -0.17, 0.18, -0.03, ...] + 686980_0 Golden retriever has longer h... [0.14, -0.20, 0.00, 0.34, ...] + 4189224_1 The onion releases a chemical... [0.16, 0.03, 0.49, -0.41, ...] + +``query_vec`` and ``doc_vec`` are dense vectors that represent the query and document, respectively. In the +next section, we will use these vectors to perform retrieval. + +Indexing and Retrieval +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +(More information can be found at :doc:`indexing-retrieval`.) + +:class:`pyterrier_dr.FlexIndex` provides dense indexing and retrieval capabilities. Here's how you can index +a collection of documents: + +.. code-block:: python + :caption: Indexing documents with ``pyterrier_dr`` + + >>> from pyterrier_dr import FlexIndex, RetroMAE + >>> model = RetroMAE.msmarco_distill() + >>> index = FlexIndex('my-index.flex') + # build an indexing pipeline that first applies RetroMAE to get dense vectors, then indexes them into the FlexIndex + >>> pipeline = model >> index.indexer() + # run the indexing pipeline over a set of documents + >>> pipeline.index([ + ... {"docno": "1161848_2", "text": "Cutest breed of dog is a PBGV (look up on Internet) they are a little hound that looks like a shaggy terrier."}, + ... {"docno": "686980_0", "text": "Golden retriever has longer hair and is a little heavier."}, + ... {"docno": "4189224_1", "text": "The onion releases a chemical that makes your eyes water up. I mean, no way short of wearing a mask or just avoiding the sting."}, + ... ]) + +Now that the documents are indexed, you can retrieve over them: + +.. code-block:: python + :caption: Retrieving with ``pyterrier_dr`` + + >>> from pyterrier_dr import FlexIndex, RetroMAE + >>> model = RetroMAE.msmarco_distill() + >>> index = FlexIndex('my-index.flex') + # build a retrieval pipeline that first applies RetroMAE to encode the query, then retrieves using those vectors over the FlexIndex + >>> pipeline = model >> index.retriever() + # run the indexing pipeline over a set of documents + >>> pipeline.search('golden retrievers') + qid query docno docid score rank + 0 1 golden retrievers 686980_0 1 77.125557 0 + 1 1 golden retrievers 1161848_2 0 61.379417 1 + 2 1 golden retrievers 4189224_1 2 54.269958 2 + +Extras +------------------------------------------------------- + +#. You can load models from the wonderful `Sentence Transformers `__ library directly + using :class:`~pyterrier_dr.SBertBiEncoder`. + +#. Dense indexing is the most common way to use dense models. But you can also score + any pair of text using a dense model using :meth:`BiEncoder.text_scorer() `. + +#. Re-ranking can often yield better trade-offs between effectiveness and efficiency than doing dense retrieval. + You can build a re-ranking pipeline with :meth:`FlexIndex.scorer() `. + +#. Dense Pseudo-Relevance Feedback (PRF) is a technique to improve the performance of a retrieval system by expanding + the original query vector with the vectors from the top-ranked documents. Check out more :doc:`here `. + +------------------------------------------------------- + +.. [#] You can also load the model from HuggingFace with :class:`~pyterrier_dr.HgfBiEncoder`: + ``HgfBiEncoder("Shitao/RetroMAE_MSMARCO_distill")``. Using the alias will ensure that all settings for + the model are assigned properly. diff --git a/pyterrier_dr/pt_docs/prf.rst b/pyterrier_dr/pt_docs/prf.rst index 78d9061..0293ee1 100644 --- a/pyterrier_dr/pt_docs/prf.rst +++ b/pyterrier_dr/pt_docs/prf.rst @@ -1,10 +1,10 @@ -Pseudo Relevance Feedback (PRF) +Pseudo-Relevance Feedback =============================== Dense Pseudo Relevance Feedback (PRF) is a technique to improve the performance of a retrieval system by expanding the original query vector with the vectors from the top-ranked documents. The idea is that the top-ranked documents. -PyTerrier-DR provides two dense PRF implementations: :class:`pyterrier_dr.AveragePrf` and :class:`pyterrier_dr.VectorPrf`. +PyTerrier-DR provides two dense PRF implementations: :class:`~pyterrier_dr.AveragePrf` and :class:`~pyterrier_dr.VectorPrf`. API Documentation ----------------- diff --git a/pyterrier_dr/pt_docs/sbert.rst b/pyterrier_dr/pt_docs/sbert.rst deleted file mode 100644 index fb32ea7..0000000 --- a/pyterrier_dr/pt_docs/sbert.rst +++ /dev/null @@ -1,24 +0,0 @@ -Using Sentence Transformer models for Dense retrieval in PyTerrier -================================================================== - -With PyTerrier_DR, its easy to support Sentence Transformer (formerly called SentenceBERT) -models, e.g. from HuggingFace, for dense retrieval. - -The base class is ``SBertBiEncoder('huggingface/path')``; - -There are easy to remember classes for a number of standard models: - - ANCE - an early single-representation dense retrieval model: ``Ance.firstp()`` - - GTR - a dense retrieval model based on the T5 pre-trained encoder: ``GTR.base()`` - - `E5 `: ``E5.base()`` - - Query2Query - a query similarity model: Query2Query() - -The standard pyterrier_dr pipelines can be used: - -Indexing:: - model = pyterrier_dr.GTR.base() - index = pyterrier_dr.FlexIndex('gtr.flex') - pipe = (model >> index) - pipe.index(pt.get_dataset('irds:msmarco-passage').get_corpus_iter()) - -Retrieval:: - pipe.search("chemical reactions") diff --git a/pyterrier_dr/sbert_models.py b/pyterrier_dr/sbert_models.py index 00ad8be..711c8b7 100644 --- a/pyterrier_dr/sbert_models.py +++ b/pyterrier_dr/sbert_models.py @@ -26,7 +26,7 @@ def _sbert_encode(self, texts, batch_size=None, prompt=None, normalize_embedding class SBertBiEncoder(BiEncoder): def __init__(self, model_name, batch_size=32, text_field='text', verbose=False, device=None): - super().__init__(batch_size, text_field, verbose) + super().__init__(batch_size=batch_size, text_field=text_field, verbose=verbose) self.model_name = model_name if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -55,11 +55,29 @@ def __repr__(self): class Ance(_SBertBiEncoder): + """Dense encoder for ANCE (Approximate nearest neighbor Negative Contrastive Learning). + + See :class:`~pyterrier_dr.BiEncoder` for usage information. + + .. cite.dblp:: conf/iclr/XiongXLTLBAO21 + + .. automethod:: firstp() + """ VARIANTS = { 'firstp': 'sentence-transformers/msmarco-roberta-base-ance-firstp', } class E5(_SBertBiEncoder): + """Dense encoder for E5 (EmbEddings from bidirEctional Encoder rEpresentations). + + See :class:`~pyterrier_dr.BiEncoder` for usage information. + + .. cite.dblp:: journals/corr/abs-2212-03533 + + .. automethod:: base() + .. automethod:: small() + .. automethod:: large() + """ encode_queries = partialmethod(_sbert_encode, prompt='query: ', normalize_embeddings=True) encode_docs = partialmethod(_sbert_encode, prompt='passage: ', normalize_embeddings=True) @@ -71,6 +89,17 @@ class E5(_SBertBiEncoder): } class GTR(_SBertBiEncoder): + """Dense encoder for GTR (Generalizable T5-based dense Retrievers) + + See :class:`~pyterrier_dr.BiEncoder` for usage information. + + .. cite.dblp:: conf/emnlp/Ni0LDAMZLHCY22 + + .. automethod:: base() + .. automethod:: large() + .. automethod:: xl() + .. automethod:: xxl() + """ VARIANTS = { 'base': 'sentence-transformers/gtr-t5-base', 'large': 'sentence-transformers/gtr-t5-large', @@ -78,18 +107,31 @@ class GTR(_SBertBiEncoder): 'xxl': 'sentence-transformers/gtr-t5-xxl', } -class Query2Query(pt.Transformer): - DEFAULT_MODEL_NAME = 'neeva/query2query' - def __init__(self, model_name=DEFAULT_MODEL_NAME, batch_size=32, verbose=False, device=None): - self.model_name = model_name +class Query2Query(pt.Transformer, metaclass=Variants): + """Dense query encoder model for query similarity. + + Note that this encoder only provides a :meth:`~pyterrier_dr.BiEncoder.query_encoder` (no document encoder or scorer). + + .. cite:: query2query + :citation: Bathwal and Samdani. State-of-the-art Query2Query Similarity. 2022. + :link: https://web.archive.org/web/20220923212754/https://neeva.com/blog/state-of-the-art-query2query-similarity + + .. automethod:: base() + """ + def __init__(self, model_name=None, batch_size=32, verbose=False, device=None): + self.model_name = model_name or next(iter(self.VARIANTS.values())) if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = torch.device(device) from sentence_transformers import SentenceTransformer - self.model = SentenceTransformer(model_name).to(self.device).eval() + self.model = SentenceTransformer(self.model_name).to(self.device).eval() self.batch_size = batch_size self.verbose = verbose + VARIANTS = { + 'base': 'neeva/query2query', + } + encode = _sbert_encode transform = BiQueryEncoder.transform __repr__ = _SBertBiEncoder.__repr__ diff --git a/pyterrier_dr/tctcolbert_model.py b/pyterrier_dr/tctcolbert_model.py index e7f7a11..c202cb7 100644 --- a/pyterrier_dr/tctcolbert_model.py +++ b/pyterrier_dr/tctcolbert_model.py @@ -2,18 +2,29 @@ import numpy as np import torch from transformers import AutoTokenizer, AutoModel +from pyterrier_dr.util import Variants from . import BiEncoder -class TctColBert(BiEncoder): - def __init__(self, model_name='castorini/tct_colbert-msmarco', batch_size=32, text_field='text', verbose=False, device=None): - super().__init__(batch_size, text_field, verbose) - self.model_name = model_name - self.tokenizer = AutoTokenizer.from_pretrained(model_name) +class TctColBert(BiEncoder, metaclass=Variants): + """Dense encoder for TCT-ColBERT (Tightly-Coupled Teachers over ColBERT) + + See :class:`~pyterrier_dr.BiEncoder` for usage information. + + .. cite.dblp:: journals/corr/abs-2010-11386 + + .. automethod:: base() + .. automethod:: hn() + .. automethod:: hnp() + """ + def __init__(self, model_name=None, batch_size=32, text_field='text', verbose=False, device=None): + super().__init__(batch_size=batch_size, text_field=text_field, verbose=verbose) + self.model_name = model_name or next(iter(self.VARIANTS.values())) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = torch.device(device) - self.model = AutoModel.from_pretrained(model_name).to(self.device).eval() + self.model = AutoModel.from_pretrained(self.model_name).to(self.device).eval() def encode_queries(self, texts, batch_size=None): results = [] @@ -47,3 +58,9 @@ def encode_docs(self, texts, batch_size=None): def __repr__(self): return f'TctColBert({repr(self.model_name)})' + + VARIANTS = { + 'base': 'castorini/tct_colbert-msmarco', + 'hn': 'castorini/tct_colbert-v2-hn-msmarco', + 'hnp': 'castorini/tct_colbert-v2-hnp-msmarco', + } diff --git a/pyterrier_dr/util.py b/pyterrier_dr/util.py index fba631b..e4e1635 100644 --- a/pyterrier_dr/util.py +++ b/pyterrier_dr/util.py @@ -10,8 +10,13 @@ class SimFn(Enum): class Variants(type): def __getattr__(cls, name): if name in cls.VARIANTS: + @staticmethod def wrapped(*args, **kwargs): return cls(cls.VARIANTS[name], *args, **kwargs) + wrapped = wrapped.__get__(cls) + wrapped.__doc__ = f"Model: ``{cls.VARIANTS[name]}`` `[link] `__" + if name == next(iter(cls.VARIANTS)): + wrapped.__doc__ = '*(default)* ' + wrapped.__doc__ return wrapped def __init__(self, *args, **kwargs): diff --git a/tests/test_models.py b/tests/test_models.py index ff88b94..ce52426 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -11,13 +11,13 @@ class TestModels(unittest.TestCase): def _base_test(self, model, test_query_encoder=True, test_doc_encoder=True, test_scorer=True, test_indexer=True, test_retriever=True): dataset = pt.get_dataset('irds:vaswani') + topics = dataset.get_topics().head(10) - docs = list(itertools.islice(pt.get_dataset('irds:vaswani').get_corpus_iter(), 200)) + docs = list(itertools.islice(pt.get_dataset('irds:vaswani').get_corpus_iter(), 50)) docs_df = pd.DataFrame(docs) if test_query_encoder: with self.subTest('query_encoder'): - topics = dataset.get_topics() enc_topics = model(topics) self.assertEqual(len(enc_topics), len(topics)) self.assertTrue('query_vec' in enc_topics.columns)