File size: 1,908 Bytes
486f4bf
c7081df
 
111e2df
486f4bf
 
b87be57
 
486f4bf
 
ce0ce67
111e2df
c7081df
111e2df
 
ce0ce67
 
486f4bf
78cea7d
 
 
 
 
 
 
486f4bf
 
8d07585
486f4bf
a853a81
78cea7d
486f4bf
 
8d07585
b390d84
458c57f
78cea7d
 
 
 
5bbc196
 
458c57f
1651e3c
486f4bf
78cea7d
9fd6fc9
cf8abff
 
 
b20969d
232d5c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from transformers import AutoProcessor
#from transformers import IdeficsForVisionText2Text, AutoProcessor
#from peft import PeftModel, PeftConfig
import gradio as gr

#peft_model_id = "mrm8488/idefics-9b-ft-describe-diffusion-bf16-adapter"
peft_model_id = "HuggingFaceM4/idefics-9b"
device = "cuda" if torch.cuda.is_available() else "cpu"


#config = PeftConfig.from_pretrained(peft_model_id)
model = AutoProcessor.from_pretrained(peft_model_id, torch_dtype=torch.bfloat16)
#model = PeftModel.from_pretrained(model, peft_model_id)
processor = AutoProcessor.from_pretrained(peft_model_id)
model = model.to(device)
model.eval()

#Pre-determined best prompt for this fine-tune
prompt="Describe the following image:"

#Max generated tokens for your prompt
max_length=64

def predict(image):
    prompts = [[image, prompt]]
    inputs = processor(prompts[0], return_tensors="pt").to(device)
    generated_ids = model.generate(**inputs, max_length=max_length)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    generated_text = generated_text.replace(f"{prompt} ","")
    return generated_text

title = "Midjourney-like Image Captioning with IDEFICS"
description = "Gradio Demo for generating *Midjourney* like captions (describe functionality) with **IDEFICS**"

examples = [
    ["1_sTXgMwDUW0pk-1yK4iHYFw.png"],
    ["0_6as5rHi0sgG4W2Tq.png"],
    ["zoomout_2-1440x807.jpg"],
    ["inZdRVn7eafZNvaVre2iW1a538.webp"],
    ["cute-photos-of-cats-in-grass-1593184777.jpg"],
    ["llama2-coder-logo.png"]
]
io = gr.Interface(fn=predict, 
                  inputs=[
                      gr.Image(label="Upload an image", type="pil"),
                  ],
                  outputs=[
                      gr.Textbox(label="IDEFICS Description")
                  ],
                  title=title, description=description, examples=examples)
io.launch(debug=True)