Skip to content

Commit

Permalink
minor corrections for trying to make sampling work.
Browse files Browse the repository at this point in the history
  • Loading branch information
manoskary committed Oct 31, 2023
1 parent ddd1883 commit 33221ca
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
21 changes: 12 additions & 9 deletions graphmuse/samplers/sampling_sketch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
import graphmuse.samplers as csamplers
from graphmuse.utils.graph import HeteroScoreGraph
from typing import List


class SubgraphCreationSampler(Sampler):
Expand All @@ -27,13 +26,13 @@ class SubgraphCreationSampler(Sampler):
subgraphs_per_max_size : int
The number of subgraphs to create for each max size.
"""
def __init__(self, data_source, max_subgraph_size=100, drop_last=False, batch_size=64, train_idx=None, subgraphs_per_max_size:int=5):
self.data_source = data_source
def __init__(self, graphs, max_subgraph_size=100, drop_last=False, batch_size=64, train_idx=None, subgraphs_per_max_size:int=5):
self.data_source = graphs
bucket_boundaries = [2*max_subgraph_size, 5*max_subgraph_size, 10*max_subgraph_size, 20*max_subgraph_size]
self.sampling_sizes = np.array([2, 4, 10, 20, 40])*subgraphs_per_max_size
ind_n_len = []
train_idx = train_idx if train_idx is not None else list()
for i, g in enumerate(data_source.graphs):
train_idx = train_idx if train_idx is not None else range(len(graphs))
for i, g in enumerate(graphs):
if i in train_idx:
ind_n_len.append((i, g.x.shape[0]))
self.ind_n_len = ind_n_len
Expand Down Expand Up @@ -102,18 +101,19 @@ class MuseDataloader(DataLoader):
num_workers : int
The number of workers.
"""
def __init__(self, graphs, subgraph_size, subgraphs, num_layers=0, samples_per_node=0, batch_size=1, num_workers=0):
def __init__(self, graphs, subgraph_size, subgraphs, num_layers=0, samples_per_node=3, batch_size=1, num_workers=0):
self.graphs = graphs
self.subgraph_size = subgraph_size
self.subgraphs = subgraphs
self.num_layers = num_layers # This is for a later version with node-wise sampling
self.samples_per_node = samples_per_node
self.onsets = {}
self.onset_count = {}
batch_sampler = SubgraphCreationSampler(self, max_subgraph_size=subgraph_size, drop_last=False, batch_size=batch_size)
super().__init__(batch_sampler=batch_sampler, batch_size=1, collate_fn=self.collate_fn, num_workers=num_workers)
dataset = range(len(graphs))
batch_sampler = SubgraphCreationSampler(graphs, max_subgraph_size=subgraph_size, drop_last=False, batch_size=batch_size)
super().__init__(self, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=self.collate_graph_fn)

def collate_fn(self, batch):
def collate_graph_fn(self, batch):
graphlist = self.graphs[batch]
out = self.sample_from_graphlist(graphlist, self.subgraph_size, self.subgraphs, self.num_layers)
return out
Expand Down Expand Up @@ -147,3 +147,6 @@ def sample_from_graphlist(self, graphlist, subgraph_size, subgraphs, num_layers=

return subgraph_samples

def __getitem__(self, idx):
return self.sample_from_graphlist([self.graphs[idx]], self.subgraph_size, self.subgraphs, self.num_layers)

2 changes: 2 additions & 0 deletions graphmuse/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pickle
from .general import MapDict
import warnings
import graphmuse.samplers as csamplers


class HeteroScoreGraph(object):
Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(self, note_features, edges, etypes=["onset", "consecutive", "during
self.note_array = note_array
self.edge_type = torch.from_numpy(edges[-1]).long()
self.edge_index = torch.from_numpy(edges[:2]).long()
self.c_graph = csamplers.graph(edges.astype(np.uint32))
self.edge_weights = torch.ones(len(self.edge_index[0])) if edge_weights is None else torch.from_numpy(
edge_weights)
self.name = name
Expand Down
4 changes: 2 additions & 2 deletions tests/test_musical_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np


num_graphs = 100
num_graphs = 10
max_nodes = 5000
min_nodes = 500
max_dur = 20
Expand Down Expand Up @@ -37,7 +37,7 @@
graphs.append(graph)

# create dataloader
dataloader = MuseDataloader(graphs, subgraph_size=100, subgraphs=10, batch_size=1, num_workers=0)
dataloader = MuseDataloader(graphs, subgraph_size=100, subgraphs=4, batch_size=1, num_workers=1)
# iterate over dataloader
batch = next(iter(dataloader))

Expand Down

0 comments on commit 33221ca

Please sign in to comment.