From a386d3603e0bc62d126d6105695877eee6d5d47e Mon Sep 17 00:00:00 2001 From: Craig Macdonald Date: Thu, 25 Aug 2022 18:14:38 +0100 Subject: [PATCH 1/6] customisable model --- pyterrier_colbert/__init__.py | 47 +++++++++++++++++++++++++++++++++++ pyterrier_colbert/indexing.py | 21 ++++++++-------- pyterrier_colbert/ranking.py | 19 +++++++------- 3 files changed, 66 insertions(+), 21 deletions(-) diff --git a/pyterrier_colbert/__init__.py b/pyterrier_colbert/__init__.py index 961f061..06ebf4c 100644 --- a/pyterrier_colbert/__init__.py +++ b/pyterrier_colbert/__init__.py @@ -7,6 +7,53 @@ import torch from colbert.utils.utils import print_message from collections import OrderedDict, defaultdict +from colbert.modeling.colbert import ColBERT + +DEFAULT_CLASS=ColBERT +DEFAULT_MODEL='bert-base-uncased' + +def load_model(args, do_print=True, baseclass=DEFAULT_CLASS, basemodel=DEFAULT_MODEL): + from colbert.parameters import DEVICE + colbert = baseclass.from_pretrained(basemodel, + query_maxlen=args.query_maxlen, + doc_maxlen=args.doc_maxlen, + dim=args.dim, + similarity_metric=args.similarity, + mask_punctuation=args.mask_punctuation) + colbert = colbert.to(DEVICE) + + print_message("#> Loading model checkpoint.", condition=do_print) + + checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print) + + colbert.eval() + + return colbert, checkpoint + +def load_colbert(args, do_print=True, baseclass=DEFAULT_CLASS, basemodel=DEFAULT_MODEL): + from colbert.utils.runs import Run + import ujson + + colbert, checkpoint = load_model(args, do_print, baseclass=baseclass, basemodel=basemodel) + + # TODO: If the parameters below were not specified on the command line, their *checkpoint* values should be used. + # I.e., not their purely (i.e., training) default values. + + for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']: + if 'arguments' in checkpoint and hasattr(args, k): + if k in checkpoint['arguments'] and checkpoint['arguments'][k] != getattr(args, k): + a, b = checkpoint['arguments'][k], getattr(args, k) + Run.warn(f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})") + + if 'arguments' in checkpoint: + if args.rank < 1: + print(ujson.dumps(checkpoint['arguments'], indent=4)) + + if do_print: + print('\n') + + return colbert, checkpoint + def load_checkpoint(path, model, optimizer=None, do_print=True): if do_print: diff --git a/pyterrier_colbert/indexing.py b/pyterrier_colbert/indexing.py index 5184c3e..773ddf8 100644 --- a/pyterrier_colbert/indexing.py +++ b/pyterrier_colbert/indexing.py @@ -31,11 +31,7 @@ from colbert.modeling.inference import ModelInference from colbert.evaluation.loaders import load_colbert -from . import load_checkpoint -# monkeypatch to use our downloading version -import colbert.evaluation.loaders -colbert.evaluation.loaders.load_checkpoint = load_checkpoint -colbert.evaluation.loaders.load_model.__globals__['load_checkpoint'] = load_checkpoint +from . import load_colbert, DEFAULT_MODEL, DEFAULT_CLASS from colbert.utils.utils import print_message import pickle from colbert.indexing.index_manager import IndexManager @@ -44,11 +40,13 @@ DEBUG=False class CollectionEncoder(): - def __init__(self, args, process_idx, num_processes): + def __init__(self, args, process_idx, num_processes, baseclass=DEFAULT_CLASS, basemodel=DEFAULT_MODEL): self.args = args self.collection = args.collection self.process_idx = process_idx self.num_processes = num_processes + self.baseclass = baseclass + self.basemodel = basemodel self.iterator = self._initialize_iterator() # Chunksize represents the maximum size of a single .pt file @@ -80,7 +78,7 @@ def _saver_thread(self): self._save_batch(*args) def _load_model(self): - self.colbert, self.checkpoint = load_colbert(self.args, do_print=(self.process_idx == 0)) + self.colbert, self.checkpoint = load_colbert(self.args, do_print=(self.process_idx == 0), baseclass=self.baseclass, basemodel=self.basemodel) if not colbert.parameters.DEVICE == torch.device("cpu"): self.colbert = self.colbert.cuda() self.colbert.eval() @@ -259,7 +257,7 @@ def _preprocess_batch(self, offset, lines): class ColBERTIndexer(IterDictIndexerBase): - def __init__(self, checkpoint, index_root, index_name, chunksize, prepend_title=False, num_docs=None, ids=True, gpu=True): + def __init__(self, checkpoint, index_root, index_name, chunksize, prepend_title=False, num_docs=None, ids=True, gpu=True, baseclass=DEFAULT_CLASS, basemodel=DEFAULT_MODEL): args = Object() args.similarity = 'cosine' args.dim = 128 @@ -287,12 +285,13 @@ def __init__(self, checkpoint, index_root, index_name, chunksize, prepend_title= self.prepend_title = prepend_title self.num_docs = num_docs self.gpu = gpu + self.baseclass = baseclass + self.basemodel = basemodel if not gpu: warn("Gpu disabled, YMMV") import colbert.parameters - import colbert.evaluation.load_model import colbert.modeling.colbert - colbert.parameters.DEVICE = colbert.evaluation.load_model.DEVICE = colbert.modeling.colbert.DEVICE = torch.device("cpu") + colbert.parameters.DEVICE = colbert.modeling.colbert.DEVICE = torch.device("cpu") assert self.args.slices >= 1 assert self.args.sample is None or (0.0 < self.args.sample <1.0), self.args.sample @@ -325,7 +324,7 @@ def convert_gen(iterator): docid+=1 yield l self.args.generator = convert_gen(iterator) - ceg = CollectionEncoderIds(self.args,0,1) if self.ids else CollectionEncoder_Generator(self.args,0,1) + ceg = CollectionEncoderIds(self.args,0,1, baseclass=self.baseclass, basemodel=self.basemodel) if self.ids else CollectionEncoder_Generator(self.args,0,1, baseclass=self.baseclass, basemodel=self.basemodel) create_directory(self.args.index_root) create_directory(self.args.index_path) diff --git a/pyterrier_colbert/ranking.py b/pyterrier_colbert/ranking.py index 2c3e9f5..dd6fb47 100644 --- a/pyterrier_colbert/ranking.py +++ b/pyterrier_colbert/ranking.py @@ -9,12 +9,7 @@ from pyterrier.transformer import TransformerBase from pyterrier.datasets import Dataset from typing import Union, Tuple -from colbert.evaluation.load_model import load_model -from . import load_checkpoint -# monkeypatch to use our downloading version -import colbert.evaluation.loaders -colbert.evaluation.loaders.load_checkpoint = load_checkpoint -colbert.evaluation.loaders.load_model.__globals__['load_checkpoint'] = load_checkpoint +from . import load_model, DEFAULT_CLASS, DEFAULT_MODEL from colbert.modeling.inference import ModelInference from colbert.evaluation.slow import slow_rerank from colbert.indexing.loaders import get_parts, load_doclens @@ -215,7 +210,10 @@ def our_rerank_with_embeddings_batched(self, qembs, pids, weightsQ=None, gpu=Tru class ColBERTModelOnlyFactory(): def __init__(self, - colbert_model : Union[str, Tuple[colbert.modeling.colbert.ColBERT, dict]], gpu=True): + colbert_model : Union[str, Tuple[colbert.modeling.colbert.ColBERT, dict]], + gpu=True, + baseclass = DEFAULT_CLASS, + basemodel = DEFAULT_MODEL): args = Object() args.query_maxlen = 32 args.doc_maxlen = 180 @@ -239,7 +237,7 @@ def __init__(self, self.gpu = False if isinstance (colbert_model, str): args.checkpoint = colbert_model - args.colbert, args.checkpoint = load_model(args) + args.colbert, args.checkpoint = load_model(args, baseclass=baseclass, basemodel=basemodel) else: assert isinstance(colbert_model, tuple) args.colbert, args.checkpoint = colbert_model @@ -500,9 +498,10 @@ def __init__(self, faiss_partitions=None,#TODO 100- memtype = "mem", faisstype= "mem", - gpu=True): + gpu=True, + **kwargs): - super().__init__(colbert_model, gpu=gpu) + super().__init__(colbert_model, gpu=gpu, **kwargs) self.verbose = False self._faissnn = None From dffe04e6b10efa3d2e2239033a7cbe62273f11a3 Mon Sep 17 00:00:00 2001 From: Craig Macdonald Date: Thu, 25 Aug 2022 18:24:13 +0100 Subject: [PATCH 2/6] pass kwargs through constructor --- pyterrier_colbert/indexing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyterrier_colbert/indexing.py b/pyterrier_colbert/indexing.py index 773ddf8..6376e4c 100644 --- a/pyterrier_colbert/indexing.py +++ b/pyterrier_colbert/indexing.py @@ -227,8 +227,8 @@ class Object(object): class CollectionEncoder_Generator(CollectionEncoder): - def __init__(self, *args, prepend_title=False): - super().__init__(*args) + def __init__(self, *args, prepend_title=False, **kwargs): + super().__init__(*args, **kwargs) self.prepend_title = prepend_title def _initialize_iterator(self): From 6933d54a458d173d27a3e307c5692f2b78e55478 Mon Sep 17 00:00:00 2001 From: Craig Macdonald Date: Thu, 25 Aug 2022 18:34:26 +0100 Subject: [PATCH 3/6] fix import --- pyterrier_colbert/indexing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyterrier_colbert/indexing.py b/pyterrier_colbert/indexing.py index 6376e4c..2cb9d2d 100644 --- a/pyterrier_colbert/indexing.py +++ b/pyterrier_colbert/indexing.py @@ -79,6 +79,7 @@ def _saver_thread(self): def _load_model(self): self.colbert, self.checkpoint = load_colbert(self.args, do_print=(self.process_idx == 0), baseclass=self.baseclass, basemodel=self.basemodel) + import colbert.parameters if not colbert.parameters.DEVICE == torch.device("cpu"): self.colbert = self.colbert.cuda() self.colbert.eval() From 471d2df2828678f2c9d531311d16f54053dac634 Mon Sep 17 00:00:00 2001 From: Craig Macdonald Date: Thu, 25 Aug 2022 18:34:46 +0100 Subject: [PATCH 4/6] remove dup import --- pyterrier_colbert/indexing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyterrier_colbert/indexing.py b/pyterrier_colbert/indexing.py index 2cb9d2d..ee2cc33 100644 --- a/pyterrier_colbert/indexing.py +++ b/pyterrier_colbert/indexing.py @@ -30,7 +30,6 @@ import queue from colbert.modeling.inference import ModelInference -from colbert.evaluation.loaders import load_colbert from . import load_colbert, DEFAULT_MODEL, DEFAULT_CLASS from colbert.utils.utils import print_message import pickle From 99d69f45505db0799b76b00343015c5da46d5089 Mon Sep 17 00:00:00 2001 From: Craig Macdonald Date: Fri, 26 Aug 2022 14:59:35 +0100 Subject: [PATCH 5/6] loosen assertion --- pyterrier_colbert/ranking.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyterrier_colbert/ranking.py b/pyterrier_colbert/ranking.py index dd6fb47..0a0ae62 100644 --- a/pyterrier_colbert/ranking.py +++ b/pyterrier_colbert/ranking.py @@ -241,8 +241,8 @@ def __init__(self, else: assert isinstance(colbert_model, tuple) args.colbert, args.checkpoint = colbert_model - from colbert.modeling.colbert import ColBERT - assert isinstance(args.colbert, ColBERT) + import torch.nn + assert isinstance(args.colbert, torch.nn.Module) assert isinstance(args.checkpoint, dict) args.inference = ModelInference(args.colbert, amp=args.amp) From 2e1b586e713497c40bdfd96c3d301213c87e1a44 Mon Sep 17 00:00:00 2001 From: Craig Macdonald Date: Fri, 26 Aug 2022 15:02:24 +0100 Subject: [PATCH 6/6] added import --- pyterrier_colbert/ranking.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyterrier_colbert/ranking.py b/pyterrier_colbert/ranking.py index 0a0ae62..a14df32 100644 --- a/pyterrier_colbert/ranking.py +++ b/pyterrier_colbert/ranking.py @@ -230,6 +230,7 @@ def __init__(self, if not gpu: warn("Gpu disabled, YMMV") + import torch import colbert.parameters import colbert.evaluation.load_model import colbert.modeling.colbert