-
Notifications
You must be signed in to change notification settings - Fork 103
/
lyrics2lofi_dataset.py
46 lines (37 loc) · 1.63 KB
/
lyrics2lofi_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import json
import numpy as np
import torch
from torch.utils.data import Dataset
from model.dataset import *
class Lyrics2LofiDataset(Dataset):
def __init__(self, dataset_folder, files, embeddings_file, embedding_lengths_file):
super(Lyrics2LofiDataset, self).__init__()
self.samples = []
self.embedding_lengths = []
with open(embedding_lengths_file) as embeddings_length_json:
embedding_lengths_json = json.load(embeddings_length_json)
for file in files:
with open(f"{dataset_folder}/{file}") as sample_file_json:
json_loaded = json.load(sample_file_json)
self.embedding_lengths.append(embedding_lengths_json[file])
sample = process_sample(json_loaded)
self.samples.append(sample)
self.embeddings = np.load(f"{embeddings_file}.npy", mmap_mode="r")
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
sample = self.samples[index]
embedding = np.copy(self.embeddings[index])
embedding_length = self.embedding_lengths[index]
return {
"embedding": embedding,
"embedding_length": embedding_length,
"key": sample["key"],
"mode": sample["mode"],
"chords": torch.tensor(sample["chords"]),
"num_chords": sample["num_chords"],
"melody_notes": torch.tensor(sample["melody_notes"]),
"tempo": sample["tempo"],
"energy": sample["energy"],
"valence": sample["valence"]
}