Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

customisable model #48

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions pyterrier_colbert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,53 @@
import torch
from colbert.utils.utils import print_message
from collections import OrderedDict, defaultdict
from colbert.modeling.colbert import ColBERT

DEFAULT_CLASS=ColBERT
DEFAULT_MODEL='bert-base-uncased'

def load_model(args, do_print=True, baseclass=DEFAULT_CLASS, basemodel=DEFAULT_MODEL):
from colbert.parameters import DEVICE
colbert = baseclass.from_pretrained(basemodel,
query_maxlen=args.query_maxlen,
doc_maxlen=args.doc_maxlen,
dim=args.dim,
similarity_metric=args.similarity,
mask_punctuation=args.mask_punctuation)
colbert = colbert.to(DEVICE)

print_message("#> Loading model checkpoint.", condition=do_print)

checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print)

colbert.eval()

return colbert, checkpoint

def load_colbert(args, do_print=True, baseclass=DEFAULT_CLASS, basemodel=DEFAULT_MODEL):
from colbert.utils.runs import Run
import ujson

colbert, checkpoint = load_model(args, do_print, baseclass=baseclass, basemodel=basemodel)

# TODO: If the parameters below were not specified on the command line, their *checkpoint* values should be used.
# I.e., not their purely (i.e., training) default values.

for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']:
if 'arguments' in checkpoint and hasattr(args, k):
if k in checkpoint['arguments'] and checkpoint['arguments'][k] != getattr(args, k):
a, b = checkpoint['arguments'][k], getattr(args, k)
Run.warn(f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})")

if 'arguments' in checkpoint:
if args.rank < 1:
print(ujson.dumps(checkpoint['arguments'], indent=4))

if do_print:
print('\n')

return colbert, checkpoint


def load_checkpoint(path, model, optimizer=None, do_print=True):
if do_print:
Expand Down
27 changes: 13 additions & 14 deletions pyterrier_colbert/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,7 @@
import queue

from colbert.modeling.inference import ModelInference
from colbert.evaluation.loaders import load_colbert
from . import load_checkpoint
# monkeypatch to use our downloading version
import colbert.evaluation.loaders
colbert.evaluation.loaders.load_checkpoint = load_checkpoint
colbert.evaluation.loaders.load_model.__globals__['load_checkpoint'] = load_checkpoint
from . import load_colbert, DEFAULT_MODEL, DEFAULT_CLASS
from colbert.utils.utils import print_message
import pickle
from colbert.indexing.index_manager import IndexManager
Expand All @@ -44,11 +39,13 @@
DEBUG=False

class CollectionEncoder():
def __init__(self, args, process_idx, num_processes):
def __init__(self, args, process_idx, num_processes, baseclass=DEFAULT_CLASS, basemodel=DEFAULT_MODEL):
self.args = args
self.collection = args.collection
self.process_idx = process_idx
self.num_processes = num_processes
self.baseclass = baseclass
self.basemodel = basemodel
self.iterator = self._initialize_iterator()

# Chunksize represents the maximum size of a single .pt file
Expand Down Expand Up @@ -80,7 +77,8 @@ def _saver_thread(self):
self._save_batch(*args)

def _load_model(self):
self.colbert, self.checkpoint = load_colbert(self.args, do_print=(self.process_idx == 0))
self.colbert, self.checkpoint = load_colbert(self.args, do_print=(self.process_idx == 0), baseclass=self.baseclass, basemodel=self.basemodel)
import colbert.parameters
if not colbert.parameters.DEVICE == torch.device("cpu"):
self.colbert = self.colbert.cuda()
self.colbert.eval()
Expand Down Expand Up @@ -229,8 +227,8 @@ class Object(object):

class CollectionEncoder_Generator(CollectionEncoder):

def __init__(self, *args, prepend_title=False):
super().__init__(*args)
def __init__(self, *args, prepend_title=False, **kwargs):
super().__init__(*args, **kwargs)
self.prepend_title = prepend_title

def _initialize_iterator(self):
Expand Down Expand Up @@ -259,7 +257,7 @@ def _preprocess_batch(self, offset, lines):


class ColBERTIndexer(IterDictIndexerBase):
def __init__(self, checkpoint, index_root, index_name, chunksize, prepend_title=False, num_docs=None, ids=True, gpu=True):
def __init__(self, checkpoint, index_root, index_name, chunksize, prepend_title=False, num_docs=None, ids=True, gpu=True, baseclass=DEFAULT_CLASS, basemodel=DEFAULT_MODEL):
args = Object()
args.similarity = 'cosine'
args.dim = 128
Expand Down Expand Up @@ -287,12 +285,13 @@ def __init__(self, checkpoint, index_root, index_name, chunksize, prepend_title=
self.prepend_title = prepend_title
self.num_docs = num_docs
self.gpu = gpu
self.baseclass = baseclass
self.basemodel = basemodel
if not gpu:
warn("Gpu disabled, YMMV")
import colbert.parameters
import colbert.evaluation.load_model
import colbert.modeling.colbert
colbert.parameters.DEVICE = colbert.evaluation.load_model.DEVICE = colbert.modeling.colbert.DEVICE = torch.device("cpu")
colbert.parameters.DEVICE = colbert.modeling.colbert.DEVICE = torch.device("cpu")

assert self.args.slices >= 1
assert self.args.sample is None or (0.0 < self.args.sample <1.0), self.args.sample
Expand Down Expand Up @@ -325,7 +324,7 @@ def convert_gen(iterator):
docid+=1
yield l
self.args.generator = convert_gen(iterator)
ceg = CollectionEncoderIds(self.args,0,1) if self.ids else CollectionEncoder_Generator(self.args,0,1)
ceg = CollectionEncoderIds(self.args,0,1, baseclass=self.baseclass, basemodel=self.basemodel) if self.ids else CollectionEncoder_Generator(self.args,0,1, baseclass=self.baseclass, basemodel=self.basemodel)

create_directory(self.args.index_root)
create_directory(self.args.index_path)
Expand Down
24 changes: 12 additions & 12 deletions pyterrier_colbert/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@
from pyterrier.transformer import TransformerBase
from pyterrier.datasets import Dataset
from typing import Union, Tuple
from colbert.evaluation.load_model import load_model
from . import load_checkpoint
# monkeypatch to use our downloading version
import colbert.evaluation.loaders
colbert.evaluation.loaders.load_checkpoint = load_checkpoint
colbert.evaluation.loaders.load_model.__globals__['load_checkpoint'] = load_checkpoint
from . import load_model, DEFAULT_CLASS, DEFAULT_MODEL
from colbert.modeling.inference import ModelInference
from colbert.evaluation.slow import slow_rerank
from colbert.indexing.loaders import get_parts, load_doclens
Expand Down Expand Up @@ -215,7 +210,10 @@ def our_rerank_with_embeddings_batched(self, qembs, pids, weightsQ=None, gpu=Tru
class ColBERTModelOnlyFactory():

def __init__(self,
colbert_model : Union[str, Tuple[colbert.modeling.colbert.ColBERT, dict]], gpu=True):
colbert_model : Union[str, Tuple[colbert.modeling.colbert.ColBERT, dict]],
gpu=True,
baseclass = DEFAULT_CLASS,
basemodel = DEFAULT_MODEL):
args = Object()
args.query_maxlen = 32
args.doc_maxlen = 180
Expand All @@ -232,19 +230,20 @@ def __init__(self,
if not gpu:

warn("Gpu disabled, YMMV")
import torch
import colbert.parameters
import colbert.evaluation.load_model
import colbert.modeling.colbert
colbert.parameters.DEVICE = colbert.evaluation.load_model.DEVICE = colbert.modeling.colbert.DEVICE = torch.device("cpu")
self.gpu = False
if isinstance (colbert_model, str):
args.checkpoint = colbert_model
args.colbert, args.checkpoint = load_model(args)
args.colbert, args.checkpoint = load_model(args, baseclass=baseclass, basemodel=basemodel)
else:
assert isinstance(colbert_model, tuple)
args.colbert, args.checkpoint = colbert_model
from colbert.modeling.colbert import ColBERT
assert isinstance(args.colbert, ColBERT)
import torch.nn
assert isinstance(args.colbert, torch.nn.Module)
assert isinstance(args.checkpoint, dict)

args.inference = ModelInference(args.colbert, amp=args.amp)
Expand Down Expand Up @@ -500,9 +499,10 @@ def __init__(self,
faiss_partitions=None,#TODO 100-
memtype = "mem",
faisstype= "mem",
gpu=True):
gpu=True,
**kwargs):

super().__init__(colbert_model, gpu=gpu)
super().__init__(colbert_model, gpu=gpu, **kwargs)

self.verbose = False
self._faissnn = None
Expand Down