#!/usr/bin/env python # coding: utf-8 # Uncomment to run on cpu # import os # os.environ["JAX_PLATFORM_NAME"] = "cpu" import random import jax import flax.linen as nn from flax.training.common_utils import shard from flax.jax_utils import replicate from transformers import BartTokenizer from PIL import Image, ImageDraw, ImageFont import numpy as np from vqgan_jax.modeling_flax_vqgan import VQModel from dalle_mini.model import CustomFlaxBartForConditionalGeneration # ## CLIP Scoring from transformers import CLIPProcessor, FlaxCLIPModel import gradio as gr from PIL import Image, ImageDraw, ImageFont DALLE_REPO = "flax-community/dalle-mini" DALLE_COMMIT_ID = "4d34126d0df8bc4a692ae933e3b902a1fa8b6114" VQGAN_REPO = "flax-community/vqgan_f16_16384" VQGAN_COMMIT_ID = "90cc46addd2dd8f5be21586a9a23e1b95aa506a9" tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID) model = CustomFlaxBartForConditionalGeneration.from_pretrained( DALLE_REPO, revision=DALLE_COMMIT_ID ) vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID) def captioned_strip(images, caption=None, rows=1): increased_h = 0 if caption is None else 48 w, h = images[0].size[0], images[0].size[1] img = Image.new("RGB", (len(images) * w // rows, h * rows + increased_h)) for i, img_ in enumerate(images): img.paste(img_, (i // rows * w, increased_h + (i % rows) * h)) if caption is not None: draw = ImageDraw.Draw(img) font = ImageFont.truetype( "/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40 ) draw.text((20, 3), caption, (255, 255, 255), font=font) return img def custom_to_pil(x): x = np.clip(x, 0.0, 1.0) x = (255 * x).astype(np.uint8) x = Image.fromarray(x) if not x.mode == "RGB": x = x.convert("RGB") return x def generate(input, rng, params): return model.generate( **input, max_length=257, num_beams=1, do_sample=True, prng_key=rng, eos_token_id=50000, pad_token_id=50000, params=params, ) def get_images(indices, params): return vqgan.decode_code(indices, params=params) p_generate = jax.pmap(generate, "batch") p_get_images = jax.pmap(get_images, "batch") bart_params = replicate(model.params) vqgan_params = replicate(vqgan.params) clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32") print("Initialize FlaxCLIPModel") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") print("Initialize CLIPProcessor") def hallucinate(prompt, num_images=64): prompt = [prompt] * jax.device_count() inputs = tokenizer( prompt, return_tensors="jax", padding="max_length", truncation=True, max_length=128, ).data inputs = shard(inputs) all_images = [] for i in range(num_images // jax.device_count()): key = random.randint(0, 1e7) rng = jax.random.PRNGKey(key) rngs = jax.random.split(rng, jax.local_device_count()) indices = p_generate(inputs, rngs, bart_params).sequences indices = indices[:, :, 1:] images = p_get_images(indices, vqgan_params) images = np.squeeze(np.asarray(images), 1) for image in images: all_images.append(custom_to_pil(image)) return all_images def clip_top_k(prompt, images, k=8): inputs = processor(text=prompt, images=images, return_tensors="np", padding=True) outputs = clip(**inputs) logits = outputs.logits_per_text scores = np.array(logits[0]).argsort()[-k:][::-1] return [images[score] for score in scores] def compose_predictions(images, caption=None): increased_h = 0 if caption is None else 48 w, h = images[0].size[0], images[0].size[1] img = Image.new("RGB", (len(images) * w, h + increased_h)) for i, img_ in enumerate(images): img.paste(img_, (i * w, increased_h)) if caption is not None: draw = ImageDraw.Draw(img) font = ImageFont.truetype( "/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40 ) draw.text((20, 3), caption, (255, 255, 255), font=font) return img def top_k_predictions(prompt, num_candidates=32, k=8): images = hallucinate(prompt, num_images=num_candidates) images = clip_top_k(prompt, images, k=k) return images def run_inference(prompt, num_images=32, num_preds=8): images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds) predictions = captioned_strip(images) output_title = f""" {prompt} """ return (output_title, predictions) outputs = [ gr.outputs.HTML(label=""), # To be used as title gr.outputs.Image(label=""), ] description = """ DALL·E-mini is an AI model that generates images from any prompt you give! Generate images from text: """ gr.Interface( run_inference, inputs=[gr.inputs.Textbox(label="What do you want to see?")], outputs=outputs, title="DALL·E mini", description=description, article="
", layout="vertical", theme="huggingface", examples=[ ["an armchair in the shape of an avocado"], ["snowy mountains by the sea"], ], allow_flagging=False, live=False, # server_port=8999 ).launch(share=True)