Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Indexmgr #16

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
103 changes: 93 additions & 10 deletions pyterrier_colbert/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,77 @@

DEBUG=False

def get_parts_ext(directory):
#extension of get_parts to check for other file types
extensions = ['.pt', '.npy', '.store']

parts=[]
for ext in extensions:
print([filename for filename in os.listdir(directory)])
print([filename for filename in os.listdir(directory) if filename.endswith(ext)])
parts = sorted([int(filename[: -1 * len(ext)]) for filename in os.listdir(directory)
if filename.endswith(ext)])
if len(parts) > 0:
extension = ext
print("Found %d index files with ext %s" % (len(parts), extension))
break
if len(parts) == 0:
raise ValueError("found no index embedding files")

assert list(range(len(parts))) == parts, parts

# Integer-sortedness matters.
parts_paths = [os.path.join(directory, '{}{}'.format(filename, extension)) for filename in parts]
samples_paths = [os.path.join(directory, '{}.sample'.format(filename)) for filename in parts]
return parts, parts_paths, samples_paths

def load_index_part_torch(filename, verbose=True):
mmap_storage = torch.HalfStorage.from_file(file_path, False, sum(self.doclens) * self.dim)
return torch.HalfTensor(mmap_storage).view(sum(self.doclens), self.dim)

def load_index_part_torchhalf(filename, verbose=True):
return torch.load(filename)

def load_index_part_numpy(filename):
if filename.endswith(".pt"):
filename = filename.replace(".pt", ".npy")
return torch.from_numpy(np.load(filename))
else:
#resort to torch for sample, etc
return torch.load(filename)

class TorchStorageIndexManager(IndexManager):
"""
A ColBERT IndexManager for torch.HalfStorage, which support mmap
"""

def save(self, tensor, output_file):
if not output_file.endswith(".pt"):
# for .ids, .sample etc, resort to torch.save
return super().save(tensor, output_file)
output_file = output_file.replace(".pt", ".store")
size = tensor.shape[0] * tensor.shape[1]
out_tensor = torch.HalfStorage.from_file(output_file, True, size)
torch.HalfTensor(out_tensor).copy_(tensor.view(-1))

class NumpyIndexManager(IndexManager):
"""
A ColBERT IndexManager for numpy files, which support both mmap and direct loading
"""
def save(self, tensor, output_file):
if not output_file.endswith(".pt"):
# for .ids, .sample etc, resort to torch.save
return super().save(tensor, output_file)
import numpy as np
output_file = output_file.replace(".pt", ".npy")
np.save(output_file, tensor.detach().numpy())
#memmap = np.memmap(output_file, dtype=np.float16, mode='w+', shape=tensor.shape)
#memmap[ : ] = tensor[ : ]
#memmap.flush()
#del(memmap)

class CollectionEncoder():
def __init__(self, args, process_idx, num_processes):
def __init__(self, args, process_idx, num_processes, indexmgr=None):
self.args = args
self.collection = args.collection
self.process_idx = process_idx
Expand All @@ -68,7 +137,19 @@ def __init__(self, args, process_idx, num_processes):
self.print_main(f"#> self.possible_subset_sizes = {self.possible_subset_sizes}")

self._load_model()
self.indexmgr = IndexManager(args.dim)

import colbert.indexing.index_manager, colbert.indexing.loaders, colbert.indexing.faiss
if indexmgr == 'numpy':
self.indexmgr = NumpyIndexManager(args.dim)
colbert.indexing.faiss.load_index_part = load_index_part_numpy
colbert.indexing.faiss.get_parts = colbert.indexing.loaders.get_parts = get_parts_ext
elif indexmgr == 'half':
assert False
self.indexmgr = TorchStorageIndexManager(args.dim)
else:
colbert.indexing.faiss.get_parts = colbert.indexing.loaders.get_parts = get_parts_ext
colbert.indexing.faiss.load_index_part = colbert.indexing.index_manager.load_index_part
self.indexmgr = IndexManager(args.dim)

def _initialize_iterator(self):
return open(self.collection)
Expand Down Expand Up @@ -225,8 +306,8 @@ class Object(object):

class CollectionEncoder_Generator(CollectionEncoder):

def __init__(self, *args, prepend_title=False):
super().__init__(*args)
def __init__(self, *args, prepend_title=False, **kwargs):
super().__init__(*args, **kwargs)
self.prepend_title = prepend_title

def _initialize_iterator(self):
Expand All @@ -253,7 +334,7 @@ def _preprocess_batch(self, offset, lines):


class ColBERTIndexer(IterDictIndexerBase):
def __init__(self, checkpoint, index_root, index_name, chunksize, prepend_title=False, num_docs=None, ids=True, gpu=True):
def __init__(self, checkpoint, index_root, index_name, chunksize, prepend_title=False, num_docs=None, ids=True, gpu=True, indexmgr='None'):
args = Object()
args.similarity = 'cosine'
args.dim = 128
Expand Down Expand Up @@ -281,6 +362,7 @@ def __init__(self, checkpoint, index_root, index_name, chunksize, prepend_title=
self.prepend_title = prepend_title
self.num_docs = num_docs
self.gpu = gpu
self.indexmgr = indexmgr
if not gpu:
warn("Gpu disabled, YMMV")
import colbert.parameters
Expand Down Expand Up @@ -319,7 +401,7 @@ def convert_gen(iterator):
docid+=1
yield l
self.args.generator = convert_gen(iterator)
ceg = CollectionEncoderIds(self.args,0,1) if self.ids else CollectionEncoder_Generator(self.args,0,1)
ceg = CollectionEncoderIds(self.args,0,1, indexmgr=self.indexmgr) if self.ids else CollectionEncoder_Generator(self.args,0,1, indexmgr=self.indexmgr)

create_directory(self.args.index_root)
create_directory(self.args.index_path)
Expand Down Expand Up @@ -449,17 +531,18 @@ def merge_colbert_files(src_dirs, dst_dir):
"""Re-count and sym-link ColBERT index files in src_dirs folders into
a unified ColBERT index in dst_dir folder"""

FILE_PATTERNS = ["%d.pt", "%d.sample", "%d.tokenids", "doclens.%d.json"]
FILE_PATTERNS = ["%d.pt", "%d.store", "%d.np", "%d.sample", "%d.tokenids", "doclens.%d.json"]

src_sizes = [count_parts(d) for d in src_dirs]

offset = 0
for src_size, src_dir in zip(src_sizes, src_dirs):
for i in range(src_size):
for file in FILE_PATTERNS:
src_file = os.path.join(src_dir, file % i)
dst_file = os.path.join(dst_dir, file % (offset + i))
os.symlink(src_file, dst_file)
if os.path.exists(src_file):
src_file = os.path.join(src_dir, file % i)
dst_file = os.path.join(dst_dir, file % (offset + i))
os.symlink(src_file, dst_file)
offset += src_size

def make_new_faiss(index_root, index_name, **kwargs):
Expand Down
35 changes: 35 additions & 0 deletions pyterrier_colbert/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,37 @@ def get_embedding(self, pid):
endpos = self.endpos[pid]
return self.mmap[startpos:endpos,:]

class numpy_file_part_mmap:
def __init__(self, file_path, file_doclens):
self.dim = 128 # TODO
file_path = file_path.replace(".pt", ".np")
self.doclens = file_doclens
self.endpos = np.cumsum(self.doclens)
self.startpos = self.endpos - self.doclens
import numpy as np
self.mmap = torch.from_numpy(np.load(file_path, mmap_mode='r+'))
print(self.mmap.shape)

def get_embedding(self, pid):
startpos = self.startpos[pid]
endpos = self.endpos[pid]
return self.mmap[startpos:endpos,:]

class numpy_file_part_mem:
def __init__(self, file_path, file_doclens):
self.dim = 128 # TODO
file_path = file_path.replace(".pt", ".npy")
self.doclens = file_doclens
self.endpos = np.cumsum(self.doclens)
self.startpos = self.endpos - self.doclens
import numpy as np
self.mmap = torch.from_numpy(np.load(file_path))
print(self.mmap.shape)

def get_embedding(self, pid):
startpos = self.startpos[pid]
endpos = self.endpos[pid]
return self.mmap[startpos:endpos,:]

class Object(object):
pass
Expand Down Expand Up @@ -101,6 +132,10 @@ def _load_parts(index_path, part_doclens, memtype="mmap"):
mmaps = [file_part_mmap(path, doclens) for path, doclens in zip(all_parts_paths, part_doclens)]
elif memtype == "mem":
mmaps = [file_part_mem(path, doclens) for path, doclens in tqdm(zip(all_parts_paths, part_doclens), total=len(all_parts_paths), desc="Loading index shards to memory", unit="shard")]
elif memtype == "numpy":
mmaps = [numpy_file_part_mem(path, doclens) for path, doclens in tqdm(zip(all_parts_paths, part_doclens), total=len(all_parts_paths), desc="Loading index shards to memory", unit="shard")]
elif memtype == "numpy_mmap":
mmaps = [numpy_file_part_mmap(path, doclens) for path, doclens in tqdm(zip(all_parts_paths, part_doclens), total=len(all_parts_paths), desc="Loading index shards to memory", unit="shard")]
else:
assert False, "Unknown memtype %s" % memtype
return mmaps
Expand Down
24 changes: 14 additions & 10 deletions tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
CHECKPOINT="http://www.dcs.gla.ac.uk/~craigm/colbert.dnn.zip"
class TestIndexing(unittest.TestCase):

def _indexing_1doc(self, indexmgr):
def _indexing_1doc(self, indexmgr, indexread):
#minimum test case size is 100 docs, 40 Wordpiece tokens, and nx > k. we found 200 worked
import pyterrier as pt
from pyterrier_colbert.indexing import ColBERTIndexer
Expand All @@ -13,7 +13,7 @@ def _indexing_1doc(self, indexmgr):
CHECKPOINT,
os.path.dirname(self.test_dir),os.path.basename(self.test_dir),
chunksize=3,
#indexmgr=indexmgr,
indexmgr=indexmgr,
gpu=False)

iter = pt.get_dataset("vaswani").get_corpus_iter()
Expand All @@ -22,6 +22,7 @@ def _indexing_1doc(self, indexmgr):
import pyterrier_colbert.pruning as pruning

for factory in [indexer.ranking_factory()]:
factory.memtype = indexread

for pipe, has_score, name in [
(factory.end_to_end(), True, "E2E"),
Expand Down Expand Up @@ -58,12 +59,6 @@ def _indexing_1doc(self, indexmgr):
else:
self.assertFalse("score" in dfOut.columns)

# def test_indexing_1doc_numpy(self):
# self._indexing_1doc('numpy')

# def test_indexing_1doc_half(self):
# self._indexing_1doc('half')

def indexing_empty(self):
#minimum test case size is 100 docs, 40 Wordpiece tokens, and nx > k. we found 200 worked
import pyterrier as pt
Expand Down Expand Up @@ -108,8 +103,17 @@ def indexing_merged(self):
factory = ColBERTFactory(CHECKPOINT, index_root, "index_part", faiss_partitions=100, gpu=False)
self.assertEqual(400, len(factory.docid2docno))

def test_indexing_1doc_torch(self):
self._indexing_1doc('torch')
def test_indexing_1doc_torch_mem(self):
self._indexing_1doc('torch', "mem")

# def test_indexing_1doc_torch_mem(self):
# self._indexing_1doc('half', "mmap")

def test_indexing_1doc_numpy_mem(self):
self._indexing_1doc('numpy', 'numpy')

def test_indexing_1doc_numpy_mmap(self):
self._indexing_1doc('numpy', 'numpy_mmap')

def setUp(self):
import pyterrier as pt
Expand Down