Skip to content

Commit

Permalink
first fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
hramdasan committed Nov 14, 2024
1 parent 8927de2 commit fcf6460
Show file tree
Hide file tree
Showing 14 changed files with 13,765 additions and 31 deletions.
12 changes: 10 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
lightning_logs/
notebooks/_*.ipynb

wandb/
# vscode
.vscode

# jupyter
MANIFEST
build
Expand Down Expand Up @@ -157,3 +156,12 @@ venv.bak/

# mypy
.mypy_cache/
molexpress/**/*.ckpt
molexpress/**/*.pth
molexpress/**/*.txt
molexpress/**/*.csv
molexpress/**/*.zip




96 changes: 71 additions & 25 deletions molexpress/datasets/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,38 @@ def __init__(
def __call__(self, residues: list[types.Molecule | types.SMILES | types.InChI]) -> np.ndarray:
residue_graphs = []
residue_sizes = []
# print(residues)
for residue in residues:
residue_graph, residue_size = self._encode_residue(
residue, self.node_encoder, self.edge_encoder
)
try:
residue_graph, residue_size = self._encode_residue(
residue, self.node_encoder, self.edge_encoder
)
# print("Printing residue graphs",residue_graph)
except Exception as e:

Check failure on line 36 in molexpress/datasets/encoders.py

View workflow job for this annotation

GitHub Actions / test (3.9)

Ruff (F841)

molexpress/datasets/encoders.py:36:33: F841 Local variable `e` is assigned to but never used

Check failure on line 36 in molexpress/datasets/encoders.py

View workflow job for this annotation

GitHub Actions / test (3.10)

Ruff (F841)

molexpress/datasets/encoders.py:36:33: F841 Local variable `e` is assigned to but never used

Check failure on line 36 in molexpress/datasets/encoders.py

View workflow job for this annotation

GitHub Actions / test (3.11)

Ruff (F841)

molexpress/datasets/encoders.py:36:33: F841 Local variable `e` is assigned to but never used
# print(residues)
print("Residues cannot be encoded properly")
# continue

residue_graphs.append(residue_graph)
residue_sizes.append(residue_size)


#print(residues, residue_graphs)
# print(residue_graphs)
disjoint_peptide_graph = self._merge_molecular_graphs(residue_graphs)
disjoint_peptide_graph["residue_size"] = np.array(residue_sizes)

# print(disjoint_peptide_graph)
try:
disjoint_peptide_graph["residue_size"] = np.array(residue_sizes)
except Exception as e:
# print(disjoint_peptide_graph)
print("Cannot construct disjoint graph")
raise e

disjoint_peptide_graph["peptide_size"] = np.array([len(residues)], dtype="int32")
return disjoint_peptide_graph



@staticmethod
@lru_cache(maxsize=None)
def _encode_residue(
Expand Down Expand Up @@ -96,10 +116,19 @@ def masked_collate_fn(
"""
disjoint_peptide_graphs = data


disjoint_peptide_batch_graph = PeptideGraphEncoder._merge_molecular_graphs(
disjoint_peptide_graphs
)
disjoint_peptide_batch_graph["peptide_size"] = np.concatenate(
[g["residue_size"].shape[:1] for g in disjoint_peptide_graphs]
).astype("int32")
disjoint_peptide_batch_graph["residue_size"] = np.concatenate(
[g["residue_size"] for g in disjoint_peptide_graphs]
).astype("int32")


# print(disjoint_peptide_batch_graph)
node_state = disjoint_peptide_batch_graph['node_state']
node_mask = np.random.uniform(size=node_state.shape[0]) < node_masking_rate
disjoint_peptide_batch_graph['node_loss_weight'] = np.copy(node_mask.astype(node_state.dtype))
Expand All @@ -117,37 +146,51 @@ def masked_collate_fn(
mask_state[:, -1] = 1.
disjoint_peptide_batch_graph['edge_state'] = np.where(
edge_mask[:, None], mask_state, edge_state)

# print(disjoint_peptide_batch_graph)
# residue_size = np.array([g["residue_size"] for g in molecular_graphs])
return disjoint_peptide_batch_graph

@staticmethod
def _merge_molecular_graphs(
molecular_graphs: list[types.MolecularGraph],
) -> types.MolecularGraph:
# print(molecular_graphs)
# print([g for g in molecular_graphs])


num_nodes = np.array([g["node_state"].shape[0] for g in molecular_graphs])

disjoint_molecular_graph = {}

disjoint_molecular_graph["node_state"] = np.concatenate(
[g["node_state"] for g in molecular_graphs]
)

if "edge_state" in molecular_graphs[0]:
disjoint_molecular_graph["edge_state"] = np.concatenate(
[g["edge_state"] for g in molecular_graphs]
if len(molecular_graphs)>0:
disjoint_molecular_graph["node_state"] = np.concatenate(
[g["node_state"] for g in molecular_graphs]
)

edge_src = np.concatenate([graph["edge_src"] for graph in molecular_graphs])
edge_dst = np.concatenate([graph["edge_dst"] for graph in molecular_graphs])
num_edges = np.array([graph["edge_src"].shape[0] for graph in molecular_graphs])
indices = np.repeat(range(len(molecular_graphs)), num_edges)
edge_incr = np.concatenate([[0], num_nodes[:-1]])
edge_incr = np.take_along_axis(edge_incr, indices, axis=0)

disjoint_molecular_graph["edge_src"] = edge_src + edge_incr
disjoint_molecular_graph["edge_dst"] = edge_dst + edge_incr

return disjoint_molecular_graph
if "edge_state" in molecular_graphs[0]:
try:
disjoint_molecular_graph["edge_state"] = np.concatenate(
[g["edge_state"] for g in molecular_graphs]
)
except ValueError as e:

print("Error is due to the presence of structures without any bonds, usually these are ions / atoms")
print("Error during concatenation. Shapes of edge_state arrays:")
print([g["edge_state"].shape for g in molecular_graphs])

raise e

edge_src = np.concatenate([graph["edge_src"] for graph in molecular_graphs])
edge_dst = np.concatenate([graph["edge_dst"] for graph in molecular_graphs])
num_edges = np.array([graph["edge_src"].shape[0] for graph in molecular_graphs])
indices = np.repeat(range(len(molecular_graphs)), num_edges)
edge_incr = np.concatenate([[0], num_nodes[:-1]])
edge_incr = np.take_along_axis(edge_incr, indices, axis=0)

disjoint_molecular_graph["edge_src"] = edge_src + edge_incr
disjoint_molecular_graph["edge_dst"] = edge_dst + edge_incr

return disjoint_molecular_graph


class Composer:
Expand Down Expand Up @@ -234,6 +277,7 @@ def __call__(self, molecule: types.Molecule) -> np.ndarray:
}



class MolecularNodeEncoder:
def __init__(
self,
Expand All @@ -244,9 +288,11 @@ def __init__(
self.supports_masking = supports_masking

def __call__(self, molecule: types.Molecule) -> np.ndarray:
node_encodings = np.stack([self.featurizer(atom) for atom in molecule.GetAtoms()], axis=0)

node_encodings = np.stack([self.featurizer(atom) for atom in molecule.GetAtoms() if molecule ], axis=0)
if self.supports_masking:
node_encodings = np.pad(node_encodings, [(0, 0), (0, 1)])
return {
"node_state": np.stack(node_encodings),
}

2 changes: 1 addition & 1 deletion molexpress/layers/gcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def call(self, inputs: types.MolecularGraph) -> types.MolecularGraph:
if self.skip_connection:
if self._transform_skip_connection:
node_state = gnn_ops.transform(state=node_state, kernel=self.skip_connect_kernel)
node_state_updated += node_state
node_state_updated = node_state_updated + node_state

if self.dropout_rate:
node_state_updated = self.dropout(node_state_updated)
Expand Down
2 changes: 1 addition & 1 deletion molexpress/layers/gin_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def call(self, inputs: types.MolecularGraph) -> types.MolecularGraph:
edge_weight=edge_weight,
)

node_state_updated += (1 + self.epsilon) * node_state
node_state_updated = node_state_updated + (1 + self.epsilon) * node_state

node_state_updated = gnn_ops.transform(
state=node_state_updated, kernel=self.node_kernel_1, bias=self.node_bias_1
Expand Down
5 changes: 3 additions & 2 deletions molexpress/ops/gnn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def transform(
# kernel.rank == 3 and state.rank == 2
state_transformed = keras.ops.einsum('ij,jkh->ikh', state, kernel)
if bias is not None:
state_transformed += bias
state_transformed =state_transformed + bias
return state_transformed

def aggregate(
Expand Down Expand Up @@ -71,13 +71,14 @@ def aggregate(
edge_src = keras.ops.expand_dims(edge_src, axis=-1)
edge_dst = keras.ops.expand_dims(edge_dst, axis=-1)

# print(edge_src.size(),node_state.size())
node_state_src = keras.ops.take_along_axis(node_state, edge_src, axis=0)

if edge_weight is not None:
node_state_src *= edge_weight

if edge_state is not None:
node_state_src += edge_state
node_state_src = node_state_src + edge_state

edge_dst = keras.ops.squeeze(edge_dst)

Expand Down
62 changes: 62 additions & 0 deletions molexpress/pretraining/canonicalise_smiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import multiprocessing as mp
from rdkit import Chem
from tqdm import tqdm
import os

Check failure on line 4 in molexpress/pretraining/canonicalise_smiles.py

View workflow job for this annotation

GitHub Actions / test (3.9)

Ruff (F401)

molexpress/pretraining/canonicalise_smiles.py:4:8: F401 `os` imported but unused

Check failure on line 4 in molexpress/pretraining/canonicalise_smiles.py

View workflow job for this annotation

GitHub Actions / test (3.10)

Ruff (F401)

molexpress/pretraining/canonicalise_smiles.py:4:8: F401 `os` imported but unused

Check failure on line 4 in molexpress/pretraining/canonicalise_smiles.py

View workflow job for this annotation

GitHub Actions / test (3.11)

Ruff (F401)

molexpress/pretraining/canonicalise_smiles.py:4:8: F401 `os` imported but unused

# Function to canonicalize SMILES
def canonicalize_smiles(smiles):
mol = Chem.MolFromSmiles(smiles) # Convert SMILES to molecule object
if mol: # Check if molecule conversion was successful
return Chem.MolToSmiles(mol, canonical=True) # Return canonical SMILES
else:
return None

# Process a chunk of SMILES
def process_chunk(smiles_chunk):
valid_smiles = []
invalid_smiles = []
for smiles in smiles_chunk:
canonical_smiles = canonicalize_smiles(smiles.strip())
if canonical_smiles:
valid_smiles.append(canonical_smiles)
else:
invalid_smiles.append(smiles.strip())
return valid_smiles, invalid_smiles

# Read SMILES from input file and split them into chunks
def process_smiles_file(input_file, output_file, invalid_file, num_processes=4, chunk_size=100000):
# Get the total number of lines (SMILES strings)
total_lines = sum(1 for _ in open(input_file, 'r'))

Check failure on line 29 in molexpress/pretraining/canonicalise_smiles.py

View workflow job for this annotation

GitHub Actions / test (3.9)

Ruff (F841)

molexpress/pretraining/canonicalise_smiles.py:29:5: F841 Local variable `total_lines` is assigned to but never used

Check failure on line 29 in molexpress/pretraining/canonicalise_smiles.py

View workflow job for this annotation

GitHub Actions / test (3.10)

Ruff (F841)

molexpress/pretraining/canonicalise_smiles.py:29:5: F841 Local variable `total_lines` is assigned to but never used

Check failure on line 29 in molexpress/pretraining/canonicalise_smiles.py

View workflow job for this annotation

GitHub Actions / test (3.11)

Ruff (F841)

molexpress/pretraining/canonicalise_smiles.py:29:5: F841 Local variable `total_lines` is assigned to but never used

# Use multiprocessing to process the file in parallel
with open(input_file, 'r') as infile:
smiles_list = infile.readlines()

# Split the smiles into chunks
smiles_chunks = [smiles_list[i:i + chunk_size] for i in range(0, len(smiles_list), chunk_size)]

# Set up a multiprocessing pool
with mp.Pool(processes=num_processes) as pool:
# Process each chunk in parallel
results = list(tqdm(pool.imap(process_chunk, smiles_chunks), total=len(smiles_chunks)))

# Gather results
valid_smiles = []
invalid_smiles = []
for valid, invalid in results:
valid_smiles.extend(valid)
invalid_smiles.extend(invalid)

# Write the results to the output files
with open(output_file, 'w') as outfile:
outfile.write('\n'.join(valid_smiles) + '\n')
with open(invalid_file, 'w') as invalid_outfile:
invalid_outfile.write('\n'.join(invalid_smiles) + '\n')

# Example usage
input_file = 'filtered_pubchem.txt' # Your input file containing SMILES strings
output_file = 'canon_filtered_pubchem.txt' # Output file for valid canonical SMILES
invalid_file = 'invalid_smiles.txt' # Output file for invalid SMILES

# Adjust num_processes based on your machine's CPU cores, and tune chunk_size based on file size
process_smiles_file(input_file, output_file, invalid_file, num_processes=8, chunk_size=100000)
Loading

0 comments on commit fcf6460

Please sign in to comment.