boris commited on
Commit
a0b5dc7
·
1 Parent(s): dccd804

feat: update gradio app

Browse files
app/gradio/app_gradio.py CHANGED
@@ -18,12 +18,16 @@ from PIL import Image
18
  import numpy as np
19
  import matplotlib.pyplot as plt
20
 
21
-
22
  from vqgan_jax.modeling_flax_vqgan import VQModel
23
  from dalle_mini.model import CustomFlaxBartForConditionalGeneration
24
 
 
 
 
25
  import gradio as gr
26
 
 
 
27
 
28
  DALLE_REPO = 'flax-community/dalle-mini'
29
  DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'
@@ -58,34 +62,12 @@ def generate(input, rng, params):
58
  def get_images(indices, params):
59
  return vqgan.decode_code(indices, params=params)
60
 
61
- def plot_images(images):
62
- fig = plt.figure(figsize=(40, 20))
63
- columns = 4
64
- rows = 2
65
- plt.subplots_adjust(hspace=0, wspace=0)
66
-
67
- for i in range(1, columns*rows +1):
68
- fig.add_subplot(rows, columns, i)
69
- plt.imshow(images[i-1])
70
- plt.gca().axes.get_yaxis().set_visible(False)
71
- plt.show()
72
-
73
- def stack_reconstructions(images):
74
- w, h = images[0].size[0], images[0].size[1]
75
- img = Image.new("RGB", (len(images)*w, h))
76
- for i, img_ in enumerate(images):
77
- img.paste(img_, (i*w,0))
78
- return img
79
-
80
  p_generate = jax.pmap(generate, "batch")
81
  p_get_images = jax.pmap(get_images, "batch")
82
 
83
  bart_params = replicate(model.params)
84
  vqgan_params = replicate(vqgan.params)
85
 
86
- # ## CLIP Scoring
87
- from transformers import CLIPProcessor, FlaxCLIPModel
88
-
89
  clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
90
  print("Initialize FlaxCLIPModel")
91
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
@@ -137,7 +119,7 @@ def top_k_predictions(prompt, num_candidates=32, k=8):
137
 
138
  def run_inference(prompt, num_images=32, num_preds=8):
139
  images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
140
- predictions = compose_predictions(images)
141
  output_title = f"""
142
  <b>{prompt}</b>
143
  """
@@ -152,7 +134,7 @@ description = """
152
  DALL·E-mini is an AI model that generates images from any prompt you give! Generate images from text:
153
  """
154
  gr.Interface(run_inference,
155
- inputs=[gr.inputs.Textbox(label='What do you want to see?')], #, gr.inputs.Slider(1,64,1,8, label='Candidates to generate'), gr.inputs.Slider(1,8,1,1, label='Best predictions to show')],
156
  outputs=outputs,
157
  title='DALL·E mini',
158
  description=description,
 
18
  import numpy as np
19
  import matplotlib.pyplot as plt
20
 
 
21
  from vqgan_jax.modeling_flax_vqgan import VQModel
22
  from dalle_mini.model import CustomFlaxBartForConditionalGeneration
23
 
24
+ # ## CLIP Scoring
25
+ from transformers import CLIPProcessor, FlaxCLIPModel
26
+
27
  import gradio as gr
28
 
29
+ from dalle_mini.helpers import captioned_strip
30
+
31
 
32
  DALLE_REPO = 'flax-community/dalle-mini'
33
  DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'
 
62
  def get_images(indices, params):
63
  return vqgan.decode_code(indices, params=params)
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  p_generate = jax.pmap(generate, "batch")
66
  p_get_images = jax.pmap(get_images, "batch")
67
 
68
  bart_params = replicate(model.params)
69
  vqgan_params = replicate(vqgan.params)
70
 
 
 
 
71
  clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
72
  print("Initialize FlaxCLIPModel")
73
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
 
119
 
120
  def run_inference(prompt, num_images=32, num_preds=8):
121
  images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
122
+ predictions = captioned_strip(images)
123
  output_title = f"""
124
  <b>{prompt}</b>
125
  """
 
134
  DALL·E-mini is an AI model that generates images from any prompt you give! Generate images from text:
135
  """
136
  gr.Interface(run_inference,
137
+ inputs=[gr.inputs.Textbox(label='What do you want to see?')],
138
  outputs=outputs,
139
  title='DALL·E mini',
140
  description=description,
app/gradio/app_gradio_ngrok.py CHANGED
@@ -7,25 +7,15 @@ import numpy as np
7
  import matplotlib.pyplot as plt
8
  from io import BytesIO
9
  import base64
 
10
 
11
  import gradio as gr
12
 
13
- # If we use streamlit, this would be exported as a streamlit secret
14
- import os
15
- backend_url = os.environ["BACKEND_SERVER"]
16
 
17
- def compose_predictions(images, caption=None):
18
- increased_h = 0 if caption is None else 48
19
- w, h = images[0].size[0], images[0].size[1]
20
- img = Image.new("RGB", (len(images)*w, h + increased_h))
21
- for i, img_ in enumerate(images):
22
- img.paste(img_, (i*w, increased_h))
23
 
24
- if caption is not None:
25
- draw = ImageDraw.Draw(img)
26
- font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
27
- draw.text((20, 3), caption, (255,255,255), font=font)
28
- return img
29
 
30
  class ServiceError(Exception):
31
  def __init__(self, status_code):
@@ -46,7 +36,7 @@ def get_images_from_ngrok(prompt):
46
  def run_inference(prompt):
47
  try:
48
  images = get_images_from_ngrok(prompt)
49
- predictions = compose_predictions(images)
50
  output_title = f"""
51
  <p style="font-size:22px; font-style:bold">Best predictions</p>
52
  <p>We asked our model to generate 128 candidates for your prompt:</p>
 
7
  import matplotlib.pyplot as plt
8
  from io import BytesIO
9
  import base64
10
+ import os
11
 
12
  import gradio as gr
13
 
14
+ from dalle_mini.helpers import captioned_strip
 
 
15
 
 
 
 
 
 
 
16
 
17
+ backend_url = os.environ["BACKEND_SERVER"]
18
+
 
 
 
19
 
20
  class ServiceError(Exception):
21
  def __init__(self, status_code):
 
36
  def run_inference(prompt):
37
  try:
38
  images = get_images_from_ngrok(prompt)
39
+ predictions = captioned_strip(images)
40
  output_title = f"""
41
  <p style="font-size:22px; font-style:bold">Best predictions</p>
42
  <p>We asked our model to generate 128 candidates for your prompt:</p>
app/sample_images/image_0.jpg DELETED
Binary file (9.02 kB)
 
app/sample_images/image_1.jpg DELETED
Binary file (9.71 kB)
 
app/sample_images/image_2.jpg DELETED
Binary file (14.1 kB)
 
app/sample_images/image_3.jpg DELETED
Binary file (9.38 kB)
 
app/sample_images/image_4.jpg DELETED
Binary file (9.97 kB)
 
app/sample_images/image_5.jpg DELETED
Binary file (15.3 kB)
 
app/sample_images/image_6.jpg DELETED
Binary file (11.1 kB)
 
app/sample_images/image_7.jpg DELETED
Binary file (8.55 kB)
 
app/sample_images/readme.txt DELETED
@@ -1 +0,0 @@
1
- These images were generated by one of our checkpoints, as responses to the prompt "snowy mountains by the sea".
 
 
app/ui_gradio.py DELETED
@@ -1,91 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- from PIL import Image
5
- import gradio as gr
6
-
7
- def compose_predictions(images, caption=None):
8
- increased_h = 0 if caption is None else 48
9
- w, h = images[0].size[0], images[0].size[1]
10
- img = Image.new("RGB", (len(images)*w, h + increased_h))
11
- for i, img_ in enumerate(images):
12
- img.paste(img_, (i*w, increased_h))
13
-
14
- if caption is not None:
15
- draw = ImageDraw.Draw(img)
16
- font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
17
- draw.text((20, 3), caption, (255,255,255), font=font)
18
- return img
19
-
20
- def compose_predictions_grid(images):
21
- cols = 4
22
- rows = len(images) // cols
23
- w, h = images[0].size[0], images[0].size[1]
24
- img = Image.new("RGB", (w * cols, h * rows))
25
- for i, img_ in enumerate(images):
26
- row = i // cols
27
- col = i % cols
28
- img.paste(img_, (w * col, h * row))
29
- return img
30
-
31
- def top_k_predictions_real(prompt, num_candidates=32, k=8):
32
- images = hallucinate(prompt, num_images=num_candidates)
33
- images = clip_top_k(prompt, images, k=num_preds)
34
- return images
35
-
36
- def top_k_predictions(prompt, num_candidates=32, k=8):
37
- images = []
38
- for i in range(k):
39
- image = Image.open(f"sample_images/image_{i}.jpg")
40
- images.append(image)
41
- return images
42
-
43
- def run_inference(prompt, num_images=32, num_preds=8):
44
- images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
45
- predictions = compose_predictions(images)
46
- output_title = f"""
47
- <p style="font-size:22px; font-style:bold">Best predictions</p>
48
- <p>We asked our model to generate 32 candidates for your prompt:</p>
49
-
50
- <pre>
51
-
52
- <b>{prompt}</b>
53
- </pre>
54
- <p>We then used a pre-trained <a href="https://huggingface.co/openai/clip-vit-base-patch32">CLIP model</a> to score them according to the
55
- similarity of the text and the image representations.</p>
56
-
57
- <p>This is the result:</p>
58
- """
59
- output_description = """
60
- <p>Read more about the process <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA">in our report</a>.<p>
61
- <p style='text-align: center'>Created with <a href="https://github.com/borisdayma/dalle-mini">DALLE·mini</a></p>
62
- """
63
- return (output_title, predictions, output_description)
64
-
65
- outputs = [
66
- gr.outputs.HTML(label=""), # To be used as title
67
- gr.outputs.Image(label=''),
68
- gr.outputs.HTML(label=""), # Additional text that appears in the screenshot
69
- ]
70
-
71
- description = """
72
- Welcome to our demo of DALL·E-mini. This project was created on TPU v3-8s during the 🤗 Flax / JAX Community Week.
73
- It reproduces the essential characteristics of OpenAI's DALL·E, at a fraction of the size.
74
-
75
- Please, write what you would like the model to generate, or select one of the examples below.
76
- """
77
- gr.Interface(run_inference,
78
- inputs=[gr.inputs.Textbox(label='Prompt')], #, gr.inputs.Slider(1,64,1,8, label='Candidates to generate'), gr.inputs.Slider(1,8,1,1, label='Best predictions to show')],
79
- outputs=outputs,
80
- title='DALL·E mini',
81
- description=description,
82
- article="<p style='text-align: center'> DALLE·mini by Boris Dayma et al. | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a></p>",
83
- layout='vertical',
84
- theme='huggingface',
85
- examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
86
- allow_flagging=False,
87
- live=False,
88
- server_port=8999
89
- ).launch(
90
- share=True # Creates temporary public link if true
91
- )