Skip to content

Commit

Permalink
added total_samples to return of sample_neighbors_in_scoregraph
Browse files Browse the repository at this point in the history
  • Loading branch information
nimrodVarga committed Dec 11, 2023
1 parent fa30e9d commit fb532a7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
9 changes: 6 additions & 3 deletions graphmuse/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,17 +184,20 @@ def sample_neighbors_in_score_graph(note_array, depth, samples_per_node, targets
samples_per_layer: PyList(type=np.ndarray, length=depth+1)
List of numpy arrays of nodes (called layers) where the last layer corresponds to 'targets' and each n-th layer which isn't the last is a subset of the pre-neighborhood of the n+1-th layer
edges_between_layers: PyList(type=np.ndarray(2, N), length=depth)
List of numpy arrays of edges which show how 2 consecutive layers in samples_per_layer are connected
List of numpy arrays of edges which show how 2 consecutive layers in samples_per_layer are connected
total_samples : numpy.ndarray
the union of samples_per_layer
"""
assert len(targets)>0

onsets = note_array["onset_div"].astype(numpy.int32)
durations = note_array["duration_div"].astype(numpy.int32)
samples_per_layer, edges_between_layers = c_sample_neighbors_in_score_graph(onsets, durations, depth, samples_per_node, targets)
samples_per_layer, edges_between_layers, total_samples = c_sample_neighbors_in_score_graph(onsets, durations, depth, samples_per_node, targets)
# move to torch tensors
samples_per_layer = [torch.from_numpy(layer) for layer in samples_per_layer]
edges_between_layers = [torch.from_numpy(edges) for edges in edges_between_layers]
return samples_per_layer, edges_between_layers
total_samples = torch.from_numpy(total_samples)
return samples_per_layer, edges_between_layers, total_samples


def sample_preneighbors_within_region(cgraph, region, samples_per_node=10):
Expand Down
5 changes: 2 additions & 3 deletions graphmuse/samplers/sampling_sketch.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def sample_from_graphlist(self, graphlist):

if self.sample_rightmost:
# Sample rightmost node-wise by num layers (because of reverse edges missing)
right_layers, edges_between_right_layers = csamplers.sample_neighbors_in_score_graph(random_graph.note_array, self.num_layers-2, self.samples_per_node, right_extension)
right_layers, edges_between_right_layers, total_right_samples = csamplers.sample_neighbors_in_score_graph(random_graph.note_array, self.num_layers-2, self.samples_per_node, right_extension)
else:
right_layers, edges_between_right_layers = [], []
edges_between_layers = torch.cat((left_edges, right_edges, edges_between_right_layers, edges_between_left_layers), dim=1)
Expand All @@ -158,8 +158,7 @@ def sample_from_graphlist(self, graphlist):

# Translate edges to subgraph indices (do this on GPU when available).
subgraph_edge_index = torch.cat((edges_within_region, edges_between_layers), dim=1).to(self.device)
# TODO: add total right samples
sampled_nodes = torch.cat((torch.arange(region[0], region[1]), sampled_nodes_first_layer, total_left_samples)).to(self.device)
sampled_nodes = torch.cat((torch.arange(region[0], region[1]), sampled_nodes_first_layer, total_left_samples, total_right_samples)).to(self.device)
new_mapping = torch.arange(sampled_nodes.shape[0], device=self.device)
nodes_remap = torch.empty_like(random_graph.x.shape[0]).to(self.device)
nodes_remap[sampled_nodes] = new_mapping
Expand Down
12 changes: 11 additions & 1 deletion include/musical_sampling.c
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,10 @@ static PyObject* sample_neighbors_in_score_graph(PyObject* csamplers, PyObject*
HashSet node_hash_set;
HashSet_new(&node_hash_set, (Index)PyArray_SIZE(prev_layer));

HashSet total_samples;
HashSet_new(&total_samples, (Index)PyArray_SIZE(prev_layer));
HashSet_init(&total_samples);

HashSet node_tracker;
HashSet_new(&node_tracker, samples_per_node);

Expand Down Expand Up @@ -406,6 +410,7 @@ static PyObject* sample_neighbors_in_score_graph(PyObject* csamplers, PyObject*
if(neighbor_count <= samples_per_node){
for(Index j=lower_bound; j<upper_bound; j++){
HashSet_add_node(&node_hash_set, j);
HashSet_add_node(&total_samples, j);

edge_list[2*edge_list_cursor]=j;
edge_list[2*edge_list_cursor+1]=i;
Expand Down Expand Up @@ -435,6 +440,7 @@ static PyObject* sample_neighbors_in_score_graph(PyObject* csamplers, PyObject*
Node j = perm[rand_ix];

HashSet_add_node(&node_hash_set, j);
HashSet_add_node(&total_samples, j);

edge_list[2*edge_list_cursor]=j;
edge_list[2*edge_list_cursor+1]=i;
Expand All @@ -458,6 +464,7 @@ static PyObject* sample_neighbors_in_score_graph(PyObject* csamplers, PyObject*
}

HashSet_add_node(&node_hash_set, j);
HashSet_add_node(&total_samples, j);


edge_list[2*edge_list_cursor]=j;
Expand All @@ -480,11 +487,14 @@ static PyObject* sample_neighbors_in_score_graph(PyObject* csamplers, PyObject*
PyList_SET_ITEM(edges_between_layers, layer-1, (PyObject*)edges);
}

PyArrayObject* np_total_samples = HashSet_to_numpy(&total_samples);

HashSet_free(&total_samples);
HashSet_free(&node_tracker);
HashSet_free(&node_hash_set);
free(edge_list);

return PyTuple_Pack(2, samples_per_layer, edges_between_layers);
return PyTuple_Pack(3, samples_per_layer, edges_between_layers, np_total_samples);
}

static PyObject* sample_preneighbors_within_region(PyObject* csamplers, PyObject* args){
Expand Down

0 comments on commit fb532a7

Please sign in to comment.