sd-riffusion / app.py
juancopi81's picture
Add file download option
35afeb7
import random
from PIL import Image
from diffusers import StableDiffusionPipeline
import gradio as gr
import torch
from spectro import wav_bytes_from_spectrogram_image
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id,
torch_dtype=dtype,
revision="fp16")
pipe = pipe.to(device)
model_id2 = "riffusion/riffusion-model-v1"
pipe2 = StableDiffusionPipeline.from_pretrained(model_id2, torch_dtype=dtype)
pipe2 = pipe2.to(device)
COLORS = [
["#ff0000", "#00ff00"],
["#00ff00", "#0000ff"],
["#0000ff", "#ff0000"],
]
title = """
<div style="text-align: center; max-width: 650px; margin: 0 auto 10px;">
<div style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;">
<h1 style="font-weight: 950; margin-bottom: 7px; color: #000; font-weight: bold;">Riffusion and Stable Diffusion</h1>
</div>
<p style="text-align: center;font-size: 94%">
Duplicate this Space and run it on your own profile using a (paid) private T4-small or A10G-small GPU for training:
<span style="display: flex;align-items: center;justify-content: center;height: 30px;">
<a href="https://huggingface.co/spaces/juancopi81/sd-riffusion?duplicate=true">
<img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a>
</span>
</p>
<p style="text-align: center;font-size: 94%">
You can buy me a coffee to support this space:
<span style="display: flex;align-items: center;justify-content: center;height: 30px;">
<a href="https://www.buymeacoffee.com/juancopi81j">
<img src="https://badgen.net/badge/icon/Buy%20Me%20A%20Coffee?icon=buymeacoffee&label" alt="Buy me a coffee"></a>. Depending on the support, I'll keep this space running and add more features!
</span>
</p>
</div>
"""
def get_bg_image(prompt):
images = pipe(prompt)
print("Image generated!")
image_output = images.images[0]
image_output.save("img.png")
return "img.png"
def get_music(prompt):
duration = 10
if duration == 5:
width_duration=512
else :
width_duration = 512 + ((int(duration)-5) * 128)
spec = pipe2(prompt, height=512, width=width_duration).images[0]
print(spec)
wav = wav_bytes_from_spectrogram_image(spec)
with open("output.wav", "wb") as f:
f.write(wav[0].getbuffer())
return "output.wav"
def infer(prompt, style):
style_prompt = prompt + style
image = get_bg_image(style_prompt)
audio = get_music(prompt)
video = gr.make_waveform(audio,
bg_image=image,
bars_color=random.choice(COLORS))
return video, video
css = """
#col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
#prompt-in {
border: 2px solid #666;
border-radius: 2px;
padding: 8px;
}
#prompt-style {
border: 2px solid #666;
border-radius: 2px;
padding: 8px;
}
#btn-container {
display: flex;
align-items: center;
justify-content: center;
width: calc(15% - 16px);
height: calc(15% - 16px);
}
/* Style the submit button */
#submit-btn {
background-color: #382a1d;
color: #fff;
border: 1px solid #000;
border-radius: 4px;
padding: 8px;
font-size: 16px;
cursor: pointer;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(title)
with gr.Column(elem_id="col-container"):
prompt_input = gr.Textbox(placeholder="The Beatles playing for the queen",
elem_id="prompt-in",
label="Enter your music prompt.")
style_input = gr.Textbox(placeholder="In the style of Vincent van Gogh",
elem_id="prompt-style",
label="(Optional) Add styles to your background image.",
value="")
with gr.Row(elem_id="btn-container"):
send_btn = gr.Button(value="Send", elem_id="submit-btn")
send_btn.click(infer,
inputs=[prompt_input, style_input],
outputs=[gr.Video(), gr.File()])
gr.Markdown("""
[![Twitter Follow](https://img.shields.io/twitter/follow/juancopi81?style=social)](https://twitter.com/juancopi81)
![visitors](https://visitor-badge.glitch.me/badge?page_id=Juancopi81.sd-riffusion)
""")
demo.launch(debug=True)