Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add MPS support and forking avoidance for training #325

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 29 additions & 28 deletions colbert/indexing/codecs/residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from colbert.infra.config import ColBERTConfig
from colbert.indexing.codecs.residual_embeddings import ResidualEmbeddings
from colbert.utils.utils import print_message
from colbert.parameters import DEVICE

import pathlib
from torch.utils.cpp_extension import load
Expand All @@ -18,33 +19,35 @@
class ResidualCodec:
Embeddings = ResidualEmbeddings

def __init__(self, config, centroids, avg_residual=None, bucket_cutoffs=None, bucket_weights=None):
self.use_gpu = config.total_visible_gpus > 0
def __init__(self, config, centroids, avg_residual=None, bucket_cutoffs=None, bucket_weights=None, device=None):
if device is None:
device = DEVICE
self.device = device

ResidualCodec.try_load_torch_extensions(self.use_gpu)

if self.use_gpu > 0:
self.centroids = centroids.cuda().half()
if self.device.type in ["cuda", "mps"]:
self.centroids = self.centroids.to(self.device).half()
else:
self.centroids = centroids.float()
self.dim, self.nbits = config.dim, config.nbits
self.avg_residual = avg_residual

if torch.is_tensor(self.avg_residual):
if self.use_gpu:
self.avg_residual = self.avg_residual.cuda().half()
if self.device.type in ["cuda", "mps"]:
self.avg_residual = self.avg_residual.to(self.device).half()

if torch.is_tensor(bucket_cutoffs):
if self.use_gpu:
bucket_cutoffs = bucket_cutoffs.cuda()
bucket_weights = bucket_weights.half().cuda()
if self.device.type in ["cuda", "mps"]:
bucket_cutoffs = bucket_cutoffs.to(self.device)
bucket_weights = bucket_weights.to(self.device)

self.bucket_cutoffs = bucket_cutoffs
self.bucket_weights = bucket_weights
if not self.use_gpu and self.bucket_weights is not None:
self.bucket_weights = self.bucket_weights.to(torch.float32)

self.arange_bits = torch.arange(0, self.nbits, device='cuda' if self.use_gpu else 'cpu', dtype=torch.uint8)
self.arange_bits = torch.arange(0, self.nbits, device=self.device.type, dtype=torch.uint8)

self.rank = config.rank

Expand Down Expand Up @@ -89,10 +92,10 @@ def __init__(self, config, centroids, avg_residual=None, bucket_cutoffs=None, bu
)
else:
self.decompression_lookup_table = None
if self.use_gpu:
self.reversed_bit_map = self.reversed_bit_map.cuda()
if self.device.type in ["cuda", "mps"]:
self.reversed_bit_map = self.reversed_bit_map.to(self.device).half()
if self.decompression_lookup_table is not None:
self.decompression_lookup_table = self.decompression_lookup_table.cuda()
self.decompression_lookup_table = self.decompression_lookup_table.to(self.device)

@classmethod
def try_load_torch_extensions(cls, use_gpu):
Expand Down Expand Up @@ -168,8 +171,8 @@ def compress(self, embs):
codes, residuals = [], []

for batch in embs.split(1 << 18):
if self.use_gpu:
batch = batch.cuda().half()
if self.device.type in ["cuda", "mps"]:
batch = batch.to(self.device).half()
codes_ = self.compress_into_codes(batch, out_device=batch.device)
centroids_ = self.lookup_centroids(codes_, out_device=batch.device)

Expand Down Expand Up @@ -211,15 +214,15 @@ def compress_into_codes(self, embs, out_device):

bsize = (1 << 29) // self.centroids.size(0)
for batch in embs.split(bsize):
if self.use_gpu:
indices = (self.centroids @ batch.T.cuda().half()).max(dim=0).indices.to(device=out_device)
if self.device.type in ["cuda", "mps"]:
indices = (self.centroids @ batch.T.to(self.device).half()).max(dim=0).indices.to(self.device)
else:
indices = (self.centroids @ batch.T.cpu().float()).max(dim=0).indices.to(device=out_device)
indices = (self.centroids @ batch.T.to(self.device).float()).max(dim=0).indices
codes.append(indices)

return torch.cat(codes)

def lookup_centroids(self, codes, out_device):
def lookup_centroids(self, codes):
"""
Handles multi-dimensional codes too.

Expand All @@ -228,11 +231,9 @@ def lookup_centroids(self, codes, out_device):

centroids = []


for batch in codes.split(1 << 20):
if self.use_gpu:
centroids.append(self.centroids[batch.cuda().long()].to(device=out_device))
else:
centroids.append(self.centroids[batch.long()].to(device=out_device))
centroids.append(self.centroids[batch.to(self.device).long()].to(self.device))

return torch.cat(centroids)

Expand All @@ -246,8 +247,8 @@ def decompress(self, compressed_embs: Embeddings):

D = []
for codes_, residuals_ in zip(codes.split(1 << 15), residuals.split(1 << 15)):
if self.use_gpu:
codes_, residuals_ = codes_.cuda(), residuals_.cuda()
if self.device.type in ["cuda", "mps"]:
codes_, residuals_ = codes_.to(self.device), residuals_.to(self.device)
centroids_ = ResidualCodec.decompress_residuals(
residuals_,
self.bucket_weights,
Expand All @@ -257,17 +258,17 @@ def decompress(self, compressed_embs: Embeddings):
self.centroids,
self.dim,
self.nbits,
).cuda()
).to(self.device)
else:
# TODO: Remove dead code
centroids_ = self.lookup_centroids(codes_, out_device='cpu')
centroids_ = self.lookup_centroids(codes_)
residuals_ = self.reversed_bit_map[residuals_.long()]
residuals_ = self.decompression_lookup_table[residuals_.long()]
residuals_ = residuals_.reshape(residuals_.shape[0], -1)
residuals_ = self.bucket_weights[residuals_.long()]
centroids_.add_(residuals_)

if self.use_gpu:
if self.device.type in ["cuda", "mps"]:
D_ = torch.nn.functional.normalize(centroids_, p=2, dim=-1).half()
else:
D_ = torch.nn.functional.normalize(centroids_.to(torch.float32), p=2, dim=-1)
Expand Down
69 changes: 30 additions & 39 deletions colbert/indexing/collection_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,26 @@ def __init__(self, config: ColBERTConfig, collection, verbose=2):
self.rank, self.nranks = self.config.rank, self.config.nranks

self.use_gpu = self.config.total_visible_gpus > 0
self.use_mps = False
if torch.backends.mps.is_available() and self.config.use_mps_if_available:
if torch.backends.mps.is_built() :
self.use_mps = True
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
else:
print_message("MPS not available because the current PyTorch install was not built with MPS enabled.")

if self.config.rank == 0 and self.verbose > 1:
self.config.help()

self.collection = Collection.cast(collection)
self.checkpoint = Checkpoint(self.config.checkpoint, colbert_config=self.config)
self.device = torch.device("cpu")
if self.use_gpu:
self.checkpoint = self.checkpoint.cuda()
self.device = torch.device("cuda")
elif self.use_mps:
print_message("Loading model to MPS")
self.device = torch.device("mps")
self.checkpoint = self.checkpoint.to(self.device)

self.encoder = CollectionEncoder(config, self.checkpoint)
self.saver = IndexSaver(config)
Expand Down Expand Up @@ -136,42 +148,23 @@ def _sample_embeddings(self, sampled_pids):

local_sample_embs, doclens = self.encoder.encode_passages(local_sample)

if torch.cuda.is_available():
if torch.distributed.is_available() and torch.distributed.is_initialized():
self.num_sample_embs = torch.tensor([local_sample_embs.size(0)]).cuda()
torch.distributed.all_reduce(self.num_sample_embs)

avg_doclen_est = sum(doclens) / len(doclens) if doclens else 0
avg_doclen_est = torch.tensor([avg_doclen_est]).cuda()
torch.distributed.all_reduce(avg_doclen_est)

nonzero_ranks = torch.tensor([float(len(local_sample) > 0)]).cuda()
torch.distributed.all_reduce(nonzero_ranks)
else:
self.num_sample_embs = torch.tensor([local_sample_embs.size(0)]).cuda()
if torch.distributed.is_available() and torch.distributed.is_initialized():
self.num_sample_embs = torch.tensor([local_sample_embs.size(0)]).to(self.device)
torch.distributed.all_reduce(self.num_sample_embs)

avg_doclen_est = sum(doclens) / len(doclens) if doclens else 0
avg_doclen_est = torch.tensor([avg_doclen_est]).cuda()
avg_doclen_est = sum(doclens) / len(doclens) if doclens else 0
avg_doclen_est = torch.tensor([avg_doclen_est]).to(self.device)
torch.distributed.all_reduce(avg_doclen_est)

nonzero_ranks = torch.tensor([float(len(local_sample) > 0)]).cuda()
nonzero_ranks = torch.tensor([float(len(local_sample) > 0)]).to(self.device)
torch.distributed.all_reduce(nonzero_ranks)
else:
if torch.distributed.is_available() and torch.distributed.is_initialized():
self.num_sample_embs = torch.tensor([local_sample_embs.size(0)]).cpu()
torch.distributed.all_reduce(self.num_sample_embs)

avg_doclen_est = sum(doclens) / len(doclens) if doclens else 0
avg_doclen_est = torch.tensor([avg_doclen_est]).cpu()
torch.distributed.all_reduce(avg_doclen_est)
self.num_sample_embs = torch.tensor([local_sample_embs.size(0)]).to(self.device)

nonzero_ranks = torch.tensor([float(len(local_sample) > 0)]).cpu()
torch.distributed.all_reduce(nonzero_ranks)
else:
self.num_sample_embs = torch.tensor([local_sample_embs.size(0)]).cpu()
avg_doclen_est = sum(doclens) / len(doclens) if doclens else 0
avg_doclen_est = torch.tensor([avg_doclen_est]).to(self.device)

avg_doclen_est = sum(doclens) / len(doclens) if doclens else 0
avg_doclen_est = torch.tensor([avg_doclen_est]).cpu()

nonzero_ranks = torch.tensor([float(len(local_sample) > 0)]).cpu()
nonzero_ranks = torch.tensor([float(len(local_sample) > 0)]).to(self.device)

avg_doclen_est = avg_doclen_est.item() / nonzero_ranks.item()
self.avg_doclen_est = avg_doclen_est
Expand Down Expand Up @@ -241,7 +234,7 @@ def train(self, shared_lists):

# Compute and save codec into avg_residual.pt, buckets.pt and centroids.pt
codec = ResidualCodec(config=self.config, centroids=centroids, avg_residual=avg_residual,
bucket_cutoffs=bucket_cutoffs, bucket_weights=bucket_weights)
bucket_cutoffs=bucket_cutoffs, bucket_weights=bucket_weights, device=self.device)
self.saver.save_codec(codec)

def _concatenate_and_split_sample(self):
Expand Down Expand Up @@ -304,22 +297,20 @@ def _train_kmeans(self, sample, shared_lists):
centroids = compute_faiss_kmeans(*args_)

centroids = torch.nn.functional.normalize(centroids, dim=-1)
if self.use_gpu:
if self.device.type in ["cuda", "mps"]:
centroids = centroids.half()
else:
centroids = centroids.float()

return centroids

def _compute_avg_residual(self, centroids, heldout):
compressor = ResidualCodec(config=self.config, centroids=centroids, avg_residual=None)
compressor = ResidualCodec(config=self.config, centroids=centroids, avg_residual=None, device=self.device)

heldout_reconstruct = compressor.compress_into_codes(heldout, out_device='cuda' if self.use_gpu else 'cpu')
heldout_reconstruct = compressor.lookup_centroids(heldout_reconstruct, out_device='cuda' if self.use_gpu else 'cpu')
if self.use_gpu:
heldout_avg_residual = heldout.cuda() - heldout_reconstruct
else:
heldout_avg_residual = heldout - heldout_reconstruct

heldout_avg_residual = heldout.to(self.device) - heldout_reconstruct

avg_residual = torch.abs(heldout_avg_residual).mean(dim=0).cpu()
print([round(x, 3) for x in avg_residual.squeeze().tolist()])
Expand Down
2 changes: 2 additions & 0 deletions colbert/infra/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class RunSettings:

avoid_fork_if_possible: bool = DefaultVal(False)

use_mps_if_available: bool = DefaultVal(True)

@property
def gpus_(self):
value = self.gpus
Expand Down
8 changes: 5 additions & 3 deletions colbert/modeling/base_colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ class BaseColBERT(torch.nn.Module):
Like HF, evaluation mode is the default.
"""

def __init__(self, name_or_path, colbert_config=None):
def __init__(self, name_or_path, colbert_config=None, device=None):
super().__init__()

self.colbert_config = ColBERTConfig.from_existing(ColBERTConfig.load_from_checkpoint(name_or_path), colbert_config)
self.name = self.colbert_config.model_name or name_or_path
if device is None:
device = DEVICE
self._device = device

try:
HF_ColBERT = class_factory(self.name)
Expand All @@ -34,7 +36,7 @@ def __init__(self, name_or_path, colbert_config=None):
# HF_ColBERT = class_factory(self.name)

self.model = HF_ColBERT.from_pretrained(name_or_path, colbert_config=self.colbert_config)
self.model.to(DEVICE)
self.model.to(self._device)
self.raw_tokenizer = AutoTokenizer.from_pretrained(name_or_path)

self.eval()
Expand Down
8 changes: 4 additions & 4 deletions colbert/modeling/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ class Checkpoint(ColBERT):
TODO: Add .cast() accepting [also] an object instance-of(Checkpoint) as first argument.
"""

def __init__(self, name, colbert_config=None, verbose:int = 3):
super().__init__(name, colbert_config)
def __init__(self, name, colbert_config=None, verbose:int = 3, device=None):
super().__init__(name, colbert_config, device=None)
assert self.training is False

self.verbose = verbose

self.query_tokenizer = QueryTokenizer(self.colbert_config, verbose=self.verbose)
self.doc_tokenizer = DocTokenizer(self.colbert_config)
self.query_tokenizer = QueryTokenizer(self.colbert_config, verbose=self.verbose, device=device)
self.doc_tokenizer = DocTokenizer(self.colbert_config, device=device)

self.amp_manager = MixedPrecisionManager(True)

Expand Down
28 changes: 14 additions & 14 deletions colbert/modeling/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@ class ColBERT(BaseColBERT):
This class handles the basic encoding and scoring operations in ColBERT. It is used for training.
"""

def __init__(self, name='bert-base-uncased', colbert_config=None):
super().__init__(name, colbert_config)
self.use_gpu = colbert_config.total_visible_gpus > 0
def __init__(self, name='bert-base-uncased', colbert_config=None, device=None):
super().__init__(name, colbert_config, device=device)

ColBERT.try_load_torch_extensions(self.use_gpu)
ColBERT.try_load_torch_extensions(self._device.type=="cuda")

if self.colbert_config.mask_punctuation:
self.skiplist = {w: True
Expand Down Expand Up @@ -83,26 +82,26 @@ def compute_ib_loss(self, Q, D, D_mask):
return torch.nn.CrossEntropyLoss()(scores, labels)

def query(self, input_ids, attention_mask):
input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
input_ids, attention_mask = input_ids.to(self._device), attention_mask.to(self._device)
Q = self.bert(input_ids, attention_mask=attention_mask)[0]
Q = self.linear(Q)

mask = torch.tensor(self.mask(input_ids, skiplist=[]), device=self.device).unsqueeze(2).float()
mask = torch.tensor(self.mask(input_ids, skiplist=[]), device=self._device).unsqueeze(2).float()
Q = Q * mask

return torch.nn.functional.normalize(Q, p=2, dim=2)

def doc(self, input_ids, attention_mask, keep_dims=True):
assert keep_dims in [True, False, 'return_mask']

input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
input_ids, attention_mask = input_ids.to(self._device), attention_mask.to(self._device)
D = self.bert(input_ids, attention_mask=attention_mask)[0]
D = self.linear(D)
mask = torch.tensor(self.mask(input_ids, skiplist=self.skiplist), device=self.device).unsqueeze(2).float()
mask = torch.tensor(self.mask(input_ids, skiplist=self.skiplist), device=self._device).unsqueeze(2).float()
D = D * mask

D = torch.nn.functional.normalize(D, p=2, dim=2)
if self.use_gpu:
if self._device.type in ["cuda", "mps"]:
D = D.half()

if keep_dims is False:
Expand All @@ -119,7 +118,7 @@ def score(self, Q, D_padded, D_mask):
if self.colbert_config.similarity == 'l2':
assert self.colbert_config.interaction == 'colbert'
return (-1.0 * ((Q.unsqueeze(2) - D_padded.unsqueeze(1))**2).sum(-1)).max(-1).values.sum(-1)
return colbert_score(Q, D_padded, D_mask, config=self.colbert_config)
return colbert_score(Q, D_padded, D_mask, config=self.colbert_config, torch_device=self._device)

def mask(self, input_ids, skiplist):
mask = [[(x not in skiplist) and (x != self.pad_token) for x in d] for d in input_ids.cpu().tolist()]
Expand Down Expand Up @@ -155,7 +154,7 @@ def colbert_score_reduce(scores_padded, D_mask, config: ColBERTConfig):


# TODO: Wherever this is called, pass `config=`
def colbert_score(Q, D_padded, D_mask, config=ColBERTConfig()):
def colbert_score(Q, D_padded, D_mask, config=ColBERTConfig(), torch_device=None):
"""
Supply sizes Q = (1 | num_docs, *, dim) and D = (num_docs, *, dim).
If Q.size(0) is 1, the matrix will be compared with all passages.
Expand All @@ -164,9 +163,10 @@ def colbert_score(Q, D_padded, D_mask, config=ColBERTConfig()):
EVENTUALLY: Consider masking with -inf for the maxsim (or enforcing a ReLU).
"""

use_gpu = config.total_visible_gpus > 0
if use_gpu:
Q, D_padded, D_mask = Q.cuda(), D_padded.cuda(), D_mask.cuda()
if torch_device is None:
torch_device = DEVICE
if torch_device.type in ["cuda", "mps"]:
Q, D_padded, D_mask = Q.to(torch_device), D_padded.to(torch_device), D_mask.to(torch_device)

assert Q.dim() == 3, Q.size()
assert D_padded.dim() == 3, D_padded.size()
Expand Down
Loading