forked from raraz15/MaxGrooVAE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
IO.py
179 lines (159 loc) · 7.5 KB
/
IO.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# Most of the methods here are integrated for our purposes from https://goo.gl/magenta/groovae-colab
import copy
from collections import defaultdict
import numpy as np
import note_seq
from note_seq.protobuf import music_pb2
from magenta.models.music_vae import configs
dc_tap = configs.CONFIG_MAP['groovae_2bar_tap_fixed_velocity'].data_converter
model_config=configs.CONFIG_MAP['groovae_2bar_tap_fixed_velocity']
model_weights_path="groovae_2bar_tap_fixed_velocity.tar"
N_BARS=2
BEATS_PER_BAR=4
STEPS_PER_QUARTER_NOTE=4
N_STEPS=N_BARS*BEATS_PER_BAR*STEPS_PER_QUARTER_NOTE
VELOCITY=85 # Fixed Value
# If a sequence has notes at time before 0.0, scootch them up to 0
def start_notes_at_0(s):
for n in s.notes:
if n.start_time < 0:
n.end_time -= n.start_time
n.start_time = 0
return s
# Some midi files come by default from different instrument channels
# Quick and dirty way to set midi files to be recognized as drums
def set_to_drums(ns):
for n in ns.notes:
n.instrument=9
n.is_drum = True
# quickly change the tempo of a midi sequence and adjust all notes
def change_tempo(note_sequence, new_tempo):
new_sequence = copy.deepcopy(note_sequence)
ratio = note_sequence.tempos[0].qpm / new_tempo
for note in new_sequence.notes:
note.start_time = note.start_time * ratio
note.end_time = note.end_time * ratio
new_sequence.tempos[0].qpm = new_tempo
return new_sequence
# Calculate quantization steps but do not remove microtiming
def quantize(s, steps_per_quarter=4):
s_=copy.deepcopy(s)
return note_seq.sequences_lib.quantize_note_sequence(s_,steps_per_quarter)
# Destructively quantize a midi sequence
def flatten_quantization(s):
beat_length = 60. / s.tempos[0].qpm
step_length = beat_length / 4 #s.quantization_info.steps_per_quarter
new_s = copy.deepcopy(s)
for note in new_s.notes:
note.start_time = step_length * note.quantized_start_step
note.end_time = step_length * note.quantized_end_step
return new_s
def add_silent_note(note_sequence, num_bars):
tempo = note_sequence.tempos[0].qpm
length = 60/tempo * 4 * num_bars
note_sequence.notes.add(
instrument=9, pitch=42, velocity=0, start_time=length-0.02,
end_time=length-0.01, is_drum=True)
def is_4_4(s):
ts = s.time_signatures[0]
return (ts.numerator == 4 and ts.denominator ==4)
def quantize_to_beat_divisions(beat, division=32):
"""Quantize a floating point beat? to a 1/division'th beat"""
if division!=1:
return (beat//(1/division))*(1/division)
else: # do not quantize
return beat
# quick method for turning a drumbeat into a tapped rhythm
def get_tapped_2bar(s, velocity=VELOCITY, ride=False):
new_s = dc_tap.from_tensors(dc_tap.to_tensors(s).inputs)[0]
new_s = change_tempo(new_s, s.tempos[0].qpm)
if velocity != 0:
for n in new_s.notes:
n.velocity = velocity
if ride:
for n in new_s.notes:
n.pitch = 42
return new_s
def drumify(s, model, temperature=0.5, N=1):
encoding, mu, sigma = model.encode([s]*N) # Repeat the groove N times
decoded = model.decode(encoding, length=N_STEPS, temperature=temperature)
return decoded
def max_str_to_midi_array(max_str, BPM):
"""max_list timing are in bars. Assumes 4/4 timing"""
max_str=max_str.split(' ')
assert len(max_str)==3*N_STEPS, 'List length wrong!'
beat_dur=60/BPM # in sec
midi_array=[]
for i in range((len(max_str)//3)):
start_step=4*float(max_str[3*i]) # in beats
end_step=4*float(max_str[3*i+1]) # in beats
vel=float(max_str[3*i+2])
start_time=start_step*beat_dur
end_time=end_step*beat_dur
midi_array.append([start_time,end_time,vel])
return np.array(midi_array)
def make_tap_sequence(midi_array, BPM, velocity=VELOCITY, tpq=220):
"""Creates a NoteSequence object from a midi_array."""
note_sequence=music_pb2.NoteSequence()
note_sequence.tempos.add(qpm=BPM)
note_sequence.ticks_per_quarter=tpq
note_sequence.time_signatures.add(numerator=BEATS_PER_BAR, denominator=4)
note_sequence.key_signatures.add()
for onset_time, offset_time, onset_velocity in midi_array:
if onset_velocity: # Non-zero velocity notes only
note_sequence.notes.add(instrument=9, # Drum MIDI Program number
pitch=42, # Constant
is_drum=True,
velocity=velocity,
start_time=onset_time,
end_time=offset_time)
note_sequence.total_time=N_BARS*BEATS_PER_BAR*(60/BPM)
return note_sequence
def NN_output_to_Max(h, BPM, pre_quantization=False, beat_quantization_division=1):
"""Return a dict of {'drum': max_string} where a max_string is a concatenation of [start, duration, velocity]."""
_h=copy.deepcopy(h)
beat_dur=60/BPM
if pre_quantization:
_h=quantize(_h)
midi_lists=defaultdict(list)
for note in _h.notes:
assert note.start_time>=0, f'Negative Start time received from NN! {note.start_time}'
assert note.end_time>=0, f'Negative End time received from NN! {note.end_time}'
assert note.end_time>note.start_time, 'End time before start time from NN!'
start_beat=quantize_to_beat_divisions(note.start_time/beat_dur, beat_quantization_division)
end_beat=quantize_to_beat_divisions(note.end_time/beat_dur, beat_quantization_division)
assert start_beat>=0, f'Start beat quantized wrongly! {start_beat}'
assert end_beat>=0, f'End beat quantized wrongly! {end_beat}'
assert end_beat>start_beat, 'Duration quantized wrongly.'
start=int(1000*start_beat) # Convert to this format for Max
duration=int(1000*(end_beat-start_beat))
midi_lists[note.pitch].extend([start,duration,note.velocity])
# Cast all drums to a single space separated string for Max, sort for pretty print
drum_messages=sort_dict_by_key({drum: ' '.join([str(v) for v in array]) for drum,array in midi_lists.items()})
return drum_messages
# TODO: take beat_quantization_division
def max_to_NN_to_max(max_lst, BPM, model, temperature=1.0, beat_quantization_division=64, N=1):
"""takes a max list, gets N NN outputs and puts them in a Max readable format."""
# List to array
midi_array=max_str_to_midi_array(max_lst, BPM)
# Convert it into the pre-NN input format
note_sequence=make_tap_sequence(midi_array, BPM)
note_sequence=quantize(note_sequence)
set_to_drums(note_sequence)
# Convert to NN input format
note_sequence=start_notes_at_0(note_sequence)
note_sequence=change_tempo(get_tapped_2bar(note_sequence, velocity=VELOCITY, ride=True), BPM)
assert BPM==note_sequence.tempos[0].qpm, 'Tempo conversion failed at tapped bar creation'
# Get N Drum compositions in Note Sequence formats
compositions=drumify(note_sequence, model, temperature=temperature, N=N)
# Convert each composition into a dict containing Max readable drum messages
messages=[]
for h in compositions:
h=start_notes_at_0(h) # remove negative timings
h=change_tempo(h, BPM) # Inherited from Magenta don't know why
assert BPM==h.tempos[0].qpm, 'Tempo conversion failed at NN creation'
# Convert to Max messages
messages.append(NN_output_to_Max(h, BPM, beat_quantization_division=beat_quantization_division))
return messages
def sort_dict_by_key(dct):
return {k:dct[k] for k in sorted(dct.keys())}