File size: 3,816 Bytes
704ee93
 
 
 
 
 
 
 
 
 
 
 
6be3159
 
 
704ee93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6be3159
704ee93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a61d80f
704ee93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb6780a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python
# coding: utf-8

import requests
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from io import BytesIO
import base64

import gradio as gr

# If we use streamlit, this would be exported as a streamlit secret
import os
backend_url = os.environ["BACKEND_SERVER"]

def compose_predictions(images, caption=None):
    increased_h = 0 if caption is None else 48
    w, h = images[0].size[0], images[0].size[1]
    img = Image.new("RGB", (len(images)*w, h + increased_h))
    for i, img_ in enumerate(images):
        img.paste(img_, (i*w, increased_h))

    if caption is not None:
        draw = ImageDraw.Draw(img)
        font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
        draw.text((20, 3), caption, (255,255,255), font=font)
    return img

class ServiceError(Exception):
    def __init__(self, status_code):
        self.status_code = status_code

def get_images_from_ngrok(prompt):
    r = requests.post(
        backend_url, 
        json={"prompt": prompt}
    )
    if r.status_code == 200:
        images = r.json()["images"]
        images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
        return images
    else:
        raise ServiceError(r.status_code)
        
def run_inference(prompt):
    try:
        images = get_images_from_ngrok(prompt)
        predictions = compose_predictions(images)
        output_title = f"""
        <p style="font-size:22px; font-style:bold">Best predictions</p>
        <p>We asked our model to generate 128 candidates for your prompt:</p>

        <pre>

        <b>{prompt}</b>
        </pre>
        <p>We then used a pre-trained <a href="https://huggingface.co/openai/clip-vit-base-patch32">CLIP model</a> to score them according to the
        similarity of the text and the image representations.</p>

        <p>This is the result:</p>
        """
        
        output_description = """
        <p>Read more about the process <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA">in our report</a>.<p>
        <p style='text-align: center'>Created with <a href="https://github.com/borisdayma/dalle-mini">DALLE·mini</a></p>
        """

    except ServiceError:
        output_title = f"""
        Sorry, there was an error retrieving the images. Please, try again later or <a href="mailto:[email protected]">contact us here</a>.
        """
        predictions = None
        output_description = ""
        
    return (output_title, predictions, output_description)
    
outputs = [
    gr.outputs.HTML(label=""),      # To be used as title
    gr.outputs.Image(label=''),
    gr.outputs.HTML(label=""),      # Additional text that appears in the screenshot
]

description = """
Welcome to our demo of DALL·E-mini. This project was created on TPU v3-8s during the 🤗 Flax / JAX Community Week.
It reproduces the essential characteristics of OpenAI's DALL·E, at a fraction of the size.

Please, write what you would like the model to generate, or select one of the examples below.
"""
gr.Interface(run_inference, 
    inputs=[gr.inputs.Textbox(label='Prompt')], #, gr.inputs.Slider(1,64,1,8, label='Candidates to generate'), gr.inputs.Slider(1,8,1,1, label='Best predictions to show')], 
    outputs=outputs, 
    title='DALL·E mini',
    description=description,
    article="<p style='text-align: center'> DALLE·mini by Boris Dayma et al. | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a></p>",
    layout='vertical',
    theme='huggingface',
    examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
    allow_flagging=False,
    live=False,
    server_name="0.0.0.0",      # Bind to all interfaces (I think)
    # server_port=8999
).launch()