diff --git a/reinvent_models/lib_invent/enums/generative_model_parameters.py b/reinvent_models/lib_invent/enums/generative_model_parameters.py index 564f81a..a93c725 100644 --- a/reinvent_models/lib_invent/enums/generative_model_parameters.py +++ b/reinvent_models/lib_invent/enums/generative_model_parameters.py @@ -1,3 +1,5 @@ + + class GenerativeModelParametersEnum: NUMBER_OF_LAYERS = "num_layers" NUMBER_OF_DIMENSIONS = "num_dimensions" @@ -12,4 +14,4 @@ def __getattr__(self, name): # prohibit any attempt to set any values def __setattr__(self, key, value): - raise ValueError("No changes allowed.") + raise ValueError("No changes allowed.") \ No newline at end of file diff --git a/reinvent_models/lib_invent/enums/generative_model_regime.py b/reinvent_models/lib_invent/enums/generative_model_regime.py index b6ce647..989cc27 100644 --- a/reinvent_models/lib_invent/enums/generative_model_regime.py +++ b/reinvent_models/lib_invent/enums/generative_model_regime.py @@ -1,3 +1,5 @@ + + class GenerativeModelRegimeEnum: INFERENCE = "inference" TRAINING = "training" diff --git a/reinvent_models/lib_invent/models/dataset.py b/reinvent_models/lib_invent/models/dataset.py index 6572dc0..064df68 100644 --- a/reinvent_models/lib_invent/models/dataset.py +++ b/reinvent_models/lib_invent/models/dataset.py @@ -27,9 +27,7 @@ def __init__(self, smiles_list, vocabulary, tokenizer): self._encoded_list.append(enc) def __getitem__(self, i): - return torch.tensor( - self._encoded_list[i], dtype=torch.long - ) # pylint: disable=E1102 + return torch.tensor(self._encoded_list[i], dtype=torch.long) # pylint: disable=E1102 def __len__(self): return len(self._encoded_list) @@ -47,21 +45,14 @@ def __init__(self, scaffold_decoration_smi_list, vocabulary): self._encoded_list = [] for scaffold, dec in scaffold_decoration_smi_list: - en_scaff = self.vocabulary.scaffold_vocabulary.encode( - self.vocabulary.scaffold_tokenizer.tokenize(scaffold) - ) - en_dec = self.vocabulary.decoration_vocabulary.encode( - self.vocabulary.decoration_tokenizer.tokenize(dec) - ) + en_scaff = self.vocabulary.scaffold_vocabulary.encode(self.vocabulary.scaffold_tokenizer.tokenize(scaffold)) + en_dec = self.vocabulary.decoration_vocabulary.encode(self.vocabulary.decoration_tokenizer.tokenize(dec)) if en_scaff is not None and en_dec is not None: self._encoded_list.append((en_scaff, en_dec)) def __getitem__(self, i): scaff, dec = self._encoded_list[i] - return ( - torch.tensor(scaff, dtype=torch.long), - torch.tensor(dec, dtype=torch.long), - ) # pylint: disable=E1102 + return (torch.tensor(scaff, dtype=torch.long), torch.tensor(dec, dtype=torch.long)) # pylint: disable=E1102 def __len__(self): return len(self._encoded_list) @@ -83,9 +74,7 @@ def pad_batch(encoded_seqs): :param encoded_seqs: A list of encoded sequences. :return: A tensor with the sequences correctly padded. """ - seq_lengths = torch.tensor( - [len(seq) for seq in encoded_seqs], dtype=torch.int64 - ) # pylint: disable=not-callable + seq_lengths = torch.tensor([len(seq) for seq in encoded_seqs], dtype=torch.int64) # pylint: disable=not-callable if torch.cuda.is_available(): return (tnnur.pad_sequence(encoded_seqs, batch_first=True).cuda(), seq_lengths) return (tnnur.pad_sequence(encoded_seqs, batch_first=True), seq_lengths) diff --git a/reinvent_models/lib_invent/models/decorator.py b/reinvent_models/lib_invent/models/decorator.py index 5c512c8..fa06ef8 100644 --- a/reinvent_models/lib_invent/models/decorator.py +++ b/reinvent_models/lib_invent/models/decorator.py @@ -7,9 +7,7 @@ import torch.nn as tnn import torch.nn.utils.rnn as tnnur -from reinvent_models.lib_invent.enums.generative_model_parameters import ( - GenerativeModelParametersEnum, -) +from reinvent_models.lib_invent.enums.generative_model_parameters import GenerativeModelParametersEnum class Encoder(tnn.Module): @@ -27,16 +25,10 @@ def __init__(self, num_layers, num_dimensions, vocabulary_size, dropout): self._embedding = tnn.Sequential( tnn.Embedding(self.vocabulary_size, self.num_dimensions), - tnn.Dropout(dropout), - ) - self._rnn = tnn.LSTM( - self.num_dimensions, - self.num_dimensions, - self.num_layers, - batch_first=True, - dropout=self.dropout, - bidirectional=True, + tnn.Dropout(dropout) ) + self._rnn = tnn.LSTM(self.num_dimensions, self.num_dimensions, self.num_layers, + batch_first=True, dropout=self.dropout, bidirectional=True) def forward(self, padded_seqs, seq_lengths): # pylint: disable=arguments-differ # FIXME: This fails with a batch of 1 because squeezing looses a dimension with size 1 @@ -53,40 +45,26 @@ def forward(self, padded_seqs, seq_lengths): # pylint: disable=arguments-differ padded_seqs = self._embedding(padded_seqs) hs_h, hs_c = (hidden_state, hidden_state.clone().detach()) - # FIXME: this is to guard against non compatible `gpu` input for pack_padded_sequence() method in pytorch 1.7 + #FIXME: this is to guard against non compatible `gpu` input for pack_padded_sequence() method in pytorch 1.7 seq_lengths = seq_lengths.cpu() - packed_seqs = tnnur.pack_padded_sequence( - padded_seqs, seq_lengths, batch_first=True, enforce_sorted=False - ) + packed_seqs = tnnur.pack_padded_sequence(padded_seqs, seq_lengths, batch_first=True, enforce_sorted=False) packed_seqs, (hs_h, hs_c) = self._rnn(packed_seqs, (hs_h, hs_c)) padded_seqs, _ = tnnur.pad_packed_sequence(packed_seqs, batch_first=True) # sum up bidirectional layers and collapse - hs_h = ( - hs_h.view(self.num_layers, 2, batch_size, self.num_dimensions) - .sum(dim=1) - .squeeze() - ) # (layers, batch, dim) - hs_c = ( - hs_c.view(self.num_layers, 2, batch_size, self.num_dimensions) - .sum(dim=1) - .squeeze() - ) # (layers, batch, dim) - padded_seqs = ( - padded_seqs.view(batch_size, max_seq_size, 2, self.num_dimensions) - .sum(dim=2) - .squeeze() - ) # (batch, seq, dim) + hs_h = hs_h.view(self.num_layers, 2, batch_size, self.num_dimensions)\ + .sum(dim=1).squeeze() # (layers, batch, dim) + hs_c = hs_c.view(self.num_layers, 2, batch_size, self.num_dimensions)\ + .sum(dim=1).squeeze() # (layers, batch, dim) + padded_seqs = padded_seqs.view(batch_size, max_seq_size, 2, self.num_dimensions)\ + .sum(dim=2).squeeze() # (batch, seq, dim) return padded_seqs, (hs_h, hs_c) def _initialize_hidden_state(self, batch_size): if torch.cuda.is_available(): - return torch.zeros( - self.num_layers * 2, batch_size, self.num_dimensions - ).cuda() - return torch.zeros(self.num_layers * 2, batch_size, self.num_dimensions) + return torch.zeros(self.num_layers*2, batch_size, self.num_dimensions).cuda() def get_params(self): parameter_enums = GenerativeModelParametersEnum @@ -98,23 +76,23 @@ def get_params(self): parameter_enums.NUMBER_OF_LAYERS: self.num_layers, parameter_enums.NUMBER_OF_DIMENSIONS: self.num_dimensions, parameter_enums.VOCABULARY_SIZE: self.vocabulary_size, - parameter_enums.DROPOUT: self.dropout, + parameter_enums.DROPOUT: self.dropout } class AttentionLayer(tnn.Module): + def __init__(self, num_dimensions): super(AttentionLayer, self).__init__() self.num_dimensions = num_dimensions self._attention_linear = tnn.Sequential( - tnn.Linear(self.num_dimensions * 2, self.num_dimensions), tnn.Tanh() + tnn.Linear(self.num_dimensions*2, self.num_dimensions), + tnn.Tanh() ) - def forward( - self, padded_seqs, encoder_padded_seqs, decoder_mask - ): # pylint: disable=arguments-differ + def forward(self, padded_seqs, encoder_padded_seqs, decoder_mask): # pylint: disable=arguments-differ """ Performs the forward pass. :param padded_seqs: A tensor with the output sequences (batch, seq_d, dim). @@ -124,19 +102,12 @@ def forward( """ # scaled dot-product # (batch, seq_d, 1, dim)*(batch, 1, seq_e, dim) => (batch, seq_d, seq_e*) - attention_weights = ( - (padded_seqs.unsqueeze(dim=2) * encoder_padded_seqs.unsqueeze(dim=1)) - .sum(dim=3) - .div(math.sqrt(self.num_dimensions)) + attention_weights = (padded_seqs.unsqueeze(dim=2)*encoder_padded_seqs.unsqueeze(dim=1))\ + .sum(dim=3).div(math.sqrt(self.num_dimensions))\ .softmax(dim=2) - ) # (batch, seq_d, seq_e*)@(batch, seq_e, dim) => (batch, seq_d, dim) attention_context = attention_weights.bmm(encoder_padded_seqs) - return ( - self._attention_linear(torch.cat([padded_seqs, attention_context], dim=2)) - * decoder_mask, - attention_weights, - ) + return (self._attention_linear(torch.cat([padded_seqs, attention_context], dim=2))*decoder_mask, attention_weights) class Decoder(tnn.Module): @@ -154,26 +125,16 @@ def __init__(self, num_layers, num_dimensions, vocabulary_size, dropout): self._embedding = tnn.Sequential( tnn.Embedding(self.vocabulary_size, self.num_dimensions), - tnn.Dropout(dropout), - ) - self._rnn = tnn.LSTM( - self.num_dimensions, - self.num_dimensions, - self.num_layers, - batch_first=True, - dropout=self.dropout, - bidirectional=False, + tnn.Dropout(dropout) ) + self._rnn = tnn.LSTM(self.num_dimensions, self.num_dimensions, self.num_layers, + batch_first=True, dropout=self.dropout, bidirectional=False) self._attention = AttentionLayer(self.num_dimensions) - self._linear = tnn.Linear( - self.num_dimensions, self.vocabulary_size - ) # just to redimension + self._linear = tnn.Linear(self.num_dimensions, self.vocabulary_size) # just to redimension - def forward( - self, padded_seqs, seq_lengths, encoder_padded_seqs, hidden_states - ): # pylint: disable=arguments-differ + def forward(self, padded_seqs, seq_lengths, encoder_padded_seqs, hidden_states): # pylint: disable=arguments-differ """ Performs the forward pass. :param padded_seqs: A tensor with the output sequences (batch, seq_d, dim). @@ -187,20 +148,13 @@ def forward( padded_encoded_seqs = self._embedding(padded_seqs) packed_encoded_seqs = tnnur.pack_padded_sequence( - padded_encoded_seqs, seq_lengths, batch_first=True, enforce_sorted=False - ) - packed_encoded_seqs, hidden_states = self._rnn( - packed_encoded_seqs, hidden_states - ) - padded_encoded_seqs, _ = tnnur.pad_packed_sequence( - packed_encoded_seqs, batch_first=True - ) # (batch, seq, dim) + padded_encoded_seqs, seq_lengths, batch_first=True, enforce_sorted=False) + packed_encoded_seqs, hidden_states = self._rnn(packed_encoded_seqs, hidden_states) + padded_encoded_seqs, _ = tnnur.pad_packed_sequence(packed_encoded_seqs, batch_first=True) # (batch, seq, dim) mask = (padded_encoded_seqs[:, :, 0] != 0).unsqueeze(dim=-1).type(torch.float) - attn_padded_encoded_seqs, attention_weights = self._attention( - padded_encoded_seqs, encoder_padded_seqs, mask - ) - logits = self._linear(attn_padded_encoded_seqs) * mask # (batch, seq, voc_size) + attn_padded_encoded_seqs, attention_weights = self._attention(padded_encoded_seqs, encoder_padded_seqs, mask) + logits = self._linear(attn_padded_encoded_seqs)*mask # (batch, seq, voc_size) return logits, hidden_states, attention_weights def get_params(self): @@ -213,7 +167,7 @@ def get_params(self): parameter_enum.NUMBER_OF_LAYERS: self.num_layers, parameter_enum.NUMBER_OF_DIMENSIONS: self.num_dimensions, parameter_enum.VOCABULARY_SIZE: self.vocabulary_size, - parameter_enum.DROPOUT: self.dropout, + parameter_enum.DROPOUT: self.dropout } @@ -228,9 +182,7 @@ def __init__(self, encoder_params, decoder_params): self._encoder = Encoder(**encoder_params) self._decoder = Decoder(**decoder_params) - def forward( - self, encoder_seqs, encoder_seq_lengths, decoder_seqs, decoder_seq_lengths - ): # pylint: disable=arguments-differ + def forward(self, encoder_seqs, encoder_seq_lengths, decoder_seqs, decoder_seq_lengths): # pylint: disable=arguments-differ """ Performs the forward pass. :param encoder_seqs: A tensor with the output sequences (batch, seq_d, dim). @@ -239,12 +191,8 @@ def forward( :param decoder_seq_lengths: The lengths of the decoder sequences. :return : The output logits as a tensor (batch, seq_d, dim). """ - encoder_padded_seqs, hidden_states = self.forward_encoder( - encoder_seqs, encoder_seq_lengths - ) - logits, _, _ = self.forward_decoder( - decoder_seqs, decoder_seq_lengths, encoder_padded_seqs, hidden_states - ) + encoder_padded_seqs, hidden_states = self.forward_encoder(encoder_seqs, encoder_seq_lengths) + logits, _, _ = self.forward_decoder(decoder_seqs, decoder_seq_lengths, encoder_padded_seqs, hidden_states) return logits def forward_encoder(self, padded_seqs, seq_lengths): @@ -256,9 +204,7 @@ def forward_encoder(self, padded_seqs, seq_lengths): """ return self._encoder(padded_seqs, seq_lengths) - def forward_decoder( - self, padded_seqs, seq_lengths, encoder_padded_seqs, hidden_states - ): + def forward_decoder(self, padded_seqs, seq_lengths, encoder_padded_seqs, hidden_states): """ Does a forward pass only of the decoder. :param hidden_states: The hidden states from the encoder. @@ -266,9 +212,7 @@ def forward_decoder( :param seq_lengths: The length of each sequence in the batch. :return : Returns the logits and the hidden state for each element of the sequence passed. """ - return self._decoder( - padded_seqs, seq_lengths, encoder_padded_seqs, hidden_states - ) + return self._decoder(padded_seqs, seq_lengths, encoder_padded_seqs, hidden_states) def get_params(self): """ @@ -277,5 +221,5 @@ def get_params(self): """ return { "encoder_params": self._encoder.get_params(), - "decoder_params": self._decoder.get_params(), + "decoder_params": self._decoder.get_params() } diff --git a/reinvent_models/lib_invent/models/model.py b/reinvent_models/lib_invent/models/model.py index 04f27e9..f27782d 100644 --- a/reinvent_models/lib_invent/models/model.py +++ b/reinvent_models/lib_invent/models/model.py @@ -6,22 +6,14 @@ import torch.nn as tnn -from reinvent_models.lib_invent.enums.generative_model_regime import ( - GenerativeModelRegimeEnum, -) +from reinvent_models.lib_invent.enums.generative_model_regime import GenerativeModelRegimeEnum from reinvent_models.lib_invent.models.decorator import Decorator from reinvent_models.model_factory.enums.model_mode_enum import ModelModeEnum class DecoratorModel: - def __init__( - self, - vocabulary, - decorator, - max_sequence_length=256, - no_cuda=False, - mode=ModelModeEnum().TRAINING, - ): + + def __init__(self, vocabulary, decorator, max_sequence_length=256, no_cuda=False, mode=ModelModeEnum().TRAINING): """ Implements the likelihood and scaffold_decorating functions of the decorator model. :param vocabulary: A DecoratorVocabulary instance with the vocabularies of both the encoder and decoder. @@ -59,7 +51,11 @@ def load_from_file(cls, path, mode=ModelModeEnum().TRAINING): decorator = Decorator(**data["decorator"]["params"]) decorator.load_state_dict(data["decorator"]["state"]) - model = DecoratorModel(decorator=decorator, mode=mode, **data["model"]) + model = DecoratorModel( + decorator=decorator, + mode=mode, + **data["model"] + ) return model @@ -69,14 +65,14 @@ def save(self, path): :param path: Path to the file which the model will be saved to. """ save_dict = { - "model": { - "vocabulary": self.vocabulary, - "max_sequence_length": self.max_sequence_length, - }, - "decorator": { - "params": self.network.get_params(), - "state": self.network.state_dict(), + 'model': { + 'vocabulary': self.vocabulary, + 'max_sequence_length': self.max_sequence_length }, + 'decorator': { + 'params': self.network.get_params(), + 'state': self.network.state_dict() + } } torch.save(save_dict, path) @@ -92,13 +88,7 @@ def set_mode(self, mode): self.network.train() return self - def likelihood( - self, - scaffold_seqs, - scaffold_seq_lengths, - decoration_seqs, - decoration_seq_lengths, - ): + def likelihood(self, scaffold_seqs, scaffold_seq_lengths, decoration_seqs, decoration_seq_lengths): """ Retrieves the likelihood of a scaffold and its respective decorations. :param scaffold_seqs: (batch, seq) A batch of padded scaffold sequences. @@ -109,12 +99,8 @@ def likelihood( """ # NOTE: the decoration_seq_lengths have a - 1 to prevent the end token to be forward-passed. - logits = self.network( - scaffold_seqs, - scaffold_seq_lengths, - decoration_seqs, - decoration_seq_lengths - 1, - ) # (batch, seq - 1, voc) + logits = self.network(scaffold_seqs, scaffold_seq_lengths, decoration_seqs, + decoration_seq_lengths - 1) # (batch, seq - 1, voc) log_probs = logits.log_softmax(dim=2).transpose(1, 2) # (batch, voc, seq - 1) return self._nll_loss(log_probs, decoration_seqs[:, 1:]).sum(dim=1) # (batch) @@ -128,47 +114,37 @@ def sample_decorations(self, scaffold_seqs, scaffold_seq_lengths): :return: An iterator with (scaffold_smi, decoration_smi, nll) triplets. """ batch_size = scaffold_seqs.size(0) - + input_vector = torch.full( - (batch_size, 1), - self.vocabulary.decoration_vocabulary["^"], - dtype=torch.long, - ) # (batch, 1) - - seq_lengths = torch.ones(batch_size) # (batch) - encoder_padded_seqs, hidden_states = self.network.forward_encoder( - scaffold_seqs, scaffold_seq_lengths - ) + (batch_size, 1), self.vocabulary.decoration_vocabulary["^"], dtype=torch.long) nlls = torch.zeros(batch_size) not_finished = torch.ones(batch_size, 1, dtype=torch.long) + if torch.cuda.is_available(): - input_vector = input_vector.cuda() - nlls = nlls.cuda() - input_vector = input_vector.cuda() + input_vector = torch.full( + (batch_size, 1), self.vocabulary.decoration_vocabulary["^"], dtype=torch.long).cuda() # (batch, 1) + nlls = torch.zeros(batch_size).cuda() + not_finished = torch.ones(batch_size, 1, dtype=torch.long).cuda() + + # print(f"input_vector: {input_vector}") + seq_lengths = torch.ones(batch_size) # (batch) + encoder_padded_seqs, hidden_states = self.network.forward_encoder(scaffold_seqs, scaffold_seq_lengths) sequences = [] for _ in range(self.max_sequence_length - 1): logits, hidden_states, _ = self.network.forward_decoder( - input_vector, seq_lengths, encoder_padded_seqs, hidden_states - ) # (batch, 1, voc) + input_vector, seq_lengths, encoder_padded_seqs, hidden_states) # (batch, 1, voc) probs = logits.softmax(dim=2).squeeze() # (batch, voc) log_probs = logits.log_softmax(dim=2).squeeze() # (batch, voc) input_vector = torch.multinomial(probs, 1) * not_finished # (batch, 1) sequences.append(input_vector) nlls += self._nll_loss(log_probs, input_vector.squeeze()) - not_finished = (input_vector > 1).type( - torch.long - ) # 0 is padding, 1 is end token + not_finished = (input_vector > 1).type(torch.long) # 0 is padding, 1 is end token if not_finished.sum() == 0: break - decoration_smiles = [ - self.vocabulary.decode_decoration(seq) - for seq in torch.cat(sequences, 1).data.cpu().numpy() - ] - scaffold_smiles = [ - self.vocabulary.decode_scaffold(seq) - for seq in scaffold_seqs.data.cpu().numpy() - ] + decoration_smiles = [self.vocabulary.decode_decoration(seq) + for seq in torch.cat(sequences, 1).data.cpu().numpy()] + scaffold_smiles = [self.vocabulary.decode_scaffold(seq) for seq in scaffold_seqs.data.cpu().numpy()] return zip(scaffold_smiles, decoration_smiles, nlls.data.cpu().numpy().tolist()) def get_network_parameters(self): diff --git a/reinvent_models/lib_invent/models/vocabulary.py b/reinvent_models/lib_invent/models/vocabulary.py index 1ca75a5..bc8f941 100644 --- a/reinvent_models/lib_invent/models/vocabulary.py +++ b/reinvent_models/lib_invent/models/vocabulary.py @@ -137,7 +137,7 @@ class SMILESTokenizer: REGEXPS = { "brackets": re.compile(r"(\[[^\]]*\])"), "2_ring_nums": re.compile(r"(%\d{2})"), - "brcl": re.compile(r"(Br|Cl)"), + "brcl": re.compile(r"(Br|Cl)") } REGEXP_ORDER = ["brackets", "2_ring_nums", "brcl"] @@ -148,7 +148,6 @@ def tokenize(self, smiles, with_begin_and_end=True): :param with_begin_and_end: Appends a begin token and prepends an end token. :return : A list with the tokenized version. """ - def split_by(smiles, regexps): if not regexps: return list(smiles) @@ -203,13 +202,7 @@ class DecoratorVocabulary: Encapsulation of the two vocabularies needed for the decorator. """ - def __init__( - self, - scaffold_vocabulary, - scaffold_tokenizer, - decoration_vocabulary, - decoration_tokenizer, - ): + def __init__(self, scaffold_vocabulary, scaffold_tokenizer, decoration_vocabulary, decoration_tokenizer): self.scaffold_vocabulary = scaffold_vocabulary self.scaffold_tokenizer = scaffold_tokenizer self.decoration_vocabulary = decoration_vocabulary @@ -248,9 +241,7 @@ def decode_scaffold(self, encoded_scaffold): :param encoded_scaffold: A one-hot encoded version of the scaffold. :return : A SMILES of the scaffold. """ - return self.scaffold_tokenizer.untokenize( - self.scaffold_vocabulary.decode(encoded_scaffold) - ) + return self.scaffold_tokenizer.untokenize(self.scaffold_vocabulary.decode(encoded_scaffold)) def encode_decoration(self, smiles): """ @@ -258,9 +249,7 @@ def encode_decoration(self, smiles): :param smiles: Decoration SMILES to encode. :return : An one-hot-encoded vector with the fragment information. """ - return self.decoration_vocabulary.encode( - self.decoration_tokenizer.tokenize(smiles) - ) + return self.decoration_vocabulary.encode(self.decoration_tokenizer.tokenize(smiles)) def decode_decoration(self, encoded_decoration): """ @@ -268,9 +257,7 @@ def decode_decoration(self, encoded_decoration): :param encoded_decorations: A one-hot encoded version of the decoration. :return : A list with SMILES of all the fragments. """ - return self.decoration_tokenizer.untokenize( - self.decoration_vocabulary.decode(encoded_decoration) - ) + return self.decoration_tokenizer.untokenize(self.decoration_vocabulary.decode(encoded_decoration)) @classmethod def from_lists(cls, scaffold_list, decoration_list): @@ -286,9 +273,4 @@ def from_lists(cls, scaffold_list, decoration_list): decoration_tokenizer = SMILESTokenizer() decoration_vocabulary = create_vocabulary(decoration_list, decoration_tokenizer) - return DecoratorVocabulary( - scaffold_vocabulary, - scaffold_tokenizer, - decoration_vocabulary, - decoration_tokenizer, - ) + return DecoratorVocabulary(scaffold_vocabulary, scaffold_tokenizer, decoration_vocabulary, decoration_tokenizer) diff --git a/reinvent_models/link_invent/dataset/dataset.py b/reinvent_models/link_invent/dataset/dataset.py index 6237a16..b7daea9 100644 --- a/reinvent_models/link_invent/dataset/dataset.py +++ b/reinvent_models/link_invent/dataset/dataset.py @@ -6,9 +6,7 @@ from torch import Tensor from torch.nn.utils.rnn import pad_sequence -from reinvent_models.link_invent.model_vocabulary.model_vocabulary import ( - ModelVocabulary, -) +from reinvent_models.link_invent.model_vocabulary.model_vocabulary import ModelVocabulary class Dataset(tud.Dataset): @@ -34,9 +32,7 @@ def __init__(self, smiles_list, model_vocabulary: ModelVocabulary): # TODO log theses cases def __getitem__(self, i): - return torch.tensor( - self._encoded_list[i], dtype=torch.long - ) # pylint: disable=E1102 + return torch.tensor(self._encoded_list[i], dtype=torch.long) # pylint: disable=E1102 def __len__(self): return len(self._encoded_list) @@ -52,9 +48,8 @@ def _pad_batch(encoded_seqs: List) -> Tuple[Tensor, Tensor]: :param encoded_seqs: A list of encoded sequences. :return: A tensor with the sequences correctly padded. """ - seq_lengths = torch.tensor( - [len(seq) for seq in encoded_seqs], dtype=torch.int64 - ) - if torch.cuda.is_available(): - return pad_sequence(encoded_seqs, batch_first=True).cuda(), seq_lengths - return pad_sequence(encoded_seqs, batch_first=True), seq_lengths + seq_lengths = torch.tensor([len(seq) for seq in encoded_seqs], dtype=torch.int64) + return pad_sequence(encoded_seqs, batch_first=True).cuda() if torch.cuda.is_available() else pad_sequence(encoded_seqs, batch_first=True), seq_lengths + + + diff --git a/reinvent_models/link_invent/dataset/paired_dataset.py b/reinvent_models/link_invent/dataset/paired_dataset.py index 827aec8..4b62435 100644 --- a/reinvent_models/link_invent/dataset/paired_dataset.py +++ b/reinvent_models/link_invent/dataset/paired_dataset.py @@ -5,17 +5,13 @@ from torch.nn.utils.rnn import pad_sequence from torch.utils import data as tud -from reinvent_models.link_invent.model_vocabulary.paired_model_vocabulary import ( - PairedModelVocabulary, -) +from reinvent_models.link_invent.model_vocabulary.paired_model_vocabulary import PairedModelVocabulary class PairedDataset(tud.Dataset): """Dataset that takes a list of (input, output) pairs.""" - def __init__( - self, input_target_smi_list: List[List[str]], vocabulary: PairedModelVocabulary - ): + def __init__(self, input_target_smi_list: List[List[str]], vocabulary: PairedModelVocabulary): self.vocabulary = vocabulary self._encoded_list = [] @@ -30,10 +26,8 @@ def __init__( def __getitem__(self, i): en_input, en_output = self._encoded_list[i] - return ( - torch.tensor(en_input, dtype=torch.long), - torch.tensor(en_output, dtype=torch.long), - ) # pylint: disable=E1102 + return (torch.tensor(en_input, dtype=torch.long), + torch.tensor(en_output, dtype=torch.long)) # pylint: disable=E1102 def __len__(self): return len(self._encoded_list) @@ -55,9 +49,5 @@ def _pad_batch(encoded_seqs: List) -> Tuple[Tensor, Tensor]: :param encoded_seqs: A list of encoded sequences. :return: A tensor with the sequences correctly padded. """ - seq_lengths = torch.tensor( - [len(seq) for seq in encoded_seqs], dtype=torch.int64 - ) - if torch.cuda.is_available(): - return pad_sequence(encoded_seqs, batch_first=True).cuda(), seq_lengths - return pad_sequence(encoded_seqs, batch_first=True), seq_lengths + seq_lengths = torch.tensor([len(seq) for seq in encoded_seqs], dtype=torch.int64) + return pad_sequence(encoded_seqs, batch_first=True).cuda() if torch.cuda.is_available() else pad_sequence(encoded_seqs, batch_first=True), seq_lengths diff --git a/reinvent_models/link_invent/dto/__init__.py b/reinvent_models/link_invent/dto/__init__.py index 2354fa5..d96e08b 100644 --- a/reinvent_models/link_invent/dto/__init__.py +++ b/reinvent_models/link_invent/dto/__init__.py @@ -1,4 +1,2 @@ -from reinvent_models.link_invent.dto.link_invent_model_parameters_dto import ( - LinkInventModelParameterDTO, -) +from reinvent_models.link_invent.dto.link_invent_model_parameters_dto import LinkInventModelParameterDTO from reinvent_models.link_invent.dto.sampled_sequence_dto import SampledSequencesDTO diff --git a/reinvent_models/link_invent/dto/link_invent_model_parameters_dto.py b/reinvent_models/link_invent/dto/link_invent_model_parameters_dto.py index 8db4ce0..cc9ce1d 100644 --- a/reinvent_models/link_invent/dto/link_invent_model_parameters_dto.py +++ b/reinvent_models/link_invent/dto/link_invent_model_parameters_dto.py @@ -1,8 +1,6 @@ from dataclasses import dataclass -from reinvent_models.link_invent.model_vocabulary.paired_model_vocabulary import ( - PairedModelVocabulary, -) +from reinvent_models.link_invent.model_vocabulary.paired_model_vocabulary import PairedModelVocabulary @dataclass @@ -10,4 +8,4 @@ class LinkInventModelParameterDTO: vocabulary: PairedModelVocabulary max_sequence_length: int network_parameter: dict - network_state: dict + network_state: dict \ No newline at end of file diff --git a/reinvent_models/link_invent/dto/sampled_sequence_dto.py b/reinvent_models/link_invent/dto/sampled_sequence_dto.py index 7bf9e39..8861dd7 100644 --- a/reinvent_models/link_invent/dto/sampled_sequence_dto.py +++ b/reinvent_models/link_invent/dto/sampled_sequence_dto.py @@ -5,4 +5,4 @@ class SampledSequencesDTO: input: str output: str - nll: float + nll: float \ No newline at end of file diff --git a/reinvent_models/link_invent/link_invent_model.py b/reinvent_models/link_invent/link_invent_model.py index e13a7db..28059aa 100644 --- a/reinvent_models/link_invent/link_invent_model.py +++ b/reinvent_models/link_invent/link_invent_model.py @@ -1,4 +1,3 @@ -from dataclasses import asdict from typing import List, Union, Any import torch @@ -7,23 +6,15 @@ from reinvent_models.link_invent.dto import LinkInventModelParameterDTO from reinvent_models.link_invent.dto import SampledSequencesDTO -from reinvent_models.link_invent.model_vocabulary.paired_model_vocabulary import ( - PairedModelVocabulary, -) +from reinvent_models.link_invent.model_vocabulary.paired_model_vocabulary import PairedModelVocabulary from reinvent_models.model_factory.enums.model_mode_enum import ModelModeEnum from reinvent_models.model_factory.generative_model_base import GenerativeModelBase from reinvent_models.link_invent.networks import EncoderDecoder class LinkInventModel(GenerativeModelBase): - def __init__( - self, - vocabulary: PairedModelVocabulary, - network: EncoderDecoder, - max_sequence_length: int = 256, - no_cuda: bool = False, - mode: str = ModelModeEnum().TRAINING, - ): + def __init__(self, vocabulary: PairedModelVocabulary, network: EncoderDecoder, + max_sequence_length: int = 256, no_cuda: bool = False, mode: str = ModelModeEnum().TRAINING): self.vocabulary = vocabulary self.network = network self.max_sequence_length = max_sequence_length @@ -44,24 +35,23 @@ def set_mode(self, mode: str): raise ValueError(f"Invalid model mode '{mode}") @classmethod - def load_from_file( - cls, path_to_file, mode: str = ModelModeEnum().TRAINING - ) -> Union[Any, GenerativeModelBase]: + def load_from_file(cls, path_to_file, mode: str = ModelModeEnum().TRAINING) -> Union[Any, GenerativeModelBase] : """ Loads a model from a single file :param path_to_file: Path to the saved model :param mode: Mode in which the model should be initialized :return: An instance of the network """ - data = from_dict(LinkInventModelParameterDTO, torch.load(path_to_file)) + if torch.cuda.is_available(): + model_ = torch.load(path_to_file) + else: + model_ = torch.load(path_to_file, map_location=torch.device('cpu')) + + data = from_dict(LinkInventModelParameterDTO, model_) network = EncoderDecoder(**data.network_parameter) network.load_state_dict(data.network_state) - model = LinkInventModel( - vocabulary=data.vocabulary, - network=network, - max_sequence_length=data.max_sequence_length, - mode=mode, - ) + model = LinkInventModel(vocabulary=data.vocabulary, network=network, + max_sequence_length=data.max_sequence_length, mode=mode) return model def save_to_file(self, path_to_file): @@ -69,17 +59,12 @@ def save_to_file(self, path_to_file): Saves the model to a file. :param path_to_file: Path to the file which the model will be saved to. """ - data = LinkInventModelParameterDTO( - vocabulary=self.vocabulary, - max_sequence_length=self.max_sequence_length, - network_parameter=self.network.get_params(), - network_state=self.network.state_dict(), - ) - torch.save(asdict(data), path_to_file) - - def likelihood( - self, warheads_seqs, warheads_seq_lengths, linker_seqs, linker_seq_lengths - ): + data = LinkInventModelParameterDTO(vocabulary=self.vocabulary, max_sequence_length=self.max_sequence_length, + network_parameter=self.network.get_params(), + network_state=self.network.state_dict()) + torch.save(data.__dict__, path_to_file) + + def likelihood(self, warheads_seqs, warheads_seq_lengths, linker_seqs, linker_seq_lengths): """ Retrieves the likelihood of warheads and their respective linker. :param warheads_seqs: (batch, seq) A batch of padded scaffold sequences. @@ -90,9 +75,8 @@ def likelihood( """ # NOTE: the decoration_seq_lengths have a - 1 to prevent the end token to be forward-passed. - logits = self.network( - warheads_seqs, warheads_seq_lengths, linker_seqs, linker_seq_lengths - 1 - ) # (batch, seq - 1, voc) + logits = self.network(warheads_seqs, warheads_seq_lengths, linker_seqs, + linker_seq_lengths - 1) # (batch, seq - 1, voc) log_probs = logits.log_softmax(dim=2).transpose(1, 2) # (batch, voc, seq - 1) return self._nll_loss(log_probs, linker_seqs[:, 1:]).sum(dim=1) # (batch) @@ -106,51 +90,37 @@ def sample(self, inputs, input_seq_lengths) -> List[SampledSequencesDTO]: """ batch_size = inputs.size(0) - input_vector = torch.full( - (batch_size, 1), self.vocabulary.target.vocabulary["^"], dtype=torch.long - ) # (batch, 1) + input_vector = torch.full((batch_size, 1), self.vocabulary.target.vocabulary["^"], + dtype=torch.long) # (batch, 1) seq_lengths = torch.ones(batch_size) # (batch) - encoder_padded_seqs, hidden_states = self.network.forward_encoder( - inputs, input_seq_lengths - ) + encoder_padded_seqs, hidden_states = self.network.forward_encoder(inputs, input_seq_lengths) nlls = torch.zeros(batch_size) not_finished = torch.ones(batch_size, 1, dtype=torch.long) + if torch.cuda.is_available(): - input_vector = input_vector.cuda() - nlls = nlls.cuda() - input_vector = input_vector.cuda() + input_vector = torch.full((batch_size, 1), self.vocabulary.target.vocabulary["^"], + dtype=torch.long).cuda() # (batch, 1) + nlls = torch.zeros(batch_size).cuda() + not_finished = torch.ones(batch_size, 1, dtype=torch.long).cuda() + sequences = [] for _ in range(self.max_sequence_length - 1): logits, hidden_states, _ = self.network.forward_decoder( - input_vector, seq_lengths, encoder_padded_seqs, hidden_states - ) # (batch, 1, voc) - probs = logits.softmax(dim=2).squeeze() # (batch, voc) - log_probs = logits.log_softmax(dim=2).squeeze() # (batch, voc) + input_vector, seq_lengths, encoder_padded_seqs, hidden_states) # (batch, 1, voc) + probs = logits.softmax(dim=2).squeeze(dim=1) # (batch, voc) + log_probs = logits.log_softmax(dim=2).squeeze(dim=1) # (batch, voc) input_vector = torch.multinomial(probs, 1) * not_finished # (batch, 1) sequences.append(input_vector) - nlls += self._nll_loss(log_probs, input_vector.squeeze()) - not_finished = (input_vector > 1).type( - torch.long - ) # 0 is padding, 1 is end token + nlls += self._nll_loss(log_probs, input_vector.squeeze(dim=1)) + not_finished = (input_vector > 1).type(torch.long) # 0 is padding, 1 is end token if not_finished.sum() == 0: break - linker_smiles_list = [ - self.vocabulary.target.decode(seq) - for seq in torch.cat(sequences, 1).data.cpu().numpy() - ] - warheads_smiles_list = [ - self.vocabulary.input.decode(seq) for seq in inputs.data.cpu().numpy() - ] - - result = [ - SampledSequencesDTO(warheads, linker, nll) - for warheads, linker, nll in zip( - warheads_smiles_list, - linker_smiles_list, - nlls.data.cpu().numpy().tolist(), - ) - ] + linker_smiles_list = [self.vocabulary.target.decode(seq) for seq in torch.cat(sequences, 1).data.cpu().numpy()] + warheads_smiles_list = [self.vocabulary.input.decode(seq) for seq in inputs.data.cpu().numpy()] + + result = [SampledSequencesDTO(warheads, linker, nll) for warheads, linker, nll in + zip(warheads_smiles_list, linker_smiles_list, nlls.data.cpu().numpy().tolist())] return result def get_network_parameters(self): diff --git a/reinvent_models/link_invent/model_vocabulary/__init__.py b/reinvent_models/link_invent/model_vocabulary/__init__.py index 23d4e97..b00145d 100644 --- a/reinvent_models/link_invent/model_vocabulary/__init__.py +++ b/reinvent_models/link_invent/model_vocabulary/__init__.py @@ -1,3 +1,3 @@ -from reinvent_models.link_invent.model_vocabulary.model_vocabulary import ( - ModelVocabulary, -) +from reinvent_models.link_invent.model_vocabulary.model_vocabulary import ModelVocabulary + + diff --git a/reinvent_models/link_invent/model_vocabulary/model_vocabulary.py b/reinvent_models/link_invent/model_vocabulary/model_vocabulary.py index d87ff1b..4042766 100644 --- a/reinvent_models/link_invent/model_vocabulary/model_vocabulary.py +++ b/reinvent_models/link_invent/model_vocabulary/model_vocabulary.py @@ -1,10 +1,6 @@ from typing import List -from reinvent_models.link_invent.model_vocabulary.vocabulary import ( - Vocabulary, - SMILESTokenizer, - create_vocabulary, -) +from reinvent_models.link_invent.model_vocabulary.vocabulary import Vocabulary, SMILESTokenizer, create_vocabulary class ModelVocabulary: diff --git a/reinvent_models/link_invent/model_vocabulary/paired_model_vocabulary.py b/reinvent_models/link_invent/model_vocabulary/paired_model_vocabulary.py index a30c97b..3323a24 100644 --- a/reinvent_models/link_invent/model_vocabulary/paired_model_vocabulary.py +++ b/reinvent_models/link_invent/model_vocabulary/paired_model_vocabulary.py @@ -1,22 +1,12 @@ from typing import List -from reinvent_models.link_invent.model_vocabulary.vocabulary import ( - SMILESTokenizer, - Vocabulary, -) -from reinvent_models.link_invent.model_vocabulary.model_vocabulary import ( - ModelVocabulary, -) +from reinvent_models.link_invent.model_vocabulary.vocabulary import SMILESTokenizer, Vocabulary +from reinvent_models.link_invent.model_vocabulary.model_vocabulary import ModelVocabulary class PairedModelVocabulary: - def __init__( - self, - input_vocabulary: Vocabulary, - input_tokenizer: SMILESTokenizer, - output_vocabulary: Vocabulary, - output_tokenizer: SMILESTokenizer, - ): + def __init__(self, input_vocabulary: Vocabulary, input_tokenizer: SMILESTokenizer, + output_vocabulary: Vocabulary, output_tokenizer: SMILESTokenizer): self.input = ModelVocabulary(input_vocabulary, input_tokenizer) self.target = ModelVocabulary(output_vocabulary, output_tokenizer) @@ -34,9 +24,5 @@ def from_lists(cls, input_smiles_list: List[str], target_smiles_list: List[str]) input_vocabulary = ModelVocabulary.from_list(input_smiles_list) target_vocabulary = ModelVocabulary.from_list(target_smiles_list) - return PairedModelVocabulary( - input_vocabulary.vocabulary, - input_vocabulary.tokenizer, - target_vocabulary.vocabulary, - input_vocabulary.tokenizer, - ) + return PairedModelVocabulary(input_vocabulary.vocabulary, input_vocabulary.tokenizer, + target_vocabulary.vocabulary, input_vocabulary.tokenizer) diff --git a/reinvent_models/link_invent/model_vocabulary/vocabulary.py b/reinvent_models/link_invent/model_vocabulary/vocabulary.py index 69f1bd9..0758a3b 100644 --- a/reinvent_models/link_invent/model_vocabulary/vocabulary.py +++ b/reinvent_models/link_invent/model_vocabulary/vocabulary.py @@ -137,7 +137,7 @@ class SMILESTokenizer: REGEXPS = { "brackets": re.compile(r"(\[[^\]]*\])"), "2_ring_nums": re.compile(r"(%\d{2})"), - "brcl": re.compile(r"(Br|Cl)"), + "brcl": re.compile(r"(Br|Cl)") } REGEXP_ORDER = ["brackets", "2_ring_nums", "brcl"] @@ -148,7 +148,6 @@ def tokenize(self, smiles, with_begin_and_end=True): :param with_begin_and_end: Appends a begin token and prepends an end token. :return : A list with the tokenized version. """ - def split_by(smiles, regexps): if not regexps: return list(smiles) diff --git a/reinvent_models/link_invent/networks/attention_layer.py b/reinvent_models/link_invent/networks/attention_layer.py new file mode 100644 index 0000000..1c07a6c --- /dev/null +++ b/reinvent_models/link_invent/networks/attention_layer.py @@ -0,0 +1,37 @@ +import math + +import torch +from torch import nn as tnn + + +class AttentionLayer(tnn.Module): + + def __init__(self, num_dimensions: int): + super(AttentionLayer, self).__init__() + + self.num_dimensions = num_dimensions + + self._attention_linear = tnn.Sequential( + tnn.Linear(self.num_dimensions*2, self.num_dimensions), + tnn.Tanh() + ) + + def forward(self, padded_seqs: torch.Tensor, encoder_padded_seqs: torch.Tensor, decoder_mask: torch.Tensor) \ + -> (torch.Tensor, torch.Tensor): # pylint: disable=arguments-differ + """ + Performs the forward pass. + :param padded_seqs: A tensor with the output sequences (batch, seq_d, dim). + :param encoder_padded_seqs: A tensor with the encoded input sequences (batch, seq_e, dim). + :param decoder_mask: A tensor that represents the encoded input mask. + :return : Two tensors: one with the modified logits and another with the attention weights. + """ + # scaled dot-product + # (batch, seq_d, 1, dim)*(batch, 1, seq_e, dim) => (batch, seq_d, seq_e*) + attention_weights = (padded_seqs.unsqueeze(dim=2)*encoder_padded_seqs.unsqueeze(dim=1))\ + .sum(dim=3).div(math.sqrt(self.num_dimensions))\ + .softmax(dim=2) + # (batch, seq_d, seq_e*)@(batch, seq_e, dim) => (batch, seq_d, dim) + attention_context = attention_weights.bmm(encoder_padded_seqs) + + return (self._attention_linear(torch.cat([padded_seqs, attention_context], dim=2))*decoder_mask, + attention_weights) diff --git a/reinvent_models/link_invent/networks/decoder.py b/reinvent_models/link_invent/networks/decoder.py new file mode 100644 index 0000000..0e7a7cc --- /dev/null +++ b/reinvent_models/link_invent/networks/decoder.py @@ -0,0 +1,71 @@ +from typing import Tuple + +import torch +from torch import nn as tnn +from torch.nn.utils import rnn as tnnur + +from reinvent_models.link_invent.networks.attention_layer import AttentionLayer +from reinvent_models.model_factory.enums.model_parameter_enum import ModelParametersEnum + + +class Decoder(tnn.Module): + """ + Simple RNN decoder. + """ + + def __init__(self, num_layers: int, num_dimensions: int, vocabulary_size: int, dropout: float): + super(Decoder, self).__init__() + + self.num_layers = num_layers + self.num_dimensions = num_dimensions + self.vocabulary_size = vocabulary_size + self.dropout = dropout + + self._embedding = tnn.Sequential( + tnn.Embedding(self.vocabulary_size, self.num_dimensions), + tnn.Dropout(dropout) + ) + self._rnn = tnn.LSTM(self.num_dimensions, self.num_dimensions, self.num_layers, + batch_first=True, dropout=self.dropout, bidirectional=False) + + self._attention = AttentionLayer(self.num_dimensions) + + self._linear = tnn.Linear(self.num_dimensions, self.vocabulary_size) # just to redimension + + def forward(self, padded_seqs: torch.Tensor, seq_lengths: torch.Tensor, + encoder_padded_seqs: torch.Tensor, hidden_states: Tuple[torch.Tensor]) \ + -> (torch.Tensor, Tuple[torch.Tensor], torch.Tensor): # pylint: disable=arguments-differ + """ + Performs the forward pass. + :param padded_seqs: A tensor with the output sequences (batch, seq_d, dim). + :param seq_lengths: A list with the length of each output sequence. + :param encoder_padded_seqs: A tensor with the encoded input sequences (batch, seq_e, dim). + :param hidden_states: The hidden states from the encoder. + :return : Three tensors: The output logits, the hidden states of the decoder and the attention weights. + """ + # FIXME: this is to guard against non compatible `gpu` input for pack_padded_sequence() method in pytorch 1.7 + seq_lengths = seq_lengths.cpu() + + padded_encoded_seqs = self._embedding(padded_seqs) + packed_encoded_seqs = tnnur.pack_padded_sequence( + padded_encoded_seqs, seq_lengths, batch_first=True, enforce_sorted=False) + packed_encoded_seqs, hidden_states = self._rnn(packed_encoded_seqs, hidden_states) + padded_encoded_seqs, _ = tnnur.pad_packed_sequence(packed_encoded_seqs, batch_first=True) # (batch, seq, dim) + + mask = (padded_encoded_seqs[:, :, 0] != 0).unsqueeze(dim=-1).type(torch.float) + attn_padded_encoded_seqs, attention_weights = self._attention(padded_encoded_seqs, encoder_padded_seqs, mask) + logits = self._linear(attn_padded_encoded_seqs)*mask # (batch, seq, voc_size) + return logits, hidden_states, attention_weights + + def get_params(self) -> dict: + parameter_enum = ModelParametersEnum + """ + Obtains the params for the network. + :return : A dict with the params. + """ + return { + parameter_enum.NUMBER_OF_LAYERS: self.num_layers, + parameter_enum.NUMBER_OF_DIMENSIONS: self.num_dimensions, + parameter_enum.VOCABULARY_SIZE: self.vocabulary_size, + parameter_enum.DROPOUT: self.dropout + } \ No newline at end of file diff --git a/reinvent_models/link_invent/networks/encoder.py b/reinvent_models/link_invent/networks/encoder.py new file mode 100644 index 0000000..d924146 --- /dev/null +++ b/reinvent_models/link_invent/networks/encoder.py @@ -0,0 +1,70 @@ +import torch +from torch import nn as tnn +from torch.nn.utils import rnn as tnnur + +from reinvent_models.model_factory.enums.model_parameter_enum import ModelParametersEnum + + +class Encoder(tnn.Module): + """ + Simple bidirectional RNN encoder implementation. + """ + + def __init__(self, num_layers: int, num_dimensions: int, vocabulary_size: int, dropout: float): + super(Encoder, self).__init__() + + self.num_layers = num_layers + self.num_dimensions = num_dimensions + self.vocabulary_size = vocabulary_size + self.dropout = dropout + + self._embedding = tnn.Sequential( + tnn.Embedding(self.vocabulary_size, self.num_dimensions), + tnn.Dropout(dropout) + ) + self._rnn = tnn.LSTM(self.num_dimensions, self.num_dimensions, self.num_layers, + batch_first=True, dropout=self.dropout, bidirectional=True) + + def forward(self, padded_seqs: torch.Tensor, seq_lengths: torch.Tensor) \ + -> (torch.Tensor, (torch.Tensor, torch.Tensor)): # pylint: disable=arguments-differ + """ + Performs the forward pass. + :param padded_seqs: A tensor with the sequences (batch, seq). + :param seq_lengths: The lengths of the sequences (for packed sequences). + :return : A tensor with all the output values for each step and the two hidden states. + """ + batch_size, max_seq_size = padded_seqs.size() + hidden_state = self._initialize_hidden_state(batch_size) + + padded_seqs = self._embedding(padded_seqs) + hs_h, hs_c = (hidden_state, hidden_state.clone().detach()) + + #FIXME: this is to guard against non compatible `gpu` input for pack_padded_sequence() method in pytorch 1.7 + seq_lengths = seq_lengths.cpu() + + packed_seqs = tnnur.pack_padded_sequence(padded_seqs, seq_lengths, batch_first=True, enforce_sorted=False) + packed_seqs, (hs_h, hs_c) = self._rnn(packed_seqs, (hs_h, hs_c)) + padded_seqs, _ = tnnur.pad_packed_sequence(packed_seqs, batch_first=True) + + # sum up bidirectional layers and collapse + hs_h = hs_h.view(self.num_layers, 2, batch_size, self.num_dimensions).sum(dim=1) # (layers, batch, dim) + hs_c = hs_c.view(self.num_layers, 2, batch_size, self.num_dimensions).sum(dim=1) # (layers, batch, dim) + padded_seqs = padded_seqs.view(batch_size, max_seq_size, 2, self.num_dimensions).sum(dim=2) # (batch, seq, dim) + + return padded_seqs, (hs_h, hs_c) + + def _initialize_hidden_state(self, batch_size: int) -> torch.Tensor: + return torch.zeros(self.num_layers*2, batch_size, self.num_dimensions).cuda() if torch.cuda.is_available() else torch.zeros(self.num_layers*2, batch_size, self.num_dimensions) + + def get_params(self) -> dict: + parameter_enums = ModelParametersEnum + """ + Obtains the params for the network. + :return : A dict with the params. + """ + return { + parameter_enums.NUMBER_OF_LAYERS: self.num_layers, + parameter_enums.NUMBER_OF_DIMENSIONS: self.num_dimensions, + parameter_enums.VOCABULARY_SIZE: self.vocabulary_size, + parameter_enums.DROPOUT: self.dropout + } \ No newline at end of file diff --git a/reinvent_models/link_invent/networks/encoder_decoder.py b/reinvent_models/link_invent/networks/encoder_decoder.py index c2f9943..f71b2aa 100644 --- a/reinvent_models/link_invent/networks/encoder_decoder.py +++ b/reinvent_models/link_invent/networks/encoder_decoder.py @@ -1,218 +1,12 @@ """ Implementation of a network using an Encoder-Decoder architecture. """ -import math -import torch import torch.nn as tnn -import torch.nn.utils.rnn as tnnur +from torch import Tensor -from reinvent_models.model_factory.enums.model_parameter_enum import ModelParametersEnum - - -class Encoder(tnn.Module): - """ - Simple bidirectional RNN encoder implementation. - """ - - def __init__(self, num_layers, num_dimensions, vocabulary_size, dropout): - super(Encoder, self).__init__() - - self.num_layers = num_layers - self.num_dimensions = num_dimensions - self.vocabulary_size = vocabulary_size - self.dropout = dropout - - self._embedding = tnn.Sequential( - tnn.Embedding(self.vocabulary_size, self.num_dimensions), - tnn.Dropout(dropout), - ) - self._rnn = tnn.LSTM( - self.num_dimensions, - self.num_dimensions, - self.num_layers, - batch_first=True, - dropout=self.dropout, - bidirectional=True, - ) - - def forward(self, padded_seqs, seq_lengths): # pylint: disable=arguments-differ - # FIXME: This fails with a batch of 1 because squeezing looses a dimension with size 1 - """ - Performs the forward pass. - :param padded_seqs: A tensor with the sequences (batch, seq). - :param seq_lengths: The lengths of the sequences (for packed sequences). - :return : A tensor with all the output values for each step and the two hidden states. - """ - batch_size = padded_seqs.size(0) - max_seq_size = padded_seqs.size(1) - hidden_state = self._initialize_hidden_state(batch_size) - - padded_seqs = self._embedding(padded_seqs) - hs_h, hs_c = (hidden_state, hidden_state.clone().detach()) - - # FIXME: this is to guard against non compatible `gpu` input for pack_padded_sequence() method in pytorch 1.7 - seq_lengths = seq_lengths.cpu() - - packed_seqs = tnnur.pack_padded_sequence( - padded_seqs, seq_lengths, batch_first=True, enforce_sorted=False - ) - packed_seqs, (hs_h, hs_c) = self._rnn(packed_seqs, (hs_h, hs_c)) - padded_seqs, _ = tnnur.pad_packed_sequence(packed_seqs, batch_first=True) - - # sum up bidirectional layers and collapse - hs_h = ( - hs_h.view(self.num_layers, 2, batch_size, self.num_dimensions) - .sum(dim=1) - .squeeze() - ) # (layers, batch, dim) - hs_c = ( - hs_c.view(self.num_layers, 2, batch_size, self.num_dimensions) - .sum(dim=1) - .squeeze() - ) # (layers, batch, dim) - padded_seqs = ( - padded_seqs.view(batch_size, max_seq_size, 2, self.num_dimensions) - .sum(dim=2) - .squeeze() - ) # (batch, seq, dim) - - return padded_seqs, (hs_h, hs_c) - - def _initialize_hidden_state(self, batch_size): - if torch.cuda.is_available(): - return torch.zeros( - self.num_layers * 2, batch_size, self.num_dimensions - ).cuda() - return torch.zeros(self.num_layers * 2, batch_size, self.num_dimensions) - - def get_params(self): - parameter_enums = ModelParametersEnum - """ - Obtains the params for the network. - :return : A dict with the params. - """ - return { - parameter_enums.NUMBER_OF_LAYERS: self.num_layers, - parameter_enums.NUMBER_OF_DIMENSIONS: self.num_dimensions, - parameter_enums.VOCABULARY_SIZE: self.vocabulary_size, - parameter_enums.DROPOUT: self.dropout, - } - - -class AttentionLayer(tnn.Module): - def __init__(self, num_dimensions): - super(AttentionLayer, self).__init__() - - self.num_dimensions = num_dimensions - - self._attention_linear = tnn.Sequential( - tnn.Linear(self.num_dimensions * 2, self.num_dimensions), tnn.Tanh() - ) - - def forward( - self, padded_seqs, encoder_padded_seqs, decoder_mask - ): # pylint: disable=arguments-differ - """ - Performs the forward pass. - :param padded_seqs: A tensor with the output sequences (batch, seq_d, dim). - :param encoder_padded_seqs: A tensor with the encoded input sequences (batch, seq_e, dim). - :param decoder_mask: A tensor that represents the encoded input mask. - :return : Two tensors: one with the modified logits and another with the attention weights. - """ - # scaled dot-product - # (batch, seq_d, 1, dim)*(batch, 1, seq_e, dim) => (batch, seq_d, seq_e*) - attention_weights = ( - (padded_seqs.unsqueeze(dim=2) * encoder_padded_seqs.unsqueeze(dim=1)) - .sum(dim=3) - .div(math.sqrt(self.num_dimensions)) - .softmax(dim=2) - ) - # (batch, seq_d, seq_e*)@(batch, seq_e, dim) => (batch, seq_d, dim) - attention_context = attention_weights.bmm(encoder_padded_seqs) - return ( - self._attention_linear(torch.cat([padded_seqs, attention_context], dim=2)) - * decoder_mask, - attention_weights, - ) - - -class Decoder(tnn.Module): - """ - Simple RNN decoder. - """ - - def __init__(self, num_layers, num_dimensions, vocabulary_size, dropout): - super(Decoder, self).__init__() - - self.num_layers = num_layers - self.num_dimensions = num_dimensions - self.vocabulary_size = vocabulary_size - self.dropout = dropout - - self._embedding = tnn.Sequential( - tnn.Embedding(self.vocabulary_size, self.num_dimensions), - tnn.Dropout(dropout), - ) - self._rnn = tnn.LSTM( - self.num_dimensions, - self.num_dimensions, - self.num_layers, - batch_first=True, - dropout=self.dropout, - bidirectional=False, - ) - - self._attention = AttentionLayer(self.num_dimensions) - - self._linear = tnn.Linear( - self.num_dimensions, self.vocabulary_size - ) # just to redimension - - def forward( - self, padded_seqs, seq_lengths, encoder_padded_seqs, hidden_states - ): # pylint: disable=arguments-differ - """ - Performs the forward pass. - :param padded_seqs: A tensor with the output sequences (batch, seq_d, dim). - :param seq_lengths: A list with the length of each output sequence. - :param encoder_padded_seqs: A tensor with the encoded input sequences (batch, seq_e, dim). - :param hidden_states: The hidden states from the encoder. - :return : Three tensors: The output logits, the hidden states of the decoder and the attention weights. - """ - # FIXME: this is to guard against non compatible `gpu` input for pack_padded_sequence() method in pytorch 1.7 - seq_lengths = seq_lengths.cpu() - - padded_encoded_seqs = self._embedding(padded_seqs) - packed_encoded_seqs = tnnur.pack_padded_sequence( - padded_encoded_seqs, seq_lengths, batch_first=True, enforce_sorted=False - ) - packed_encoded_seqs, hidden_states = self._rnn( - packed_encoded_seqs, hidden_states - ) - padded_encoded_seqs, _ = tnnur.pad_packed_sequence( - packed_encoded_seqs, batch_first=True - ) # (batch, seq, dim) - - mask = (padded_encoded_seqs[:, :, 0] != 0).unsqueeze(dim=-1).type(torch.float) - attn_padded_encoded_seqs, attention_weights = self._attention( - padded_encoded_seqs, encoder_padded_seqs, mask - ) - logits = self._linear(attn_padded_encoded_seqs) * mask # (batch, seq, voc_size) - return logits, hidden_states, attention_weights - - def get_params(self): - parameter_enum = ModelParametersEnum - """ - Obtains the params for the network. - :return : A dict with the params. - """ - return { - parameter_enum.NUMBER_OF_LAYERS: self.num_layers, - parameter_enum.NUMBER_OF_DIMENSIONS: self.num_dimensions, - parameter_enum.VOCABULARY_SIZE: self.vocabulary_size, - parameter_enum.DROPOUT: self.dropout, - } +from reinvent_models.link_invent.networks.decoder import Decoder +from reinvent_models.link_invent.networks.encoder import Encoder class EncoderDecoder(tnn.Module): @@ -220,15 +14,14 @@ class EncoderDecoder(tnn.Module): An encoder-decoder that combines input with generated targets. """ - def __init__(self, encoder_params, decoder_params): + def __init__(self, encoder_params: dict, decoder_params: dict): super(EncoderDecoder, self).__init__() self._encoder = Encoder(**encoder_params) self._decoder = Decoder(**decoder_params) - def forward( - self, encoder_seqs, encoder_seq_lengths, decoder_seqs, decoder_seq_lengths - ): # pylint: disable=arguments-differ + def forward(self, encoder_seqs: Tensor, encoder_seq_lengths: Tensor, decoder_seqs: Tensor, + decoder_seq_lengths: Tensor): """ Performs the forward pass. :param encoder_seqs: A tensor with the output sequences (batch, seq_d, dim). @@ -237,15 +30,11 @@ def forward( :param decoder_seq_lengths: The lengths of the decoder sequences. :return : The output logits as a tensor (batch, seq_d, dim). """ - encoder_padded_seqs, hidden_states = self.forward_encoder( - encoder_seqs, encoder_seq_lengths - ) - logits, _, _ = self.forward_decoder( - decoder_seqs, decoder_seq_lengths, encoder_padded_seqs, hidden_states - ) + encoder_padded_seqs, hidden_states = self.forward_encoder(encoder_seqs, encoder_seq_lengths) + logits, _, _ = self.forward_decoder(decoder_seqs, decoder_seq_lengths, encoder_padded_seqs, hidden_states) return logits - def forward_encoder(self, padded_seqs, seq_lengths): + def forward_encoder(self, padded_seqs: Tensor, seq_lengths: Tensor): """ Does a forward pass only of the encoder. :param padded_seqs: The data to feed the encoder. @@ -254,9 +43,8 @@ def forward_encoder(self, padded_seqs, seq_lengths): """ return self._encoder(padded_seqs, seq_lengths) - def forward_decoder( - self, padded_seqs, seq_lengths, encoder_padded_seqs, hidden_states - ): + def forward_decoder(self, padded_seqs: Tensor, seq_lengths: Tensor, encoder_padded_seqs: Tensor, + hidden_states: Tensor): """ Does a forward pass only of the decoder. :param hidden_states: The hidden states from the encoder. @@ -264,9 +52,7 @@ def forward_decoder( :param seq_lengths: The length of each sequence in the batch. :return : Returns the logits and the hidden state for each element of the sequence passed. """ - return self._decoder( - padded_seqs, seq_lengths, encoder_padded_seqs, hidden_states - ) + return self._decoder(padded_seqs, seq_lengths, encoder_padded_seqs, hidden_states) def get_params(self): """ @@ -275,5 +61,5 @@ def get_params(self): """ return { "encoder_params": self._encoder.get_params(), - "decoder_params": self._decoder.get_params(), + "decoder_params": self._decoder.get_params() } diff --git a/reinvent_models/model_factory/enums/model_mode_enum.py b/reinvent_models/model_factory/enums/model_mode_enum.py index e53888b..b4f3f3f 100644 --- a/reinvent_models/model_factory/enums/model_mode_enum.py +++ b/reinvent_models/model_factory/enums/model_mode_enum.py @@ -1,13 +1,7 @@ +from dataclasses import dataclass + + +@dataclass(frozen=True) class ModelModeEnum: INFERENCE = "inference" TRAINING = "training" - - # try to find the internal value and return - def __getattr__(self, name): - if name in self: - return name - raise AttributeError - - # prohibit any attempt to set any values - def __setattr__(self, key, value): - raise ValueError("No changes allowed.") diff --git a/reinvent_models/model_factory/enums/model_parameter_enum.py b/reinvent_models/model_factory/enums/model_parameter_enum.py index 84046b0..48b01e3 100644 --- a/reinvent_models/model_factory/enums/model_parameter_enum.py +++ b/reinvent_models/model_factory/enums/model_parameter_enum.py @@ -1,15 +1,9 @@ +from dataclasses import dataclass + + +@dataclass(frozen=True) class ModelParametersEnum: NUMBER_OF_LAYERS = "num_layers" NUMBER_OF_DIMENSIONS = "num_dimensions" VOCABULARY_SIZE = "vocabulary_size" DROPOUT = "dropout" - - # try to find the internal value and return - def __getattr__(self, name): - if name in self: - return name - raise AttributeError - - # prohibit any attempt to set any values - def __setattr__(self, key, value): - raise ValueError("No changes allowed.") diff --git a/reinvent_models/model_factory/enums/model_type_enum.py b/reinvent_models/model_factory/enums/model_type_enum.py index ff82bb3..e7987cf 100644 --- a/reinvent_models/model_factory/enums/model_type_enum.py +++ b/reinvent_models/model_factory/enums/model_type_enum.py @@ -1,15 +1,9 @@ +from dataclasses import dataclass + + +@dataclass(frozen=True) class ModelTypeEnum: DEFAULT = "default" REINVENT_CORE = "reinvent_core" LIB_INVENT = "lib_invent" LINK_INVENT = "link_invent" - - # try to find the internal value and return - def __getattr__(self, name): - if name in self: - return name - raise AttributeError - - # prohibit any attempt to set any values - def __setattr__(self, key, value): - raise ValueError("No changes allowed.") diff --git a/reinvent_models/model_factory/generative_model.py b/reinvent_models/model_factory/generative_model.py index 9e42086..d9ad886 100644 --- a/reinvent_models/model_factory/generative_model.py +++ b/reinvent_models/model_factory/generative_model.py @@ -1,6 +1,4 @@ -from reinvent_models.model_factory.configurations.model_configuration import ( - ModelConfiguration, -) +from reinvent_models.model_factory.configurations.model_configuration import ModelConfiguration from reinvent_models.model_factory.enums.model_type_enum import ModelTypeEnum from reinvent_models.model_factory.generative_model_base import GenerativeModelBase from reinvent_models.model_factory.lib_invent_adapter import LibInventAdapter @@ -14,19 +12,12 @@ def __new__(cls, configuration: ModelConfiguration) -> GenerativeModelBase: model_type_enum = ModelTypeEnum() if cls._configuration.model_type == model_type_enum.DEFAULT: - model = ReinventCoreAdapter( - cls._configuration.model_file_path, mode=cls._configuration.model_mode - ) + model = ReinventCoreAdapter(cls._configuration.model_file_path, mode=cls._configuration.model_mode) elif cls._configuration.model_type == model_type_enum.LIB_INVENT: - model = LibInventAdapter( - cls._configuration.model_file_path, mode=cls._configuration.model_mode - ) + model = LibInventAdapter(cls._configuration.model_file_path, mode=cls._configuration.model_mode) elif cls._configuration.model_type == model_type_enum.LINK_INVENT: - model = LinkInventAdapter( - cls._configuration.model_file_path, mode=cls._configuration.model_mode - ) + model = LinkInventAdapter(cls._configuration.model_file_path, mode=cls._configuration.model_mode) else: - raise ValueError( - f"Invalid model_type provided: '{cls._configuration.model_type}" - ) + raise ValueError(f"Invalid model_type provided: '{cls._configuration.model_type}") return model + diff --git a/reinvent_models/model_factory/generative_model_base.py b/reinvent_models/model_factory/generative_model_base.py index 2b2833f..85a74a5 100644 --- a/reinvent_models/model_factory/generative_model_base.py +++ b/reinvent_models/model_factory/generative_model_base.py @@ -2,6 +2,7 @@ class GenerativeModelBase(ABC): + @abstractmethod def save_to_file(self, path_to_file: str): raise NotImplementedError("save_to_file method is not implemented") diff --git a/reinvent_models/model_factory/lib_invent_adapter.py b/reinvent_models/model_factory/lib_invent_adapter.py index 55659a9..334c742 100644 --- a/reinvent_models/model_factory/lib_invent_adapter.py +++ b/reinvent_models/model_factory/lib_invent_adapter.py @@ -1,9 +1,9 @@ from reinvent_models.lib_invent.models.model import DecoratorModel -from reinvent_models.model_factory.enums.model_mode_enum import ModelModeEnum from reinvent_models.model_factory.generative_model_base import GenerativeModelBase class LibInventAdapter(GenerativeModelBase): + def __init__(self, path_to_file: str, mode: str): self.generative_model = DecoratorModel.load_from_file(path_to_file, mode) self.vocabulary = self.generative_model.vocabulary @@ -11,23 +11,13 @@ def __init__(self, path_to_file: str, mode: str): self.network = self.generative_model.network def save_to_file(self, path): - self.generative_model.save_to_file(path) - - def likelihood( - self, - scaffold_seqs, - scaffold_seq_lengths, - decoration_seqs, - decoration_seq_lengths, - ): - return self.generative_model.likelihood( - scaffold_seqs, scaffold_seq_lengths, decoration_seqs, decoration_seq_lengths - ) + self.generative_model.save(path) + + def likelihood(self, scaffold_seqs, scaffold_seq_lengths, decoration_seqs, decoration_seq_lengths): + return self.generative_model.likelihood(scaffold_seqs, scaffold_seq_lengths, decoration_seqs, decoration_seq_lengths) def sample(self, scaffold_seqs, scaffold_seq_lengths): - return self.generative_model.sample_decorations( - scaffold_seqs, scaffold_seq_lengths - ) + return self.generative_model.sample_decorations(scaffold_seqs, scaffold_seq_lengths) def set_mode(self, mode: str): self.generative_model.set_mode(mode) diff --git a/reinvent_models/model_factory/link_invent_adapter.py b/reinvent_models/model_factory/link_invent_adapter.py index f6a20c2..8371ee5 100644 --- a/reinvent_models/model_factory/link_invent_adapter.py +++ b/reinvent_models/model_factory/link_invent_adapter.py @@ -1,9 +1,9 @@ from reinvent_models.link_invent.link_invent_model import LinkInventModel -from reinvent_models.model_factory.enums.model_mode_enum import ModelModeEnum from reinvent_models.model_factory.generative_model_base import GenerativeModelBase class LinkInventAdapter(GenerativeModelBase): + def __init__(self, path_to_file: str, mode: str): self.generative_model = LinkInventModel.load_from_file(path_to_file, mode) self.vocabulary = self.generative_model.vocabulary @@ -13,12 +13,8 @@ def __init__(self, path_to_file: str, mode: str): def save_to_file(self, path): self.generative_model.save_to_file(path) - def likelihood( - self, warheads_seqs, warheads_seq_lengths, linker_seqs, linker_seq_lengths - ): - return self.generative_model.likelihood( - warheads_seqs, warheads_seq_lengths, linker_seqs, linker_seq_lengths - ) + def likelihood(self, warheads_seqs, warheads_seq_lengths, linker_seqs, linker_seq_lengths): + return self.generative_model.likelihood(warheads_seqs, warheads_seq_lengths, linker_seqs, linker_seq_lengths) def sample(self, warheads_seqs, warheads_seq_lengths): return self.generative_model.sample(warheads_seqs, warheads_seq_lengths) diff --git a/reinvent_models/model_factory/reinvent_core_adapter.py b/reinvent_models/model_factory/reinvent_core_adapter.py index 2ef08ac..7ed96dc 100644 --- a/reinvent_models/model_factory/reinvent_core_adapter.py +++ b/reinvent_models/model_factory/reinvent_core_adapter.py @@ -1,15 +1,21 @@ -from reinvent_models.model_factory.enums.model_mode_enum import ModelModeEnum +from typing import List +import torch + +from reinvent_models.lib_invent.enums.generative_model_regime import GenerativeModelRegimeEnum from reinvent_models.model_factory.generative_model_base import GenerativeModelBase from reinvent_models.reinvent_core.models.model import Model class ReinventCoreAdapter(GenerativeModelBase): + def __init__(self, path_to_file: str, mode: str): + model_regime = GenerativeModelRegimeEnum() + mode = mode == model_regime.INFERENCE self.generative_model = Model.load_from_file(path_to_file, mode) - self.vocabulary = self.generative_model.vocabulary - self.tokenizer = self.generative_model.tokenizer - self.max_sequence_length = self.generative_model.max_sequence_length - self.network = self.generative_model.network + self.vocabulary = self.generative_model.vocabulary + self.tokenizer = self.generative_model.tokenizer + self.max_sequence_length = self.generative_model.max_sequence_length + self.network = self.generative_model.network # self._nll_loss = self._reinvent_model._nll_loss def save_to_file(self, path): @@ -18,11 +24,17 @@ def save_to_file(self, path): def likelihood(self, sequences): return self.generative_model.likelihood(sequences) - def sample(self, num, batch_size): - return self.generative_model.sample_smiles(num, batch_size) + def sample(self, batch_size): + return self.generative_model.sample_sequences_and_smiles(batch_size) + + def set_mode(self, mode: str): + self.generative_model.set_mode(mode) def get_network_parameters(self): return self.generative_model.get_network_parameters() def get_vocabulary(self): return self.vocabulary + + def likelihood_smiles(self, smiles: List[str])-> torch.Tensor: + return self.generative_model.likelihood_smiles(smiles) diff --git a/reinvent_models/reinvent_core/models/dataset.py b/reinvent_models/reinvent_core/models/dataset.py index c284095..df754b8 100644 --- a/reinvent_models/reinvent_core/models/dataset.py +++ b/reinvent_models/reinvent_core/models/dataset.py @@ -29,11 +29,9 @@ def __len__(self): def collate_fn(encoded_seqs): """Converts a list of encoded sequences into a padded tensor""" max_length = max([seq.size(0) for seq in encoded_seqs]) - collated_arr = torch.zeros( - len(encoded_seqs), max_length, dtype=torch.long - ) # padded with zeroes + collated_arr = torch.zeros(len(encoded_seqs), max_length, dtype=torch.long) # padded with zeroes for i, seq in enumerate(encoded_seqs): - collated_arr[i, : seq.size(0)] = seq + collated_arr[i, :seq.size(0)] = seq return collated_arr @@ -45,9 +43,7 @@ def calculate_nlls_from_model(model, smiles, batch_size=128): :return : It returns an iterator with every batch. """ dataset = Dataset(smiles, model.vocabulary, model.tokenizer) - _dataloader = tud.DataLoader( - dataset, batch_size=batch_size, collate_fn=Dataset.collate_fn - ) + _dataloader = tud.DataLoader(dataset, batch_size=batch_size, collate_fn=Dataset.collate_fn) def _iterator(dataloader): for batch in dataloader: diff --git a/reinvent_models/reinvent_core/models/model.py b/reinvent_models/reinvent_core/models/model.py index 562e211..0d38d0f 100644 --- a/reinvent_models/reinvent_core/models/model.py +++ b/reinvent_models/reinvent_core/models/model.py @@ -1,3 +1,4 @@ + """ Implementation of the RNN model """ @@ -7,10 +8,9 @@ import torch.nn as tnn import torch.nn.functional as tnnf +from reinvent_models.model_factory.enums.model_mode_enum import ModelModeEnum from reinvent_models.reinvent_core.models import vocabulary as mv -# from models import vocabulary as mv - class RNN(tnn.Module): """ @@ -18,16 +18,8 @@ class RNN(tnn.Module): and an output linear layer back to the size of the vocabulary """ - def __init__( - self, - voc_size, - layer_size=512, - num_layers=3, - cell_type="gru", - embedding_layer_size=256, - dropout=0.0, - layer_normalization=False, - ): + def __init__(self, voc_size, layer_size=512, num_layers=3, cell_type='gru', embedding_layer_size=256, dropout=0., + layer_normalization=False): """ Implements a N layer GRU|LSTM cell including an embedding layer and an output linear layer back to the size of the vocabulary @@ -46,26 +38,14 @@ def __init__( self._layer_normalization = layer_normalization self._embedding = tnn.Embedding(voc_size, self._embedding_layer_size) - if self._cell_type == "gru": - self._rnn = tnn.GRU( - self._embedding_layer_size, - self._layer_size, - num_layers=self._num_layers, - dropout=self._dropout, - batch_first=True, - ) - elif self._cell_type == "lstm": - self._rnn = tnn.LSTM( - self._embedding_layer_size, - self._layer_size, - num_layers=self._num_layers, - dropout=self._dropout, - batch_first=True, - ) + if self._cell_type == 'gru': + self._rnn = tnn.GRU(self._embedding_layer_size, self._layer_size, num_layers=self._num_layers, + dropout=self._dropout, batch_first=True) + elif self._cell_type == 'lstm': + self._rnn = tnn.LSTM(self._embedding_layer_size, self._layer_size, num_layers=self._num_layers, + dropout=self._dropout, batch_first=True) else: - raise ValueError( - 'Value of the parameter cell_type should be "gru" or "lstm"' - ) + raise ValueError('Value of the parameter cell_type should be "gru" or "lstm"') self._linear = tnn.Linear(self._layer_size, voc_size) def forward(self, input_vector, hidden_state=None): # pylint: disable=W0221 @@ -97,11 +77,11 @@ def get_params(self): Returns the configuration parameters of the model. """ return { - "dropout": self._dropout, - "layer_size": self._layer_size, - "num_layers": self._num_layers, - "cell_type": self._cell_type, - "embedding_layer_size": self._embedding_layer_size, + 'dropout': self._dropout, + 'layer_size': self._layer_size, + 'num_layers': self._num_layers, + 'cell_type': self._cell_type, + 'embedding_layer_size': self._embedding_layer_size } @@ -110,14 +90,8 @@ class Model: Implements an RNN model using SMILES. """ - def __init__( - self, - vocabulary: mv.Vocabulary, - tokenizer, - network_params=None, - max_sequence_length=256, - no_cuda=False, - ): + def __init__(self, vocabulary: mv.Vocabulary, tokenizer, network_params=None, max_sequence_length=256, + no_cuda=False): """ Implements an RNN. :param vocabulary: Vocabulary to use. @@ -128,6 +102,7 @@ def __init__( self.vocabulary = vocabulary self.tokenizer = tokenizer self.max_sequence_length = max_sequence_length + self._model_modes = ModelModeEnum() if not isinstance(network_params, dict): network_params = {} @@ -138,6 +113,14 @@ def __init__( self._nll_loss = tnn.NLLLoss(reduction="none") + def set_mode(self, mode: str): + if mode == self._model_modes.TRAINING: + self.network.train() + elif mode == self._model_modes.INFERENCE: + self.network.eval() + else: + raise ValueError(f"Invalid model mode '{mode}") + @classmethod def load_from_file(cls, file_path: str, sampling_mode=False): """ @@ -152,10 +135,10 @@ def load_from_file(cls, file_path: str, sampling_mode=False): network_params = save_dict.get("network_params", {}) model = Model( - vocabulary=save_dict["vocabulary"], - tokenizer=save_dict.get("tokenizer", mv.SMILESTokenizer()), + vocabulary=save_dict['vocabulary'], + tokenizer=save_dict.get('tokenizer', mv.SMILESTokenizer()), network_params=network_params, - max_sequence_length=save_dict["max_sequence_length"], + max_sequence_length=save_dict['max_sequence_length'] ) model.network.load_state_dict(save_dict["network"]) if sampling_mode: @@ -168,11 +151,11 @@ def save(self, file: str): :param file: it's actually a path """ save_dict = { - "vocabulary": self.vocabulary, - "tokenizer": self.tokenizer, - "max_sequence_length": self.max_sequence_length, - "network": self.network.state_dict(), - "network_params": self.network.get_params(), + 'vocabulary': self.vocabulary, + 'tokenizer': self.tokenizer, + 'max_sequence_length': self.max_sequence_length, + 'network': self.network.state_dict(), + 'network_params': self.network.get_params() } torch.save(save_dict, file) @@ -184,11 +167,9 @@ def likelihood_smiles(self, smiles) -> torch.Tensor: def collate_fn(encoded_seqs): """Function to take a list of encoded sequences and turn them into a batch""" max_length = max([seq.size(0) for seq in encoded_seqs]) - collated_arr = torch.zeros( - len(encoded_seqs), max_length, dtype=torch.long - ) # padded with zeroes + collated_arr = torch.zeros(len(encoded_seqs), max_length, dtype=torch.long) # padded with zeroes for i, seq in enumerate(encoded_seqs): - collated_arr[i, : seq.size(0)] = seq + collated_arr[i, :seq.size(0)] = seq return collated_arr padded_sequences = collate_fn(sequences) @@ -214,9 +195,7 @@ def sample_smiles(self, num=128, batch_size=128) -> Tuple[List, np.array]: :smiles: (n) A list with SMILES. :likelihoods: (n) A list of likelihoods. """ - batch_sizes = [batch_size for _ in range(num // batch_size)] + [ - num % batch_size - ] + batch_sizes = [batch_size for _ in range(num // batch_size)] + [num % batch_size] smiles_sampled = [] likelihoods_sampled = [] @@ -224,10 +203,7 @@ def sample_smiles(self, num=128, batch_size=128) -> Tuple[List, np.array]: if not size: break seqs, likelihoods = self._sample(batch_size=size) - smiles = [ - self.tokenizer.untokenize(self.vocabulary.decode(seq)) - for seq in seqs.cpu().numpy() - ] + smiles = [self.tokenizer.untokenize(self.vocabulary.decode(seq)) for seq in seqs.cpu().numpy()] smiles_sampled.extend(smiles) likelihoods_sampled.append(likelihoods.data.cpu().numpy()) @@ -235,14 +211,9 @@ def sample_smiles(self, num=128, batch_size=128) -> Tuple[List, np.array]: del seqs, likelihoods return smiles_sampled, np.concatenate(likelihoods_sampled) - def sample_sequences_and_smiles( - self, batch_size=128 - ) -> Tuple[torch.Tensor, List, torch.Tensor]: + def sample_sequences_and_smiles(self, batch_size=128) -> Tuple[torch.Tensor, List, torch.Tensor]: seqs, likelihoods = self._sample(batch_size=batch_size) - smiles = [ - self.tokenizer.untokenize(self.vocabulary.decode(seq)) - for seq in seqs.cpu().numpy() - ] + smiles = [self.tokenizer.untokenize(self.vocabulary.decode(seq)) for seq in seqs.cpu().numpy()] return seqs, smiles, likelihoods # @torch.no_grad() @@ -250,9 +221,7 @@ def _sample(self, batch_size=128) -> Tuple[torch.Tensor, torch.Tensor]: start_token = torch.zeros(batch_size, dtype=torch.long) start_token[:] = self.vocabulary["^"] input_vector = start_token - sequences = [ - self.vocabulary["^"] * torch.ones([batch_size, 1], dtype=torch.long) - ] + sequences = [self.vocabulary["^"] * torch.ones([batch_size, 1], dtype=torch.long)] # NOTE: The first token never gets added in the loop so the sequences are initialized with a start token hidden_state = None nlls = torch.zeros(batch_size) diff --git a/reinvent_models/reinvent_core/models/vocabulary.py b/reinvent_models/reinvent_core/models/vocabulary.py index 3eb6161..7aa8d70 100644 --- a/reinvent_models/reinvent_core/models/vocabulary.py +++ b/reinvent_models/reinvent_core/models/vocabulary.py @@ -84,13 +84,12 @@ class SMILESTokenizer: REGEXPS = { "brackets": re.compile(r"(\[[^\]]*\])"), "2_ring_nums": re.compile(r"(%\d{2})"), - "brcl": re.compile(r"(Br|Cl)"), + "brcl": re.compile(r"(Br|Cl)") } REGEXP_ORDER = ["brackets", "2_ring_nums", "brcl"] def tokenize(self, data, with_begin_and_end=True): """Tokenizes a SMILES string.""" - def split_by(data, regexps): if not regexps: return list(data) @@ -127,7 +126,5 @@ def create_vocabulary(smiles_list, tokenizer): tokens.update(tokenizer.tokenize(smi, with_begin_and_end=False)) vocabulary = Vocabulary() - vocabulary.update( - ["$", "^"] + sorted(tokens) - ) # end token is 0 (also counts as padding) + vocabulary.update(["$", "^"] + sorted(tokens)) # end token is 0 (also counts as padding) return vocabulary diff --git a/setup.py b/setup.py index f773e7e..0568de7 100644 --- a/setup.py +++ b/setup.py @@ -5,12 +5,12 @@ setuptools.setup( name="reinvent_models", - version="0.0.7", - author="PaccMann Team", - description="Generative models for Reinvent adapted for PaccMann", + version="0.0.14", + author="GT4SD Team", + description="Generative models for Reinvent adapted for GT4SD", long_description=long_description, long_description_content_type="text/markdown", - url="https://github.com/PaccMann/reinvent_models", + url="https://github.com/GT4SD/reinvent_models", packages=setuptools.find_packages(exclude=("testing",)), classifiers=[ "Programming Language :: Python :: 3",