Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] #107

Open
wants to merge 4 commits into
base: index_on_ssd
Choose a base branch
from
Open

[WIP] #107

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion colbert/indexing/codecs/residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class ResidualCodec:
Embeddings = ResidualEmbeddings

def __init__(self, config, centroids, avg_residual=None, bucket_cutoffs=None, bucket_weights=None):
self.mmap_index = config.mmap_index
self.use_gpu = config.total_visible_gpus > 0
if self.use_gpu > 0:
self.centroids = centroids.cuda().half()
Expand Down Expand Up @@ -159,7 +160,7 @@ def compress(self, embs):
codes = torch.cat(codes)
residuals = torch.cat(residuals)

return ResidualCodec.Embeddings(codes, residuals)
return ResidualCodec.Embeddings(codes, residuals, mmap_index=self.mmap_index)

def binarize(self, residuals):
residuals = torch.bucketize(residuals.float(), self.bucket_cutoffs).to(dtype=torch.uint8)
Expand Down
100 changes: 79 additions & 21 deletions colbert/indexing/codecs/residual_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import os
import torch
import ujson
from collections import defaultdict, namedtuple

from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
from colbert.utils.utils import print_message


class ResidualEmbeddings:
Strided = ResidualEmbeddingsStrided

def __init__(self, codes, residuals, mmap_index=False):
def __init__(self, codes, residuals, mmap_index=False, pid_to_chunk_metadata=None):
"""
Supply the already compressed residuals.
"""
Expand All @@ -17,6 +19,8 @@ def __init__(self, codes, residuals, mmap_index=False):
if self.mmap_index:
self.codes = codes
self.residuals = residuals
self.pid_to_chunk_metadata = pid_to_chunk_metadata
return

# assert isinstance(residuals, bitarray), type(residuals)
assert codes.size(0) == residuals.size(0), (codes.size(), residuals.size())
Expand Down Expand Up @@ -45,20 +49,23 @@ def load_chunks(cls, index_path, chunk_idxs, num_embeddings, mmap_index=False):
codes_offset = 0
pid_offset = 0

per_chunk_doclens = {}
pid_to_chunk_idx = {}
ChunkMetadata = namedtuple('ChunkMetadata', 'chunk_id, passage_doclen, passage_offset')
pid_to_chunk_metadata = {} # pid -> [chunk id, passage doclen, passage offset in the chunk]

for chunk_idx in chunk_idxs:
with open(os.path.join(index_path, f'{chunk_idx}.metadata.json')) as f:
metadata = ujson.load(f)

with open(os.path.join(index_path, f'doclens.{chunk_idx}.json')) as f:
chunk_doclens = ujson.load(f)

pid_offset_in_chunk = 0
for pid in range(pid_offset, pid_offset + metadata["num_passages"]):
pid_to_chunk_idx[pid] = chunk_idx
pid_doclen = chunk_doclens[pid - pid_offset]
pid_to_chunk_metadata[pid] = ChunkMetadata(chunk_idx, pid_doclen, pid_offset_in_chunk)
pid_offset_in_chunk += pid_doclen
pid_offset += metadata["num_passages"]

with open(os.path.join(index_path, f'{chunk_idx}.doclens.json')) as f:
per_chunk_doclens[chunk_idx] = ujson.load(f)

codes_endpos = codes_offset + metadata["num_embeddings"]

chunk = cls.load(index_path, chunk_idx, codes_offset, codes_endpos, packed_dim, mmap_index)
Expand All @@ -76,8 +83,9 @@ def load_chunks(cls, index_path, chunk_idxs, num_embeddings, mmap_index=False):
codes_offset = codes_endpos

# codes, residuals = codes.cuda(), residuals.cuda() # FIXME: REMOVE THIS LINE!
print(f"code is {codes}")

return cls(codes, residuals)
return cls(codes, residuals, mmap_index=mmap_index, pid_to_chunk_metadata=pid_to_chunk_metadata)

@classmethod
def load(cls, index_path, chunk_idx, offset, endpos, packed_dim, mmap_index=False):
Expand All @@ -87,7 +95,7 @@ def load(cls, index_path, chunk_idx, offset, endpos, packed_dim, mmap_index=Fals
return cls(codes, residuals)

@classmethod
def load_codes(self, index_path, chunk_idx, offset, endpos, packed_dim, mmap_index=False):
def load_codes(self, index_path, chunk_idx, offset=None, endpos=None, packed_dim=None, mmap_index=False):
codes_path = os.path.join(index_path, f'{chunk_idx}.codes.pt')

if mmap_index:
Expand All @@ -109,33 +117,83 @@ def load_residuals(self, index_path, chunk_idx, offset, endpos, packed_dim, mmap

return torch.load(residuals_path, map_location='cpu')

def save(self, path_prefix):
def save(self, index_path, chunk_idx):
path_prefix = os.path.join(index_path, str(chunk_idx))
codes_path = f'{path_prefix}.codes.pt'
residuals_path = f'{path_prefix}.residuals.pt' # f'{path_prefix}.residuals.bn'

torch.save(self.codes, codes_path)
torch.save(self.residuals, residuals_path)
print(f"saving code {self.codes}, {self.codes.shape[0]}")
if self.mmap_index:
print("using mmap")
codes_size = self.codes.shape[0]
storage = torch.IntStorage.from_file(codes_path, True, codes_size)
torch.IntTensor(storage).copy_(self.codes)

dim, nbits = get_dim_and_nbits(index_path)
packed_dim = dim // 8 * nbits
residuals_size = codes_size * packed_dim
storage = torch.ByteStorage.from_file(residuals_path, True, residuals_size)
torch.ByteTensor(storage).copy_(self.residuals)
else:
torch.save(self.codes, codes_path)
torch.save(self.residuals, residuals_path)
# _save_bitarray(self.residuals, residuals_path)

def lookup_codes(self, pids):
assert self.mmap_index
codes = torch.zeros((sum(self.doclens[pid] for pid in pids]))
# prev_pid = 0
# for pid in pids:
# if pid.item() < prev_pid:
# print_message("not in order")
# prev_pid = pid.item()

pids_per_chunk = defaultdict(list)
for pid in pids:
chunk_idx = self.pid_to_chunk_idx[pid]
pids_per_chunk[chunk_idx].append(pid)
codes_lengths = torch.zeros(len(pids))
codes_size = 0
for idx, pid in enumerate(pids):
pid_ = pid.item()
chunk_idx, pid_doclen, _ = self.pid_to_chunk_metadata[pid_]
pids_per_chunk[chunk_idx].append(pid_)
codes_lengths[idx] = pid_doclen
codes_size += pid_doclen
codes = torch.zeros(codes_size, dtype=torch.int32)

offset = 0
for chunk_idx in sorted(chunks.keys()):
for chunk_idx in sorted(pids_per_chunk.keys()):
Copy link
Collaborator

@santhnm2 santhnm2 May 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a problem in the code I wrote, but one concern I have is that it may be possible for the pids passed in to lookup_pids to not be in sorted order - if that is the case, then iterating over the chunks in sorted order would produce an output that's inconsistent with what the calling function expects

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the solution would be to iterate over the pids in the original order, then find the chunk each pid belongs to and extract the relevant data from that chunk
Or alternatively, we could sort the pids upstream somewhere (assuming this is not done already)

pids_ = pids_per_chunk[chunk_idx]
for pid in pids_:
codes[offset:offset + self.doclens[pid]] = self.codes[chunk_idx][self.chunk_offsets[pid]:self.chunk_offsets[pid] + doclens[pid]]
offset += doclens[pid]
_, pid_doclen, pid_offset_in_chunk = self.pid_to_chunk_metadata[pid]
codes[offset:offset + pid_doclen] = \
self.codes[chunk_idx][pid_offset_in_chunk:pid_offset_in_chunk + pid_doclen]
offset += pid_doclen

return codes
return codes, codes_lengths.long()

def lookup_pids(self, pids):
assert self.mmap_index
pass
packed_dim = self.residuals[0].shape[1]

pids_per_chunk = defaultdict(list)
residuals_lengths = torch.zeros(len(pids))
residuals_size = 0
for idx, pid in enumerate(pids):
pid_ = pid.item()
chunk_idx, pid_doclen, _ = self.pid_to_chunk_metadata[pid_]
pids_per_chunk[chunk_idx].append(pid_)
residuals_lengths[idx] = pid_doclen
residuals_size += pid_doclen
residuals = torch.zeros(residuals_size, packed_dim, dtype=torch.uint8)

offset = 0
for chunk_idx in sorted(pids_per_chunk.keys()):
pids_ = pids_per_chunk[chunk_idx]
for pid in pids_:
_, pid_doclen, pid_offset_in_chunk = self.pid_to_chunk_metadata[pid]
residuals[offset:offset + pid_doclen] = \
self.residuals[chunk_idx][pid_offset_in_chunk:pid_offset_in_chunk + pid_doclen]
offset += pid_doclen

return residuals, residuals_lengths

def __len__(self):
return self.codes.size(0)
Expand Down
2 changes: 1 addition & 1 deletion colbert/indexing/collection_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def _build_ivf(self):

for chunk_idx in range(self.num_chunks):
offset = self.embedding_offsets[chunk_idx]
chunk_codes = ResidualCodec.Embeddings.load_codes(self.config.index_path_, chunk_idx)
chunk_codes = ResidualCodec.Embeddings.load_codes(self.config.index_path_, chunk_idx, mmap_index=self.config.mmap_index)

codes[offset:offset+chunk_codes.size(0)] = chunk_codes

Expand Down
3 changes: 1 addition & 2 deletions colbert/indexing/index_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def _saver_thread(self):
self._write_chunk_to_disk(*args)

def _write_chunk_to_disk(self, chunk_idx, offset, compressed_embs, doclens):
path_prefix = os.path.join(self.config.index_path_, str(chunk_idx))
compressed_embs.save(path_prefix)
compressed_embs.save(self.config.index_path_, chunk_idx)

doclens_path = os.path.join(self.config.index_path_, f'doclens.{chunk_idx}.json')
with open(doclens_path, 'w') as output_doclens:
Expand Down
3 changes: 2 additions & 1 deletion colbert/search/index_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from colbert.indexing.loaders import load_doclens
from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
from colbert.indexing.codecs import residual_embeddings

from colbert.search.strided_tensor import StridedTensor
from colbert.search.candidate_generation import CandidateGeneration
Expand Down Expand Up @@ -124,7 +125,7 @@ def score_pids(self, config, Q, pids, centroid_scores):
pids = pids[torch.topk(approx_scores, k=config.ndocs).indices]

# Filter docs using full centroid scores
codes_packed, codes_lengths = self.lookup_codes(pids_)
codes_packed, codes_lengths = self.lookup_codes(pids)
approx_scores = centroid_scores[codes_packed.long()]
approx_scores_strided = StridedTensor(approx_scores, codes_lengths, use_gpu=self.use_gpu)
approx_scores_padded, approx_scores_mask = approx_scores_strided.as_padded_tensor()
Expand Down