Spaces:
Running
Running
import os | |
import io | |
import random | |
import requests | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
MAX_SEED = np.iinfo(np.int32).max | |
API_TOKEN = os.getenv("HF_TOKEN") | |
headers = {"Authorization": f"Bearer {API_TOKEN}"} | |
timeout = 100 | |
def split_image(input_image, num_splits=4): | |
output_images = [] | |
box_size = 512 # Each split image will be 512x512 | |
coordinates = [ | |
(0, 0, box_size, box_size), # Top-left | |
(box_size, 0, 1024, box_size), # Top-right | |
(0, box_size, box_size, 1024), # Bottom-left | |
(box_size, box_size, 1024, 1024) # Bottom-right | |
] | |
# Crop each region using predefined coordinates | |
for box in coordinates: | |
output_images.append(input_image.crop(box)) | |
return output_images | |
# Function to export split images to GIF | |
def export_to_gif(images, output_path, fps=4): | |
# Calculate duration per frame in milliseconds based on fps | |
duration = int(1000 / fps) | |
# Create a GIF from the list of images | |
images[0].save( | |
output_path, | |
save_all=True, | |
append_images=images[1:], | |
duration=duration, # Duration between frames | |
loop=0 # Loop forever | |
) | |
def predict(prompt, seed=-1, randomize_seed=True, guidance_scale=3.5, num_inference_steps=28, lora_id="black-forest-labs/FLUX.1-dev", progress=gr.Progress(track_tqdm=True)): | |
prompt_template = f"""a 2x2 total 4 grid of frames, showing consecutive stills from a looped gif of {prompt}""" | |
if lora_id.strip() == "" or lora_id == None: | |
lora_id = "black-forest-labs/FLUX.1-dev" | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
key = random.randint(0, 999) | |
API_URL = "https://api-inference.huggingface.co/models/"+ lora_id.strip() | |
API_TOKEN = random.choice([os.getenv("HF_TOKEN")]) | |
headers = {"Authorization": f"Bearer {API_TOKEN}"} | |
payload = { | |
"inputs": prompt_template, | |
"steps": num_inference_steps, | |
"cfg_scale": guidance_scale, | |
"seed": seed, | |
} | |
response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout) | |
if response.status_code != 200: | |
print(f"Error: Failed to get image. Response status: {response.status_code}") | |
print(f"Response content: {response.text}") | |
if response.status_code == 503: | |
raise gr.Error(f"{response.status_code} : The model is being loaded") | |
raise gr.Error(f"{response.status_code}") | |
try: | |
image_bytes = response.content | |
image = Image.open(io.BytesIO(image_bytes)) | |
print(f'\033[1mGeneration {key} completed!\033[0m ({prompt})') | |
split_images = split_image(image, num_splits=4) | |
# Path to save the GIF | |
gif_path = "output.gif" | |
# Export the split images to GIF | |
export_to_gif(split_images, gif_path, fps=4) | |
return gif_path, image, seed | |
except Exception as e: | |
print(f"Error when trying to open the image: {e}") | |
return None | |
demo = gr.Interface(fn=predict, inputs="text", outputs="image") | |
css=""" | |
#col-container { | |
margin: 0 auto; | |
max-width: 520px; | |
} | |
#stills{max-height:160px} | |
""" | |
examples = [ | |
"a cat waving its paws in the air", | |
"a panda moving their hips from side to side", | |
"a flower going through the process of blooming" | |
] | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown("# FLUX Gif Generator") | |
gr.Markdown("Create GIFs with Flux-dev. Based on @dn6's [space](https://huggingface.co/spaces/dn6/FLUX-GIFs) and @fofr's [tweet](https://x.com/fofrAI/status/1828910395962343561).") | |
gr.Markdown("Add LoRA (if needed) in Advanced Settings. For better results, include a description of the motion in your prompt.") | |
# gr.Markdown("For better results include a description of the motion in your prompt") | |
# with gr.Row(): | |
# with gr.Column(): | |
with gr.Row(): | |
prompt = gr.Text(label="Prompt", show_label=False, max_lines=4, show_copy_button = True, placeholder="Enter your prompt", container=False) | |
submit = gr.Button("Submit", scale=0) | |
output = gr.Image(label="GIF", show_label=False) | |
output_stills = gr.Image(label="stills", show_label=False, elem_id="stills") | |
with gr.Accordion("Advanced Settings", open=False): | |
custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path (optional)", placeholder="multimodalart/vintage-ads-flux") | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance Scale", | |
minimum=1, | |
maximum=15, | |
step=0.1, | |
value=3.5, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=28, | |
) | |
gr.Examples( | |
examples=examples, | |
fn=predict, | |
inputs=[prompt], | |
outputs=[output, output_stills, seed], | |
cache_examples="lazy" | |
) | |
gr.on( | |
triggers=[submit.click, prompt.submit], | |
fn=predict, | |
inputs=[prompt, seed, randomize_seed, guidance_scale, num_inference_steps, custom_lora], | |
outputs = [output, output_stills, seed] | |
) | |
demo.launch() |