Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove label dtype/depend on BED row format #48

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 4 additions & 15 deletions kipoiseq/dataloaders/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class BedDataset(object):
bed_columns: number of columns corresponding to the bed file. All the columns
after that will be parsed as targets
num_chr: if specified, 'chr' in the chromosome name will be dropped
label_dtype: specific data type for labels, Example: `float` or `np.float32`
ambiguous_mask: if specified, rows containing only ambiguous_mask values will be skipped
incl_chromosomes: exclusive list of chromosome names to include in the final dataset.
if not None, only these will be present in the dataset
Expand All @@ -65,7 +64,6 @@ class BedDataset(object):
int] # blockStarts

def __init__(self, tsv_file,
label_dtype=None,
bed_columns=3,
num_chr=False,
ambiguous_mask=None,
Expand All @@ -76,7 +74,6 @@ def __init__(self, tsv_file,
self.tsv_file = tsv_file
self.bed_columns = bed_columns
self.num_chr = num_chr
self.label_dtype = label_dtype
self.ambiguous_mask = ambiguous_mask
self.incl_chromosomes = incl_chromosomes
self.excl_chromosomes = excl_chromosomes
Expand All @@ -95,8 +92,7 @@ def __init__(self, tsv_file,
self.df = pd.read_table(self.tsv_file,
header=None,
dtype={i: d
for i, d in enumerate(self.bed_types[:self.bed_columns] +
[self.label_dtype] * self.n_tasks)},
for i, d in enumerate(self.bed_types[:found_columns])},
sep='\t')
if self.num_chr and self.df.iloc[0][0].startswith("chr"):
self.df[0] = self.df[0].str.replace("^chr", "")
Expand All @@ -122,14 +118,14 @@ def __getitem__(self, idx):
if self.ignore_targets or self.n_tasks == 0:
labels = {}
else:
labels = row.iloc[self.bed_columns:].values.astype(self.label_dtype)
labels = row.iloc[self.bed_columns:].values
return interval, labels

def __len__(self):
return len(self.df)

def get_targets(self):
return self.df.iloc[:, self.bed_columns:].values.astype(self.label_dtype)
return self.df.iloc[:, self.bed_columns:].values


@kipoi_dataloader(override={"dependencies": deps, 'info.authors': package_authors})
Expand All @@ -153,8 +149,6 @@ class StringSeqIntervalDl(Dataset):
md5: 01320157a250a3d2eea63e89ecf79eba
num_chr_fasta:
doc: True, the the dataloader will make sure that the chromosomes don't start with chr.
label_dtype:
doc: None, datatype of the task labels taken from the intervals_file. Example - str, int, float, np.float32
auto_resize_len:
doc: None, required sequence length.
# max_seq_len:
Expand Down Expand Up @@ -189,7 +183,6 @@ def __init__(self,
intervals_file,
fasta_file,
num_chr_fasta=False,
label_dtype=None,
auto_resize_len=None,
# max_seq_len=None,
# use_strand=False,
Expand All @@ -213,7 +206,6 @@ def __init__(self,
self.bed = BedDataset(self.intervals_file,
num_chr=self.num_chr_fasta,
bed_columns=3,
label_dtype=parse_dtype(label_dtype),
ignore_targets=ignore_targets)
self.fasta_extractors = None

Expand Down Expand Up @@ -281,8 +273,6 @@ class SeqIntervalDl(Dataset):
md5: 01320157a250a3d2eea63e89ecf79eba
num_chr_fasta:
doc: True, the the dataloader will make sure that the chromosomes don't start with chr.
label_dtype:
doc: 'None, datatype of the task labels taken from the intervals_file. Example: str, int, float, np.float32'
auto_resize_len:
doc: None, required sequence length.
# use_strand:
Expand Down Expand Up @@ -324,7 +314,6 @@ def __init__(self,
intervals_file,
fasta_file,
num_chr_fasta=False,
label_dtype=None,
auto_resize_len=None,
# max_seq_len=None,
# use_strand=False,
Expand All @@ -335,7 +324,7 @@ def __init__(self,
dtype=None):
# core dataset, not using the one-hot encoding params
self.seq_dl = StringSeqIntervalDl(intervals_file, fasta_file, num_chr_fasta=num_chr_fasta,
label_dtype=label_dtype, auto_resize_len=auto_resize_len,
auto_resize_len=auto_resize_len,
# use_strand=use_strand,
ignore_targets=ignore_targets)

Expand Down