Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import IdeficsForVisionText2Text, AutoProcessor
|
3 |
+
from PIL import Image
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
model_id = "mrm8488/idefics-9b-ft-describe-diffusion-bf16"
|
7 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
8 |
+
|
9 |
+
model = IdeficsForVisionText2Text.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
10 |
+
processor = AutoProcessor.from_pretrained(config.base_model_name_or_path)
|
11 |
+
|
12 |
+
def predict(prompt, image_url, max_length):
|
13 |
+
image = processor.image_processor.fetch_images(image_url)
|
14 |
+
prompts = [[image, prompt]]
|
15 |
+
|
16 |
+
inputs = processor(prompts[0], return_tensors="pt").to(device)
|
17 |
+
|
18 |
+
generated_ids = model.generate(**inputs, max_length=128)
|
19 |
+
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
20 |
+
print(generated_text)
|
21 |
+
return generated_text
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
title = "Midjourney-like Image Captioning with IDEFICS"
|
26 |
+
description = "Gradio Demo for generating Midjourney like captions (describe functionality) with IDEFICS"
|
27 |
+
#article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
|
28 |
+
"Repo</a></p> "
|
29 |
+
#examples = [['beatles.jpeg'], ['aurora.jpeg'], ['good_luck.png'], ['pokemons.jpg'], ['donuts.jpg']]
|
30 |
+
io = gr.Interface(fn=image_caption,
|
31 |
+
#inputs=gr.inputs.Image(type='pil'),
|
32 |
+
inputs=[
|
33 |
+
gr.inputs.Textbox(value="Describe the following image:"),
|
34 |
+
gr.inputs.Textbox(label="image URL", placeholder="Insert the URL of the image to be described"),
|
35 |
+
gr.inputs.Slider(label="Max tokens", value=64, max=128, min=16, step=8)
|
36 |
+
]
|
37 |
+
outputs=gr.outputs.Textbox(label="IDEFICS Description"),
|
38 |
+
title=title, description=description
|
39 |
+
allow_flagging=False, allow_screenshot=False)
|
40 |
+
io.launch(show_errors=True)
|