diff --git a/.gitignore b/.gitignore index 3e6aee68..5b5bbc39 100644 --- a/.gitignore +++ b/.gitignore @@ -161,4 +161,6 @@ cython_debug/ *.ckpt *.wav -wandb/* \ No newline at end of file +wandb/* +models/* +outputs/* diff --git a/config/txt2audio.json b/config/txt2audio.json new file mode 100644 index 00000000..6c447ed9 --- /dev/null +++ b/config/txt2audio.json @@ -0,0 +1,3 @@ +{ + "model_selected": "" +} diff --git a/run_gradio.py b/run_gradio.py index ae3ba95c..18e61dd0 100644 --- a/run_gradio.py +++ b/run_gradio.py @@ -15,7 +15,7 @@ def main(args): model_half=args.model_half ) interface.queue() - interface.launch(share=True, auth=(args.username, args.password) if args.username is not None else None) + interface.launch(share=True, auth=(args.username, args.password) if args.username is not None else None, inbrowser=args.inbrowser if args.inbrowser is not None else False) if __name__ == "__main__": import argparse @@ -27,5 +27,6 @@ def main(args): parser.add_argument('--username', type=str, help='Gradio username', required=False) parser.add_argument('--password', type=str, help='Gradio password', required=False) parser.add_argument('--model-half', action='store_true', help='Whether to use half precision', required=False) + parser.add_argument('--inbrowser', action='store_true', help='Open browser on launch', required=False) args = parser.parse_args() main(args) \ No newline at end of file diff --git a/setup.py b/setup.py index 7e7470d3..0232b886 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ 'vector-quantize-pytorch==1.9.14', 'wandb==0.15.4', 'webdataset==0.2.48', - 'x-transformers<1.27.0' + 'x-transformers<1.27.0', + 'pytaglib==3.0.0' ], ) \ No newline at end of file diff --git a/stable_audio_tools/data/txt2audio_utils.py b/stable_audio_tools/data/txt2audio_utils.py new file mode 100644 index 00000000..e6b60279 --- /dev/null +++ b/stable_audio_tools/data/txt2audio_utils.py @@ -0,0 +1,136 @@ +import taglib +import os +from datetime import datetime +import platform +import subprocess +import json + +def set_selected_model(model_name): + if model_name in [data["name"] for data in get_models_data()]: + config = get_config() + config["model_selected"] = model_name + with open("config/txt2audio.json", "w") as file: + json.dump(config, file, indent=4) + file.write('\n') + +def get_config(): + with open("config/txt2audio.json") as file: + return json.load(file) + +def get_models_name(): + return [model["name"] for model in get_models_data()] + +def get_models_data(): + models = [] + file_types = ['.ckpt', '.safetensors', '.pth'] + for file in os.listdir("models/"): + _file = os.path.splitext(file) + config_path = f"models/{_file[0]}.json" + if _file[1] in file_types and os.path.isfile(config_path): + models.append({"name": _file[0], "path": f"models/{file}", "config_path": config_path}) + return models + +def open_outputs_path(): + outputs_dir = "outputs/" + outputs = outputs_dir + datetime.now().strftime('%Y-%m-%d') + + if not os.path.isdir(outputs): + if not os.path.isdir(outputs_dir): + return + else: + outputs = outputs_dir + + outputs = os.path.abspath(outputs) + if platform.system() == "Windows": + os.startfile(outputs) + elif platform.system() == "Darwin": + subprocess.Popen(["open", outputs]) + elif "microsoft-standard-WSL2" in platform.uname().release: + subprocess.Popen(["wsl-open", outputs]) + else: + subprocess.Popen(["xdg-open", outputs]) + +def create_output_path(suffix): + outputs = f"outputs/{datetime.now().strftime('%Y-%m-%d')}" + count = 0 + + if os.path.isdir(outputs): + counts = [os.path.splitext(file)[0].split('-')[0] for file in os.listdir(outputs) if file.endswith(".wav")] + count = max([int(i) for i in counts if i.isnumeric()]) + 1 + else: + os.makedirs(outputs) + + return f"{outputs}/{'{:05d}'.format(count)}-{suffix}.wav" + +def get_generation_data(file): + with taglib.File(file) as sound: + if len(sound.tags) != 1: + return None + + data = sound.tags["TITLE"] + + if len(data) != 12: + return None + if data[0] == "None": + data[0] = "" + if data[1] == "None": + data[1] = "" + if data[5] == "None": + data[5] = 0 + + for i in range(2, 8): + data[i] = int(data[i]) + + for i in range(9, 12): + data[i] = float(data[i]) + + data[4] = float(data[4]) + + return data + +def save_generation_data(sound_path, prompt, negative_prompt, seconds_start, seconds_total, steps, preview_every, cfg_scale, seed, sampler_type, sigma_min, sigma_max, cfg_rescale): + if prompt == "": + prompt = "None" + if negative_prompt == "": + negative_prompt = "None" + + with taglib.File(sound_path, save_on_exit=True) as sound: + sound.tags["TITLE"] = [ + prompt, + negative_prompt, + str(seconds_start), + str(seconds_total), + str(steps), + str(preview_every), + str(cfg_scale), + str(seed), + str(sampler_type), + str(sigma_min), + str(sigma_max), + str(cfg_rescale)] + +def txt2audio_css(): + return """ + div.svelte-sa48pu>*, div.svelte-sa48pu>.form>* { + flex: 1 1 0%; + flex-wrap: wrap; + min-width: min(40px, 100%); + } + + #refresh_btn { + padding: 0px; + } + + #selected_model_items div.svelte-1sk0pyu div.wrap.svelte-1sk0pyu div.wrap-inner.svelte-1sk0pyu div.secondary-wrap.svelte-1sk0pyu input.border-none.svelte-1sk0pyu { + margin: 0px; + } + + #prompt_options { + flex-wrap: nowrap; + height: 40px; + } + + #selected_model_container { + gap: 3px; + } + """ diff --git a/stable_audio_tools/interface/gradio.py b/stable_audio_tools/interface/gradio.py index b46c8d43..9bc3aa66 100644 --- a/stable_audio_tools/interface/gradio.py +++ b/stable_audio_tools/interface/gradio.py @@ -21,6 +21,7 @@ model = None sample_rate = 32000 sample_size = 1920000 +model_is_half = None def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False): global model, sample_rate, sample_size @@ -55,6 +56,32 @@ def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pr return model, model_config +def unload_model(): + global model + del model + model = None + torch.cuda.empty_cache() + gc.collect() + +def txt2audio_change_model(model_name): + from stable_audio_tools.data.txt2audio_utils import get_models_data, set_selected_model + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + for model_data in get_models_data(): + if model_data["name"] == model_name: + unload_model() + set_selected_model(model_name) + model_config = get_model_config_from_path(model_data["config_path"]) + load_model(model_config, model_data["path"], model_half=model_is_half, device=device) + return model_name + +def get_model_config_from_path(model_config_path): + if model_config_path is not None: + # Load config from json file + with open(model_config_path) as f: + return json.load(f) + else: + return None + def generate_cond( prompt, negative_prompt=None, @@ -163,6 +190,7 @@ def progress_callback(callback_info): else: mask_args = None + seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32) # Do the audio generation audio = generate_diffusion_cond( model, @@ -188,12 +216,16 @@ def progress_callback(callback_info): # Convert to WAV file audio = rearrange(audio, "b d n -> d (b n)") audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() - torchaudio.save("output.wav", audio, sample_rate) + + from stable_audio_tools.data.txt2audio_utils import create_output_path, save_generation_data + output_path = create_output_path(seed) + torchaudio.save(output_path, audio, sample_rate) + save_generation_data(output_path, prompt, negative_prompt, seconds_start, seconds_total, steps, preview_every, cfg_scale, seed, sampler_type, sigma_min, sigma_max, cfg_rescale) # Let's look at a nice spectrogram too audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) - return ("output.wav", [audio_spectrogram, *preview_images]) + return (output_path, [audio_spectrogram, *preview_images]) def generate_uncond( steps=250, @@ -380,6 +412,10 @@ def create_sampling_ui(model_config, inpainting=False): with gr.Column(scale=6): prompt = gr.Textbox(show_label=False, placeholder="Prompt") negative_prompt = gr.Textbox(show_label=False, placeholder="Negative prompt") + with gr.Row(elem_id="prompt_options"): + clear_prompt = gr.Button('\U0001f5d1\ufe0f') + paste_generation_data = gr.Button('\u2199\ufe0f') + insert_generation_data = gr.File(label="Insert generation data from output.wav", file_types=[".wav"], scale=0) generate_button = gr.Button("Generate", variant='primary', scale=1) model_conditioning_config = model_config["model"].get("conditioning", None) @@ -492,11 +528,31 @@ def create_sampling_ui(model_config, inpainting=False): with gr.Column(): audio_output = gr.Audio(label="Output audio", interactive=False) - audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False) - send_to_init_button = gr.Button("Send to init audio", scale=1) + audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False) + with gr.Row(): + open_outputs_folder = gr.Button("\U0001f4c1", scale=1) + send_to_init_button = gr.Button("Send to init audio", scale=1) + from stable_audio_tools.data.txt2audio_utils import open_outputs_path + open_outputs_folder.click(fn=open_outputs_path) send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input]) - generate_button.click(fn=generate_cond, + from stable_audio_tools.data.txt2audio_utils import get_generation_data + paste_generation_data.click(fn=get_generation_data, inputs=[insert_generation_data], outputs=[prompt, + negative_prompt, + seconds_start_slider, + seconds_total_slider, + steps_slider, + preview_every_slider, + cfg_scale_slider, + seed_textbox, + sampler_type_dropdown, + sigma_min_slider, + sigma_max_slider, + cfg_rescale_slider]) + + clear_prompt.click(fn=lambda: ("", ""), outputs=[prompt, negative_prompt]) + + generate_button.click(fn=generate_cond, inputs=inputs, outputs=[ audio_output, @@ -504,9 +560,17 @@ def create_sampling_ui(model_config, inpainting=False): ], api_name="generate") - def create_txt2audio_ui(model_config): - with gr.Blocks() as ui: + from stable_audio_tools.data.txt2audio_utils import txt2audio_css, get_models_name, get_config + with gr.Blocks(css=txt2audio_css()) as ui: + with gr.Column(elem_id="selected_model_container"): + gr.HTML('', visible=True) + with gr.Row(): + selected_model_dropdown = gr.Dropdown(get_models_name(), container=False, value=get_config()["model_selected"], interactive=True, scale=0, min_width=265, elem_id="selected_model_items") + selected_model_dropdown.change(fn=txt2audio_change_model, inputs=selected_model_dropdown, outputs=selected_model_dropdown) + refresh_models_button = gr.Button("\U0001f504", scale=0, elem_id="refresh_btn") + refresh_models_button.click(fn=lambda: gr.Dropdown(choices=get_models_name()), outputs=selected_model_dropdown) + gr.HTML('