diff --git a/graphmuse/samplers/sampling_sketch.py b/graphmuse/samplers/sampling_sketch.py index 919ba03..ce12c99 100644 --- a/graphmuse/samplers/sampling_sketch.py +++ b/graphmuse/samplers/sampling_sketch.py @@ -3,7 +3,6 @@ import torch import graphmuse.samplers as csamplers from graphmuse.utils.graph import HeteroScoreGraph -from typing import List class SubgraphCreationSampler(Sampler): @@ -27,13 +26,13 @@ class SubgraphCreationSampler(Sampler): subgraphs_per_max_size : int The number of subgraphs to create for each max size. """ - def __init__(self, data_source, max_subgraph_size=100, drop_last=False, batch_size=64, train_idx=None, subgraphs_per_max_size:int=5): - self.data_source = data_source + def __init__(self, graphs, max_subgraph_size=100, drop_last=False, batch_size=64, train_idx=None, subgraphs_per_max_size:int=5): + self.data_source = graphs bucket_boundaries = [2*max_subgraph_size, 5*max_subgraph_size, 10*max_subgraph_size, 20*max_subgraph_size] self.sampling_sizes = np.array([2, 4, 10, 20, 40])*subgraphs_per_max_size ind_n_len = [] - train_idx = train_idx if train_idx is not None else list() - for i, g in enumerate(data_source.graphs): + train_idx = train_idx if train_idx is not None else range(len(graphs)) + for i, g in enumerate(graphs): if i in train_idx: ind_n_len.append((i, g.x.shape[0])) self.ind_n_len = ind_n_len @@ -102,7 +101,7 @@ class MuseDataloader(DataLoader): num_workers : int The number of workers. """ - def __init__(self, graphs, subgraph_size, subgraphs, num_layers=0, samples_per_node=0, batch_size=1, num_workers=0): + def __init__(self, graphs, subgraph_size, subgraphs, num_layers=0, samples_per_node=3, batch_size=1, num_workers=0): self.graphs = graphs self.subgraph_size = subgraph_size self.subgraphs = subgraphs @@ -110,10 +109,11 @@ def __init__(self, graphs, subgraph_size, subgraphs, num_layers=0, samples_per_n self.samples_per_node = samples_per_node self.onsets = {} self.onset_count = {} - batch_sampler = SubgraphCreationSampler(self, max_subgraph_size=subgraph_size, drop_last=False, batch_size=batch_size) - super().__init__(batch_sampler=batch_sampler, batch_size=1, collate_fn=self.collate_fn, num_workers=num_workers) + dataset = range(len(graphs)) + batch_sampler = SubgraphCreationSampler(graphs, max_subgraph_size=subgraph_size, drop_last=False, batch_size=batch_size) + super().__init__(self, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=self.collate_graph_fn) - def collate_fn(self, batch): + def collate_graph_fn(self, batch): graphlist = self.graphs[batch] out = self.sample_from_graphlist(graphlist, self.subgraph_size, self.subgraphs, self.num_layers) return out @@ -147,3 +147,6 @@ def sample_from_graphlist(self, graphlist, subgraph_size, subgraphs, num_layers= return subgraph_samples + def __getitem__(self, idx): + return self.sample_from_graphlist([self.graphs[idx]], self.subgraph_size, self.subgraphs, self.num_layers) + diff --git a/graphmuse/utils/graph.py b/graphmuse/utils/graph.py index 99271ee..cdaf1bf 100644 --- a/graphmuse/utils/graph.py +++ b/graphmuse/utils/graph.py @@ -8,6 +8,7 @@ import pickle from .general import MapDict import warnings +import graphmuse.samplers as csamplers class HeteroScoreGraph(object): @@ -49,6 +50,7 @@ def __init__(self, note_features, edges, etypes=["onset", "consecutive", "during self.note_array = note_array self.edge_type = torch.from_numpy(edges[-1]).long() self.edge_index = torch.from_numpy(edges[:2]).long() + self.c_graph = csamplers.graph(edges.astype(np.uint32)) self.edge_weights = torch.ones(len(self.edge_index[0])) if edge_weights is None else torch.from_numpy( edge_weights) self.name = name diff --git a/tests/test_musical_sampling.py b/tests/test_musical_sampling.py index 4f87548..6de27d8 100644 --- a/tests/test_musical_sampling.py +++ b/tests/test_musical_sampling.py @@ -3,7 +3,7 @@ import numpy as np -num_graphs = 100 +num_graphs = 10 max_nodes = 5000 min_nodes = 500 max_dur = 20 @@ -37,7 +37,7 @@ graphs.append(graph) # create dataloader -dataloader = MuseDataloader(graphs, subgraph_size=100, subgraphs=10, batch_size=1, num_workers=0) +dataloader = MuseDataloader(graphs, subgraph_size=100, subgraphs=4, batch_size=1, num_workers=1) # iterate over dataloader batch = next(iter(dataloader))