File size: 3,200 Bytes
a65ed45
 
 
 
 
 
b3ee019
a65ed45
 
 
 
b3ee019
a65ed45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3ee019
 
a65ed45
 
 
 
b3ee019
a65ed45
 
 
b3ee019
 
a65ed45
b3ee019
 
 
 
 
a65ed45
b3ee019
a65ed45
 
 
 
 
 
 
b3ee019
a65ed45
b3ee019
 
 
 
 
 
 
 
a65ed45
 
 
 
 
609badf
a65ed45
 
b3ee019
 
a65ed45
609badf
a65ed45
 
b3ee019
a65ed45
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse
from pathlib import Path

import gradio as gr
import torch
from diffusers import DiffusionPipeline
from icecream import ic

from visual_anagrams.views import get_views, VIEW_MAP_NAMES
from visual_anagrams.samplers import sample_stage_1, sample_stage_2
from visual_anagrams.utils import add_args, save_illusion, save_metadata
from visual_anagrams.animate import animate_two_view

stage_1 = DiffusionPipeline.from_pretrained(
                "DeepFloyd/IF-I-M-v1.0",
                variant="fp16",
                torch_dtype=torch.float16)
stage_2 = DiffusionPipeline.from_pretrained(
                "DeepFloyd/IF-II-M-v1.0",
                text_encoder=None,
                variant="fp16",
                torch_dtype=torch.float16,
            )
stage_1.enable_model_cpu_offload()
stage_2.enable_model_cpu_offload()


def generate_content(
    style,
    prompt_for_original,
    prompt_for_transformed,
    transformation,
    num_inference_steps,
    seed
):
    prompts = [f'{style} {p}'.strip() for p in [prompt_for_original, prompt_for_transformed]]
    prompt_embeds = [stage_1.encode_prompt(p) for p in prompts]
    prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds)
    prompt_embeds = torch.cat(prompt_embeds)
    negative_prompt_embeds = torch.cat(negative_prompt_embeds)

    views = ['identity', VIEW_MAP_NAMES[transformation]]
    views = get_views(views)

    generator = torch.manual_seed(seed)

    print("Sample stage 1")
    image = sample_stage_1(stage_1,
                           prompt_embeds,
                           negative_prompt_embeds,
                           views,
                           num_inference_steps=num_inference_steps,
                           generator=generator)

    print("Sample stage 2")
    image = sample_stage_2(stage_2,
                           image,
                           prompt_embeds,
                           negative_prompt_embeds,
                           views,
                           num_inference_steps=num_inference_steps,
                           generator=generator)
    save_illusion(image, views, Path(""))
    
    size = image.shape[-1]
    animate_two_view(
        f"sample_{size}.png",
        views[1],
        prompts[0],
        prompts[1],
    )
    return 'tmp.mp4', f"sample_{size}.png", f"sample_{size}.views.png"


choices = list(VIEW_MAP_NAMES.keys())
gradio_app = gr.Interface(
    fn=generate_content,
    title="Multi-View Illusion Diffusion",
    inputs=[
        gr.Textbox(label="Style", placeholder="an oil painting of"),
        gr.Textbox(label="Prompt for original view", placeholder="a dress"),
        gr.Textbox(label="Prompt for transformed view", placeholder="an old man"),
        gr.Dropdown(label="View transformation", choices=choices, value=choices[0]),
        gr.Number(label="Number of diffusion steps", value=50, step=1, minimum=1, maximum=300),
        gr.Number(label="Random seed", value=0, step=1, minimum=0, maximum=100000)
    ],
    outputs=[gr.Video(label="Illusion"), gr.Image(label="Original"), gr.Image(label="Transformed")],
)


if __name__ == "__main__":
    gradio_app.launch(server_name="0.0.0.0") # server_name="0.0.0.0"