diff --git a/BPNet-OSKN/model.py b/BPNet-OSKN/model.py new file mode 100644 index 000000000..5b31cbdfc --- /dev/null +++ b/BPNet-OSKN/model.py @@ -0,0 +1,75 @@ +from bpnet.seqmodel import SeqModel +from keras.models import load_model +import numpy as np +import bpnet +import tensorflow as tf +from bpnet.functions import softmax +import keras.backend as K +import keras.layers as kl +from kipoi.model import BaseModel + +def profile_contrib(p): + return kl.Lambda(lambda p: + K.mean(K.sum(K.stop_gradient(tf.nn.softmax(p, dim=-2)) * p, axis=-2), axis=-1) + )(p) + + +class BPNetOldSeqModel(BaseModel, SeqModel): + + preact_tensor_names = ['reshape_2/Reshape:0', + 'dense_1/BiasAdd:0', + 'reshape_4/Reshape:0', + 'dense_3/BiasAdd:0', + 'reshape_6/Reshape:0', + 'dense_5/BiasAdd:0', + 'reshape_8/Reshape:0', + 'dense_7/BiasAdd:0' + ] + + bottleneck_name = 'add_9/add:0' + + target_names = ['Oct4/profile', + 'Oct4/counts', + 'Sox2/profile', + 'Sox2/counts', + 'Nanog/profile', + 'Nanog/counts', + 'Klf4/profile', + 'Klf4/counts'] + + seqlen = 1000 + + tasks = ['Oct4', 'Sox2', 'Nanog', 'Klf4'] + + postproc_fns = [softmax, None] * 4 + + def __init__(self, model_file): + self.model_file = model_file + K.clear_session() # restart session + self.model = load_model(model_file, compile=False) + self.contrib_fns = {} + + def predict_on_batch(self, seq): + preds = self.model.predict_on_batch({"seq": seq, **self.neutral_bias_inputs(len(seq), seqlen=seq.shape[1])}) + pred_dict = {target: preds[i] for i, target in enumerate(self.target_names)} + return {task: softmax(pred_dict[f'{task}/profile']) * np.exp(pred_dict[f'{task}/counts'][:, np.newaxis]) + for task in self.tasks} + + def neutral_bias_inputs(self, length, seqlen): + """Compile a set of neutral bias inputs + """ + return dict([('bias/' + target, np.zeros((length, seqlen, 4)) + if target.endswith("/profile") + else np.zeros((length, 2))) + for target in self.target_names]) + + def get_intp_tensors(self, preact_only=True, graph=None): + if graph is None: + graph = tf.get_default_graph() + intp_targets = [] + for head_name, tensor_name in zip(self.target_names, self.preact_tensor_names): + tensor = graph.get_tensor_by_name(tensor_name) + if head_name.endswith("/profile"): + tensor = profile_contrib(tensor) + intp_targets.append((head_name, tensor)) + return intp_targets diff --git a/BPNet-OSKN/model.yaml b/BPNet-OSKN/model.yaml new file mode 100644 index 000000000..d6bf2141e --- /dev/null +++ b/BPNet-OSKN/model.yaml @@ -0,0 +1,67 @@ +defined_as: model.BPNetOldSeqModel +args: + model_file: + # TODO - put to Zenodo + url: 'http://mitra.stanford.edu/kundaje/avsec/chipnexus/paper/modisco-comparison/v2-output/nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE/model.calibrated.h5' + md5: bbe883baef261877bfad07d05feb627d + +default_dataloader: + defined_as: kipoiseq.dataloaders.SeqIntervalDl + default_args: + auto_resize_len: 1000 + ignore_targets: True + +info: + authors: + - name: Ziga Avsec + github: avsecz + doc: BPNet model predicting the ChIP-nexus profiles of Oct4, Sox2, Nanog and Klf4 + cite_as: TODO + trained_on: ChIP-nexus data in mm10. test chromosomes 1, 8, 9, validation chromosomes 2, 3, 4 + license: MIT + +dependencies: + channels: + - bioconda + - pytorch + - conda-forge + - defaults + conda: + - python=3.6 + - bioconda::pybedtools>=0.7.10 + - bioconda::bedtools>=2.27.1 + - bioconda::pybigwig>=0.3.10 + - bioconda::pysam>=0.14.0 + - bioconda::genomelake>=0.1.4 + + - pytorch::pytorch # optional for data-loading + - cython + - h5py>=2.7.0 + - numpy + + - pandas>=0.23.0 + - fastparquet + - python-snappy + + - nb_conda + pip: + - tensorflow>=1.0 + - git+https://github.com/kundajelab/DeepExplain.git + - bpnet[extras] +schema: + inputs: + shape: (1000, 4) + doc: "One-hot encoded DNA sequence." + targets: + Oct4: + shape: (1000,2) + doc: "Strand-specific ChIP-nexus data for Oct4." + Sox2: + shape: (1000,2) + doc: "Strand-specific ChIP-nexus data for Sox2." + Nanog: + shape: (1000,2) + doc: "Strand-specific ChIP-nexus data for Nanog." + Klf4: + shape: (1000,2) + doc: "Strand-specific ChIP-nexus data for Klf4."