diff --git a/graphmuse/samplers/__init__.py b/graphmuse/samplers/__init__.py index c4f9b98..21ae6fb 100644 --- a/graphmuse/samplers/__init__.py +++ b/graphmuse/samplers/__init__.py @@ -184,17 +184,20 @@ def sample_neighbors_in_score_graph(note_array, depth, samples_per_node, targets samples_per_layer: PyList(type=np.ndarray, length=depth+1) List of numpy arrays of nodes (called layers) where the last layer corresponds to 'targets' and each n-th layer which isn't the last is a subset of the pre-neighborhood of the n+1-th layer edges_between_layers: PyList(type=np.ndarray(2, N), length=depth) - List of numpy arrays of edges which show how 2 consecutive layers in samples_per_layer are connected + List of numpy arrays of edges which show how 2 consecutive layers in samples_per_layer are connected + total_samples : numpy.ndarray + the union of samples_per_layer """ assert len(targets)>0 onsets = note_array["onset_div"].astype(numpy.int32) durations = note_array["duration_div"].astype(numpy.int32) - samples_per_layer, edges_between_layers = c_sample_neighbors_in_score_graph(onsets, durations, depth, samples_per_node, targets) + samples_per_layer, edges_between_layers, total_samples = c_sample_neighbors_in_score_graph(onsets, durations, depth, samples_per_node, targets) # move to torch tensors samples_per_layer = [torch.from_numpy(layer) for layer in samples_per_layer] edges_between_layers = [torch.from_numpy(edges) for edges in edges_between_layers] - return samples_per_layer, edges_between_layers + total_samples = torch.from_numpy(total_samples) + return samples_per_layer, edges_between_layers, total_samples def sample_preneighbors_within_region(cgraph, region, samples_per_node=10): diff --git a/graphmuse/samplers/sampling_sketch.py b/graphmuse/samplers/sampling_sketch.py index 7b15e14..dd727ff 100644 --- a/graphmuse/samplers/sampling_sketch.py +++ b/graphmuse/samplers/sampling_sketch.py @@ -149,7 +149,7 @@ def sample_from_graphlist(self, graphlist): if self.sample_rightmost: # Sample rightmost node-wise by num layers (because of reverse edges missing) - right_layers, edges_between_right_layers = csamplers.sample_neighbors_in_score_graph(random_graph.note_array, self.num_layers-2, self.samples_per_node, right_extension) + right_layers, edges_between_right_layers, total_right_samples = csamplers.sample_neighbors_in_score_graph(random_graph.note_array, self.num_layers-2, self.samples_per_node, right_extension) else: right_layers, edges_between_right_layers = [], [] edges_between_layers = torch.cat((left_edges, right_edges, edges_between_right_layers, edges_between_left_layers), dim=1) @@ -158,8 +158,7 @@ def sample_from_graphlist(self, graphlist): # Translate edges to subgraph indices (do this on GPU when available). subgraph_edge_index = torch.cat((edges_within_region, edges_between_layers), dim=1).to(self.device) - # TODO: add total right samples - sampled_nodes = torch.cat((torch.arange(region[0], region[1]), sampled_nodes_first_layer, total_left_samples)).to(self.device) + sampled_nodes = torch.cat((torch.arange(region[0], region[1]), sampled_nodes_first_layer, total_left_samples, total_right_samples)).to(self.device) new_mapping = torch.arange(sampled_nodes.shape[0], device=self.device) nodes_remap = torch.empty_like(random_graph.x.shape[0]).to(self.device) nodes_remap[sampled_nodes] = new_mapping diff --git a/include/musical_sampling.c b/include/musical_sampling.c index d567b90..0fcce08 100644 --- a/include/musical_sampling.c +++ b/include/musical_sampling.c @@ -349,6 +349,10 @@ static PyObject* sample_neighbors_in_score_graph(PyObject* csamplers, PyObject* HashSet node_hash_set; HashSet_new(&node_hash_set, (Index)PyArray_SIZE(prev_layer)); + HashSet total_samples; + HashSet_new(&total_samples, (Index)PyArray_SIZE(prev_layer)); + HashSet_init(&total_samples); + HashSet node_tracker; HashSet_new(&node_tracker, samples_per_node); @@ -406,6 +410,7 @@ static PyObject* sample_neighbors_in_score_graph(PyObject* csamplers, PyObject* if(neighbor_count <= samples_per_node){ for(Index j=lower_bound; j