Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Notebooks and various fixes for masked model training #5

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
lightning_logs/
notebooks/_*.ipynb

wandb/
# vscode
.vscode

# jupyter
MANIFEST
build
Expand Down Expand Up @@ -157,3 +156,12 @@ venv.bak/

# mypy
.mypy_cache/
molexpress/**/*.ckpt
molexpress/**/*.pth
molexpress/**/*.txt
molexpress/**/*.csv
molexpress/**/*.zip




96 changes: 71 additions & 25 deletions molexpress/datasets/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,38 @@
def __call__(self, residues: list[types.Molecule | types.SMILES | types.InChI]) -> np.ndarray:
residue_graphs = []
residue_sizes = []
# print(residues)
for residue in residues:
residue_graph, residue_size = self._encode_residue(
residue, self.node_encoder, self.edge_encoder
)
try:
residue_graph, residue_size = self._encode_residue(
residue, self.node_encoder, self.edge_encoder
)
# print("Printing residue graphs",residue_graph)
except Exception as e:

Check failure on line 36 in molexpress/datasets/encoders.py

View workflow job for this annotation

GitHub Actions / test (3.9)

Ruff (F841)

molexpress/datasets/encoders.py:36:33: F841 Local variable `e` is assigned to but never used

Check failure on line 36 in molexpress/datasets/encoders.py

View workflow job for this annotation

GitHub Actions / test (3.10)

Ruff (F841)

molexpress/datasets/encoders.py:36:33: F841 Local variable `e` is assigned to but never used

Check failure on line 36 in molexpress/datasets/encoders.py

View workflow job for this annotation

GitHub Actions / test (3.11)

Ruff (F841)

molexpress/datasets/encoders.py:36:33: F841 Local variable `e` is assigned to but never used
# print(residues)
print("Residues cannot be encoded properly")
# continue

residue_graphs.append(residue_graph)
residue_sizes.append(residue_size)


#print(residues, residue_graphs)
# print(residue_graphs)
disjoint_peptide_graph = self._merge_molecular_graphs(residue_graphs)
disjoint_peptide_graph["residue_size"] = np.array(residue_sizes)

# print(disjoint_peptide_graph)
try:
disjoint_peptide_graph["residue_size"] = np.array(residue_sizes)
except Exception as e:
# print(disjoint_peptide_graph)
print("Cannot construct disjoint graph")
raise e

disjoint_peptide_graph["peptide_size"] = np.array([len(residues)], dtype="int32")
return disjoint_peptide_graph



@staticmethod
@lru_cache(maxsize=None)
def _encode_residue(
Expand Down Expand Up @@ -96,10 +116,19 @@
"""
disjoint_peptide_graphs = data


disjoint_peptide_batch_graph = PeptideGraphEncoder._merge_molecular_graphs(
disjoint_peptide_graphs
)
disjoint_peptide_batch_graph["peptide_size"] = np.concatenate(
[g["residue_size"].shape[:1] for g in disjoint_peptide_graphs]
).astype("int32")
disjoint_peptide_batch_graph["residue_size"] = np.concatenate(
[g["residue_size"] for g in disjoint_peptide_graphs]
).astype("int32")


# print(disjoint_peptide_batch_graph)
node_state = disjoint_peptide_batch_graph['node_state']
node_mask = np.random.uniform(size=node_state.shape[0]) < node_masking_rate
disjoint_peptide_batch_graph['node_loss_weight'] = np.copy(node_mask.astype(node_state.dtype))
Expand All @@ -117,37 +146,51 @@
mask_state[:, -1] = 1.
disjoint_peptide_batch_graph['edge_state'] = np.where(
edge_mask[:, None], mask_state, edge_state)

# print(disjoint_peptide_batch_graph)
# residue_size = np.array([g["residue_size"] for g in molecular_graphs])
return disjoint_peptide_batch_graph

@staticmethod
def _merge_molecular_graphs(
molecular_graphs: list[types.MolecularGraph],
) -> types.MolecularGraph:
# print(molecular_graphs)
# print([g for g in molecular_graphs])


num_nodes = np.array([g["node_state"].shape[0] for g in molecular_graphs])

disjoint_molecular_graph = {}

disjoint_molecular_graph["node_state"] = np.concatenate(
[g["node_state"] for g in molecular_graphs]
)

if "edge_state" in molecular_graphs[0]:
disjoint_molecular_graph["edge_state"] = np.concatenate(
[g["edge_state"] for g in molecular_graphs]
if len(molecular_graphs)>0:
disjoint_molecular_graph["node_state"] = np.concatenate(
[g["node_state"] for g in molecular_graphs]
)

edge_src = np.concatenate([graph["edge_src"] for graph in molecular_graphs])
edge_dst = np.concatenate([graph["edge_dst"] for graph in molecular_graphs])
num_edges = np.array([graph["edge_src"].shape[0] for graph in molecular_graphs])
indices = np.repeat(range(len(molecular_graphs)), num_edges)
edge_incr = np.concatenate([[0], num_nodes[:-1]])
edge_incr = np.take_along_axis(edge_incr, indices, axis=0)

disjoint_molecular_graph["edge_src"] = edge_src + edge_incr
disjoint_molecular_graph["edge_dst"] = edge_dst + edge_incr

return disjoint_molecular_graph
if "edge_state" in molecular_graphs[0]:
try:
disjoint_molecular_graph["edge_state"] = np.concatenate(
[g["edge_state"] for g in molecular_graphs]
)
except ValueError as e:

print("Error is due to the presence of structures without any bonds, usually these are ions / atoms")
print("Error during concatenation. Shapes of edge_state arrays:")
print([g["edge_state"].shape for g in molecular_graphs])

raise e

edge_src = np.concatenate([graph["edge_src"] for graph in molecular_graphs])
edge_dst = np.concatenate([graph["edge_dst"] for graph in molecular_graphs])
num_edges = np.array([graph["edge_src"].shape[0] for graph in molecular_graphs])
indices = np.repeat(range(len(molecular_graphs)), num_edges)
edge_incr = np.concatenate([[0], num_nodes[:-1]])
edge_incr = np.take_along_axis(edge_incr, indices, axis=0)

disjoint_molecular_graph["edge_src"] = edge_src + edge_incr
disjoint_molecular_graph["edge_dst"] = edge_dst + edge_incr

return disjoint_molecular_graph


class Composer:
Expand Down Expand Up @@ -234,6 +277,7 @@
}



class MolecularNodeEncoder:
def __init__(
self,
Expand All @@ -244,9 +288,11 @@
self.supports_masking = supports_masking

def __call__(self, molecule: types.Molecule) -> np.ndarray:
node_encodings = np.stack([self.featurizer(atom) for atom in molecule.GetAtoms()], axis=0)

node_encodings = np.stack([self.featurizer(atom) for atom in molecule.GetAtoms() if molecule ], axis=0)
if self.supports_masking:
node_encodings = np.pad(node_encodings, [(0, 0), (0, 1)])
return {
"node_state": np.stack(node_encodings),
}

2 changes: 1 addition & 1 deletion molexpress/layers/gcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def call(self, inputs: types.MolecularGraph) -> types.MolecularGraph:
if self.skip_connection:
if self._transform_skip_connection:
node_state = gnn_ops.transform(state=node_state, kernel=self.skip_connect_kernel)
node_state_updated += node_state
node_state_updated = node_state_updated + node_state

if self.dropout_rate:
node_state_updated = self.dropout(node_state_updated)
Expand Down
2 changes: 1 addition & 1 deletion molexpress/layers/gin_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def call(self, inputs: types.MolecularGraph) -> types.MolecularGraph:
edge_weight=edge_weight,
)

node_state_updated += (1 + self.epsilon) * node_state
node_state_updated = node_state_updated + (1 + self.epsilon) * node_state

node_state_updated = gnn_ops.transform(
state=node_state_updated, kernel=self.node_kernel_1, bias=self.node_bias_1
Expand Down
5 changes: 3 additions & 2 deletions molexpress/ops/gnn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def transform(
# kernel.rank == 3 and state.rank == 2
state_transformed = keras.ops.einsum('ij,jkh->ikh', state, kernel)
if bias is not None:
state_transformed += bias
state_transformed =state_transformed + bias
return state_transformed

def aggregate(
Expand Down Expand Up @@ -71,13 +71,14 @@ def aggregate(
edge_src = keras.ops.expand_dims(edge_src, axis=-1)
edge_dst = keras.ops.expand_dims(edge_dst, axis=-1)

# print(edge_src.size(),node_state.size())
node_state_src = keras.ops.take_along_axis(node_state, edge_src, axis=0)

if edge_weight is not None:
node_state_src *= edge_weight

if edge_state is not None:
node_state_src += edge_state
node_state_src = node_state_src + edge_state

edge_dst = keras.ops.squeeze(edge_dst)

Expand Down
62 changes: 62 additions & 0 deletions molexpress/pretraining/canonicalise_smiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import multiprocessing as mp
from rdkit import Chem
from tqdm import tqdm
import os

Check failure on line 4 in molexpress/pretraining/canonicalise_smiles.py

View workflow job for this annotation

GitHub Actions / test (3.9)

Ruff (F401)

molexpress/pretraining/canonicalise_smiles.py:4:8: F401 `os` imported but unused

Check failure on line 4 in molexpress/pretraining/canonicalise_smiles.py

View workflow job for this annotation

GitHub Actions / test (3.10)

Ruff (F401)

molexpress/pretraining/canonicalise_smiles.py:4:8: F401 `os` imported but unused

Check failure on line 4 in molexpress/pretraining/canonicalise_smiles.py

View workflow job for this annotation

GitHub Actions / test (3.11)

Ruff (F401)

molexpress/pretraining/canonicalise_smiles.py:4:8: F401 `os` imported but unused

# Function to canonicalize SMILES
def canonicalize_smiles(smiles):
mol = Chem.MolFromSmiles(smiles) # Convert SMILES to molecule object
if mol: # Check if molecule conversion was successful
return Chem.MolToSmiles(mol, canonical=True) # Return canonical SMILES
else:
return None

# Process a chunk of SMILES
def process_chunk(smiles_chunk):
valid_smiles = []
invalid_smiles = []
for smiles in smiles_chunk:
canonical_smiles = canonicalize_smiles(smiles.strip())
if canonical_smiles:
valid_smiles.append(canonical_smiles)
else:
invalid_smiles.append(smiles.strip())
return valid_smiles, invalid_smiles

# Read SMILES from input file and split them into chunks
def process_smiles_file(input_file, output_file, invalid_file, num_processes=4, chunk_size=100000):
# Get the total number of lines (SMILES strings)
total_lines = sum(1 for _ in open(input_file, 'r'))

Check failure on line 29 in molexpress/pretraining/canonicalise_smiles.py

View workflow job for this annotation

GitHub Actions / test (3.9)

Ruff (F841)

molexpress/pretraining/canonicalise_smiles.py:29:5: F841 Local variable `total_lines` is assigned to but never used

Check failure on line 29 in molexpress/pretraining/canonicalise_smiles.py

View workflow job for this annotation

GitHub Actions / test (3.10)

Ruff (F841)

molexpress/pretraining/canonicalise_smiles.py:29:5: F841 Local variable `total_lines` is assigned to but never used

Check failure on line 29 in molexpress/pretraining/canonicalise_smiles.py

View workflow job for this annotation

GitHub Actions / test (3.11)

Ruff (F841)

molexpress/pretraining/canonicalise_smiles.py:29:5: F841 Local variable `total_lines` is assigned to but never used

# Use multiprocessing to process the file in parallel
with open(input_file, 'r') as infile:
smiles_list = infile.readlines()

# Split the smiles into chunks
smiles_chunks = [smiles_list[i:i + chunk_size] for i in range(0, len(smiles_list), chunk_size)]

# Set up a multiprocessing pool
with mp.Pool(processes=num_processes) as pool:
# Process each chunk in parallel
results = list(tqdm(pool.imap(process_chunk, smiles_chunks), total=len(smiles_chunks)))

# Gather results
valid_smiles = []
invalid_smiles = []
for valid, invalid in results:
valid_smiles.extend(valid)
invalid_smiles.extend(invalid)

# Write the results to the output files
with open(output_file, 'w') as outfile:
outfile.write('\n'.join(valid_smiles) + '\n')
with open(invalid_file, 'w') as invalid_outfile:
invalid_outfile.write('\n'.join(invalid_smiles) + '\n')

# Example usage
input_file = 'filtered_pubchem.txt' # Your input file containing SMILES strings
output_file = 'canon_filtered_pubchem.txt' # Output file for valid canonical SMILES
invalid_file = 'invalid_smiles.txt' # Output file for invalid SMILES

# Adjust num_processes based on your machine's CPU cores, and tune chunk_size based on file size
process_smiles_file(input_file, output_file, invalid_file, num_processes=8, chunk_size=100000)
Loading
Loading