-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
408 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Attentive Chrome Kipoi | ||
|
||
## Dependency Requirements | ||
* python>=3.5 | ||
* numpy | ||
* pytorch-cpu | ||
* torchvision-cpu | ||
|
||
## Quick Start | ||
### Creating new conda environtment using kipoi | ||
`kipoi env create AttentiveChrome` | ||
|
||
### Activating environment | ||
`conda activate kipoi-AttentiveChrome` | ||
|
||
## Command Line | ||
### Getting example input file | ||
Replace {model_name} with the actual name of model (e.g. E003, E005, etc.) | ||
|
||
`kipoi get-example AttentiveChrome/{model_name} -o example_file` | ||
|
||
example: `kipoi get-example AttentiveChrome/E003 -o example_file` | ||
|
||
### Predicting using example file | ||
`kipoi predict AttentiveChrome/{model_name} --dataloader_args='{"input_file": "example_file/input_file", "bin_size": 100}' -o example_predict.tsv` | ||
|
||
This should produce a tsv file containing the results. | ||
|
||
## Python | ||
### Fetching the model | ||
First, import kipoi: | ||
`import kipoi` | ||
|
||
Next, get the model. Replace {model_name} with the actual name of model (e.g. E003, E005, etc.) | ||
|
||
`model = kipoi.get_model("AttentiveChrome/{model_name}")` | ||
|
||
### Predicting using pipeline | ||
`prediction = model.pipeline.predict({"input_file": "path to input file", "bin_size": {some integer}})` | ||
|
||
This returns a numpy array containing the output from the final softmax function. | ||
|
||
e.g. `model.pipeline.predict({"input_file": "data/input_file", "bin_size": 100})` | ||
|
||
### Predicting for a single batch | ||
First, we need to set up our dataloader `dl`. | ||
|
||
`dl = model.default_dataloader(input_file="path to input file", bin_size={some integer})` | ||
|
||
Next, we can use the iterator functionality of the dataloader. | ||
|
||
`it = dl.batch_iter(batch_size=32)` | ||
|
||
`single_batch = next(it)` | ||
|
||
First line gets us an iterator named `it` with each batch containing 32 items. We can use `next(it)` to get a batch. | ||
|
||
Then, we can perform prediction on this single batch. | ||
|
||
`prediction = model.predict_on_batch(single_batch['inputs'])` | ||
|
||
This also returns a numpy array containing the output from the final softmax function. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import torch | ||
import collections | ||
import pdb | ||
import csv | ||
from kipoi.data import Dataset | ||
import math | ||
import numpy as np | ||
|
||
class HMData(Dataset): | ||
# Dataset class for loading data | ||
def __init__(self, input_file, bin_size=100): | ||
self.hm_data = self.loadData(input_file, bin_size) | ||
|
||
|
||
def loadData(self,filename,windows): | ||
with open(filename) as fi: | ||
csv_reader=csv.reader(fi) | ||
data=list(csv_reader) | ||
|
||
ncols=(len(data[0])) | ||
fi.close() | ||
nrows=len(data) | ||
ngenes=nrows/windows | ||
nfeatures=ncols-1 | ||
print("Number of genes: %d" % ngenes) | ||
print("Number of entries: %d" % nrows) | ||
print("Number of HMs: %d" % nfeatures) | ||
|
||
count=0 | ||
attr=collections.OrderedDict() | ||
|
||
for i in range(0,nrows,windows): | ||
hm1=torch.zeros(windows,1) | ||
hm2=torch.zeros(windows,1) | ||
hm3=torch.zeros(windows,1) | ||
hm4=torch.zeros(windows,1) | ||
hm5=torch.zeros(windows,1) | ||
for w in range(0,windows): | ||
hm1[w][0]=int(data[i+w][2]) | ||
hm2[w][0]=int(data[i+w][3]) | ||
hm3[w][0]=int(data[i+w][4]) | ||
hm4[w][0]=int(data[i+w][5]) | ||
hm5[w][0]=int(data[i+w][6]) | ||
geneID=str(data[i][0].split("_")[0]) | ||
|
||
thresholded_expr = int(data[i+w][7]) | ||
|
||
attr[count]={ | ||
'geneID':geneID, | ||
'expr':thresholded_expr, | ||
'hm1':hm1, | ||
'hm2':hm2, | ||
'hm3':hm3, | ||
'hm4':hm4, | ||
'hm5':hm5 | ||
} | ||
count+=1 | ||
|
||
return attr | ||
|
||
|
||
def __len__(self): | ||
return len(self.hm_data) | ||
|
||
def __getitem__(self,i): | ||
final_data=torch.cat((self.hm_data[i]['hm1'],self.hm_data[i]['hm2'],self.hm_data[i]['hm3'],self.hm_data[i]['hm4'],self.hm_data[i]['hm5']),1) | ||
final_data = final_data.numpy() | ||
label = self.hm_data[i]['expr'] | ||
geneID = self.hm_data[i]['geneID'] | ||
|
||
|
||
return_item={ | ||
'inputs': final_data, | ||
'metadata': {'geneID':geneID,'label':label} | ||
} | ||
|
||
return return_item |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
defined_as: dataloader.HMData | ||
args: | ||
input_file: | ||
doc: "Path of the histone modification read count file." | ||
example: | ||
url: https://zenodo.org/record/2640883/files/test.csv?download=1 | ||
md5: 0468f46aa1a3864283e87c7714d0a4e2 | ||
bin_size: | ||
doc: "Size of bin" | ||
optional: true | ||
dependencies: | ||
conda: # install via conda | ||
- python>=3.5 | ||
- pytorch::pytorch-cpu | ||
- numpy | ||
info: # General information about the dataloader | ||
authors: | ||
- name: Ritambhara Singh | ||
github: rs3zz | ||
- name: Jack Lanchantin | ||
github: jacklanchantin | ||
email: [email protected] | ||
- name: Arshdeep Sekhon | ||
github: ArshdeepSekhon | ||
- name: Yanjun Qi | ||
github: qiyanjun | ||
contributors: | ||
- name: Jack Lanchantin | ||
github: jacklanchantin | ||
- name: Jeffrey Yoo | ||
github: jeffreyyoo | ||
doc: "Dataloader for Gene Expression Prediction" | ||
cite_as: https://doi.org:/10.1101/329334 | ||
trained_on: Histone Modidification and RNA Seq Data From Roadmad/REMC database # short dataset description | ||
license: MIT | ||
output_schema: | ||
inputs: | ||
associated_metadata: geneID, label | ||
doc: Histone Modification Bin Matrix | ||
shape: (100, 5) # array shape of a single sample (omitting the batch dimension) | ||
metadata: | ||
geneID: | ||
doc: "gene ID" | ||
type: str | ||
label: | ||
doc: "label for gene expression (binary)" | ||
type: int | ||
type: Dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
type: pytorch | ||
args: | ||
module_file: models.py | ||
module_obj: att_chrome_model | ||
weights: | ||
url: {{model_url}} | ||
md5: {{model_md5}} | ||
default_dataloader: .. # path to the dataloader directory. Or to the dataloader class, e.g.: `kipoiseq.dataloaders.SeqIntervalDl | ||
|
||
info: # General information about the model | ||
authors: | ||
- name: Ritambhara Singh | ||
github: rs3zz | ||
- name: Jack Lanchantin | ||
github: jacklanchantin | ||
email: [email protected] | ||
- name: Arshdeep Sekhon | ||
github: ArshdeepSekhon | ||
- name: Yanjun Qi | ||
github: qiyanjun | ||
contributors: | ||
- name: Jack Lanchantin | ||
github: jacklanchantin | ||
- name: Jeffrey Yoo | ||
github: jeffreyyoo | ||
doc: Gene Expression Prediction | ||
cite_as: https://doi.org:/10.1101/329334 | ||
trained_on: Histone Modidification and RNA Seq Data From Roadmad/REMC database # short dataset description | ||
license: MIT | ||
dependencies: | ||
conda: # install via conda | ||
- python>=3.5 | ||
- numpy | ||
- pytorch::pytorch-cpu | ||
- pytorch::torchvision-cpu | ||
schema: # Model schema | ||
inputs: | ||
shape: (100, 5) # array shape of a single sample (omitting the batch dimension) | ||
doc: "Histone Modification Bin Matrix" | ||
targets: | ||
shape: (1, ) | ||
doc: "Binary Classification" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
from __future__ import print_function | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import numpy as np | ||
from pdb import set_trace as stop | ||
|
||
def batch_product(iput, mat2): | ||
result = None | ||
for i in range(iput.size()[0]): | ||
op = torch.mm(iput[i], mat2) | ||
op = op.unsqueeze(0) | ||
if(result is None): | ||
result = op | ||
else: | ||
result = torch.cat((result,op),0) | ||
return result.squeeze(2) | ||
|
||
|
||
class rec_attention(nn.Module): | ||
# attention with bin context vector per HM and HM context vector | ||
def __init__(self,hm,args): | ||
super(rec_attention,self).__init__() | ||
self.num_directions=2 if args.bidirectional else 1 | ||
if (hm==False): | ||
self.bin_rep_size=args.bin_rnn_size*self.num_directions | ||
else: | ||
self.bin_rep_size=args.bin_rnn_size | ||
|
||
self.bin_context_vector=nn.Parameter(torch.Tensor(self.bin_rep_size,1),requires_grad=True) | ||
|
||
|
||
self.softmax=nn.Softmax(dim=1) | ||
|
||
self.bin_context_vector.data.uniform_(-0.1, 0.1) | ||
|
||
def forward(self,iput): | ||
alpha=self.softmax(batch_product(iput,self.bin_context_vector)) | ||
[batch_size,source_length,bin_rep_size2]=iput.size() | ||
repres=torch.bmm(alpha.unsqueeze(2).view(batch_size,-1,source_length),iput) | ||
return repres,alpha | ||
|
||
|
||
|
||
class recurrent_encoder(nn.Module): | ||
# modular LSTM encoder | ||
def __init__(self,n_bins,ip_bin_size,hm,args): | ||
super(recurrent_encoder,self).__init__() | ||
self.bin_rnn_size=args.bin_rnn_size | ||
self.ipsize=ip_bin_size | ||
self.seq_length=n_bins | ||
|
||
self.num_directions=2 if args.bidirectional else 1 | ||
if (hm==False): | ||
self.bin_rnn_size=args.bin_rnn_size | ||
else: | ||
self.bin_rnn_size=args.bin_rnn_size // 2 | ||
self.bin_rep_size=self.bin_rnn_size*self.num_directions | ||
|
||
|
||
self.rnn=nn.LSTM(self.ipsize,self.bin_rnn_size,num_layers=args.num_layers,dropout=args.dropout,bidirectional=args.bidirectional) | ||
|
||
self.bin_attention=rec_attention(hm,args) | ||
def outputlength(self): | ||
return self.bin_rep_size | ||
def forward(self,single_hm,hidden=None): | ||
bin_output, hidden = self.rnn(single_hm,hidden) | ||
bin_output = bin_output.permute(1,0,2) | ||
hm_rep,bin_alpha = self.bin_attention(bin_output) | ||
return hm_rep,bin_alpha | ||
|
||
|
||
class AttrDict(dict): | ||
def __init__(self, *args, **kwargs): | ||
super(AttrDict, self).__init__(*args, **kwargs) | ||
self.__dict__ = self | ||
|
||
|
||
class att_chrome(nn.Module): | ||
def __init__(self,args): | ||
super(att_chrome,self).__init__() | ||
self.n_hms=args.n_hms | ||
self.n_bins=args.n_bins | ||
self.ip_bin_size=1 | ||
|
||
self.rnn_hms=nn.ModuleList() | ||
for i in range(self.n_hms): | ||
self.rnn_hms.append(recurrent_encoder(self.n_bins,self.ip_bin_size,False,args)) | ||
self.opsize = self.rnn_hms[0].outputlength() | ||
self.hm_level_rnn_1=recurrent_encoder(self.n_hms,self.opsize,True,args) | ||
self.opsize2=self.hm_level_rnn_1.outputlength() | ||
self.diffopsize=2*(self.opsize2) | ||
self.fdiff1_1=nn.Linear(self.opsize2,1) | ||
|
||
def forward(self,iput): | ||
|
||
bin_a=None | ||
level1_rep=None | ||
[batch_size,_,_]=iput.size() | ||
|
||
for hm,hm_encdr in enumerate(self.rnn_hms): | ||
hmod=iput[:,:,hm].contiguous() | ||
hmod=torch.t(hmod).unsqueeze(2) | ||
|
||
op,a= hm_encdr(hmod) | ||
if level1_rep is None: | ||
level1_rep=op | ||
bin_a=a | ||
else: | ||
level1_rep=torch.cat((level1_rep,op),1) | ||
bin_a=torch.cat((bin_a,a),1) | ||
level1_rep=level1_rep.permute(1,0,2) | ||
final_rep_1,hm_level_attention_1=self.hm_level_rnn_1(level1_rep) | ||
final_rep_1=final_rep_1.squeeze(1) | ||
prediction_m=((self.fdiff1_1(final_rep_1))) | ||
|
||
return torch.sigmoid(prediction_m) | ||
|
||
args_dict = {'lr': 0.0001, 'model_name': 'attchrome', 'clip': 1, 'epochs': 2, 'batch_size': 10, 'dropout': 0.5, 'cell_1': 'Cell1', 'save_root': 'Results/Cell1', 'data_root': 'data/', 'gpuid': 0, 'gpu': 0, 'n_hms': 5, 'n_bins': 200, 'bin_rnn_size': 32, 'num_layers': 1, 'unidirectional': False, 'save_attention_maps': False, 'attentionfilename': 'beta_attention.txt', 'test_on_saved_model': False, 'bidirectional': True, 'dataset': 'Cell1'} | ||
att_chrome_args = AttrDict(args_dict) | ||
att_chrome_model = att_chrome(att_chrome_args) | ||
|
Oops, something went wrong.