#!/usr/bin/env python # coding: utf-8 import requests from PIL import Image import numpy as np import matplotlib.pyplot as plt from io import BytesIO import base64 import gradio as gr # If we use streamlit, this would be exported as a streamlit secret import os backend_url = os.environ["BACKEND_SERVER"] def compose_predictions(images, caption=None): increased_h = 0 if caption is None else 48 w, h = images[0].size[0], images[0].size[1] img = Image.new("RGB", (len(images)*w, h + increased_h)) for i, img_ in enumerate(images): img.paste(img_, (i*w, increased_h)) if caption is not None: draw = ImageDraw.Draw(img) font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40) draw.text((20, 3), caption, (255,255,255), font=font) return img class ServiceError(Exception): def __init__(self, status_code): self.status_code = status_code def get_images_from_ngrok(prompt): r = requests.post( backend_url, json={"prompt": prompt} ) if r.status_code == 200: images = r.json()["images"] images = [Image.open(BytesIO(base64.b64decode(img))) for img in images] return images else: raise ServiceError(r.status_code) def run_inference(prompt): try: images = get_images_from_ngrok(prompt) predictions = compose_predictions(images) output_title = f"""
Best predictions
We asked our model to generate 128 candidates for your prompt:
{prompt}
We then used a pre-trained CLIP model to score them according to the similarity of the text and the image representations.
This is the result:
""" output_description = """Read more about the process in our report.
Created with DALLE路mini
""" except ServiceError: output_title = f""" Sorry, there was an error retrieving the images. Please, try again later or contact us here. """ predictions = None output_description = "" return (output_title, predictions, output_description) outputs = [ gr.outputs.HTML(label=""), # To be used as title gr.outputs.Image(label=''), gr.outputs.HTML(label=""), # Additional text that appears in the screenshot ] description = """ Welcome to our demo of DALL路E-mini. This project was created on TPU v3-8s during the 馃 Flax / JAX Community Week. It reproduces the essential characteristics of OpenAI's DALL路E, at a fraction of the size. Please, write what you would like the model to generate, or select one of the examples below. """ gr.Interface(run_inference, inputs=[gr.inputs.Textbox(label='Prompt')], #, gr.inputs.Slider(1,64,1,8, label='Candidates to generate'), gr.inputs.Slider(1,8,1,1, label='Best predictions to show')], outputs=outputs, title='DALL路E mini', description=description, article="DALLE路mini by Boris Dayma et al. | GitHub
", layout='vertical', theme='huggingface', examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']], allow_flagging=False, live=False, server_name="0.0.0.0", # Bind to all interfaces (I think) # server_port=8999 ).launch()