diff --git a/pyterrier_colbert/__init__.py b/pyterrier_colbert/__init__.py index 961f061..06ebf4c 100644 --- a/pyterrier_colbert/__init__.py +++ b/pyterrier_colbert/__init__.py @@ -7,6 +7,53 @@ import torch from colbert.utils.utils import print_message from collections import OrderedDict, defaultdict +from colbert.modeling.colbert import ColBERT + +DEFAULT_CLASS=ColBERT +DEFAULT_MODEL='bert-base-uncased' + +def load_model(args, do_print=True, baseclass=DEFAULT_CLASS, basemodel=DEFAULT_MODEL): + from colbert.parameters import DEVICE + colbert = baseclass.from_pretrained(basemodel, + query_maxlen=args.query_maxlen, + doc_maxlen=args.doc_maxlen, + dim=args.dim, + similarity_metric=args.similarity, + mask_punctuation=args.mask_punctuation) + colbert = colbert.to(DEVICE) + + print_message("#> Loading model checkpoint.", condition=do_print) + + checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print) + + colbert.eval() + + return colbert, checkpoint + +def load_colbert(args, do_print=True, baseclass=DEFAULT_CLASS, basemodel=DEFAULT_MODEL): + from colbert.utils.runs import Run + import ujson + + colbert, checkpoint = load_model(args, do_print, baseclass=baseclass, basemodel=basemodel) + + # TODO: If the parameters below were not specified on the command line, their *checkpoint* values should be used. + # I.e., not their purely (i.e., training) default values. + + for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']: + if 'arguments' in checkpoint and hasattr(args, k): + if k in checkpoint['arguments'] and checkpoint['arguments'][k] != getattr(args, k): + a, b = checkpoint['arguments'][k], getattr(args, k) + Run.warn(f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})") + + if 'arguments' in checkpoint: + if args.rank < 1: + print(ujson.dumps(checkpoint['arguments'], indent=4)) + + if do_print: + print('\n') + + return colbert, checkpoint + def load_checkpoint(path, model, optimizer=None, do_print=True): if do_print: diff --git a/pyterrier_colbert/indexing.py b/pyterrier_colbert/indexing.py index 5184c3e..ee2cc33 100644 --- a/pyterrier_colbert/indexing.py +++ b/pyterrier_colbert/indexing.py @@ -30,12 +30,7 @@ import queue from colbert.modeling.inference import ModelInference -from colbert.evaluation.loaders import load_colbert -from . import load_checkpoint -# monkeypatch to use our downloading version -import colbert.evaluation.loaders -colbert.evaluation.loaders.load_checkpoint = load_checkpoint -colbert.evaluation.loaders.load_model.__globals__['load_checkpoint'] = load_checkpoint +from . import load_colbert, DEFAULT_MODEL, DEFAULT_CLASS from colbert.utils.utils import print_message import pickle from colbert.indexing.index_manager import IndexManager @@ -44,11 +39,13 @@ DEBUG=False class CollectionEncoder(): - def __init__(self, args, process_idx, num_processes): + def __init__(self, args, process_idx, num_processes, baseclass=DEFAULT_CLASS, basemodel=DEFAULT_MODEL): self.args = args self.collection = args.collection self.process_idx = process_idx self.num_processes = num_processes + self.baseclass = baseclass + self.basemodel = basemodel self.iterator = self._initialize_iterator() # Chunksize represents the maximum size of a single .pt file @@ -80,7 +77,8 @@ def _saver_thread(self): self._save_batch(*args) def _load_model(self): - self.colbert, self.checkpoint = load_colbert(self.args, do_print=(self.process_idx == 0)) + self.colbert, self.checkpoint = load_colbert(self.args, do_print=(self.process_idx == 0), baseclass=self.baseclass, basemodel=self.basemodel) + import colbert.parameters if not colbert.parameters.DEVICE == torch.device("cpu"): self.colbert = self.colbert.cuda() self.colbert.eval() @@ -229,8 +227,8 @@ class Object(object): class CollectionEncoder_Generator(CollectionEncoder): - def __init__(self, *args, prepend_title=False): - super().__init__(*args) + def __init__(self, *args, prepend_title=False, **kwargs): + super().__init__(*args, **kwargs) self.prepend_title = prepend_title def _initialize_iterator(self): @@ -259,7 +257,7 @@ def _preprocess_batch(self, offset, lines): class ColBERTIndexer(IterDictIndexerBase): - def __init__(self, checkpoint, index_root, index_name, chunksize, prepend_title=False, num_docs=None, ids=True, gpu=True): + def __init__(self, checkpoint, index_root, index_name, chunksize, prepend_title=False, num_docs=None, ids=True, gpu=True, baseclass=DEFAULT_CLASS, basemodel=DEFAULT_MODEL): args = Object() args.similarity = 'cosine' args.dim = 128 @@ -287,12 +285,13 @@ def __init__(self, checkpoint, index_root, index_name, chunksize, prepend_title= self.prepend_title = prepend_title self.num_docs = num_docs self.gpu = gpu + self.baseclass = baseclass + self.basemodel = basemodel if not gpu: warn("Gpu disabled, YMMV") import colbert.parameters - import colbert.evaluation.load_model import colbert.modeling.colbert - colbert.parameters.DEVICE = colbert.evaluation.load_model.DEVICE = colbert.modeling.colbert.DEVICE = torch.device("cpu") + colbert.parameters.DEVICE = colbert.modeling.colbert.DEVICE = torch.device("cpu") assert self.args.slices >= 1 assert self.args.sample is None or (0.0 < self.args.sample <1.0), self.args.sample @@ -325,7 +324,7 @@ def convert_gen(iterator): docid+=1 yield l self.args.generator = convert_gen(iterator) - ceg = CollectionEncoderIds(self.args,0,1) if self.ids else CollectionEncoder_Generator(self.args,0,1) + ceg = CollectionEncoderIds(self.args,0,1, baseclass=self.baseclass, basemodel=self.basemodel) if self.ids else CollectionEncoder_Generator(self.args,0,1, baseclass=self.baseclass, basemodel=self.basemodel) create_directory(self.args.index_root) create_directory(self.args.index_path) diff --git a/pyterrier_colbert/ranking.py b/pyterrier_colbert/ranking.py index 2c3e9f5..a14df32 100644 --- a/pyterrier_colbert/ranking.py +++ b/pyterrier_colbert/ranking.py @@ -9,12 +9,7 @@ from pyterrier.transformer import TransformerBase from pyterrier.datasets import Dataset from typing import Union, Tuple -from colbert.evaluation.load_model import load_model -from . import load_checkpoint -# monkeypatch to use our downloading version -import colbert.evaluation.loaders -colbert.evaluation.loaders.load_checkpoint = load_checkpoint -colbert.evaluation.loaders.load_model.__globals__['load_checkpoint'] = load_checkpoint +from . import load_model, DEFAULT_CLASS, DEFAULT_MODEL from colbert.modeling.inference import ModelInference from colbert.evaluation.slow import slow_rerank from colbert.indexing.loaders import get_parts, load_doclens @@ -215,7 +210,10 @@ def our_rerank_with_embeddings_batched(self, qembs, pids, weightsQ=None, gpu=Tru class ColBERTModelOnlyFactory(): def __init__(self, - colbert_model : Union[str, Tuple[colbert.modeling.colbert.ColBERT, dict]], gpu=True): + colbert_model : Union[str, Tuple[colbert.modeling.colbert.ColBERT, dict]], + gpu=True, + baseclass = DEFAULT_CLASS, + basemodel = DEFAULT_MODEL): args = Object() args.query_maxlen = 32 args.doc_maxlen = 180 @@ -232,6 +230,7 @@ def __init__(self, if not gpu: warn("Gpu disabled, YMMV") + import torch import colbert.parameters import colbert.evaluation.load_model import colbert.modeling.colbert @@ -239,12 +238,12 @@ def __init__(self, self.gpu = False if isinstance (colbert_model, str): args.checkpoint = colbert_model - args.colbert, args.checkpoint = load_model(args) + args.colbert, args.checkpoint = load_model(args, baseclass=baseclass, basemodel=basemodel) else: assert isinstance(colbert_model, tuple) args.colbert, args.checkpoint = colbert_model - from colbert.modeling.colbert import ColBERT - assert isinstance(args.colbert, ColBERT) + import torch.nn + assert isinstance(args.colbert, torch.nn.Module) assert isinstance(args.checkpoint, dict) args.inference = ModelInference(args.colbert, amp=args.amp) @@ -500,9 +499,10 @@ def __init__(self, faiss_partitions=None,#TODO 100- memtype = "mem", faisstype= "mem", - gpu=True): + gpu=True, + **kwargs): - super().__init__(colbert_model, gpu=gpu) + super().__init__(colbert_model, gpu=gpu, **kwargs) self.verbose = False self._faissnn = None