|
import requests |
|
from PIL import Image |
|
from transformers import AutoTokenizer, AutoImageProcessor, VisionEncoderDecoderModel |
|
import gradio as gr |
|
import os |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
|
|
def load_model_and_components(model_name): |
|
model = VisionEncoderDecoderModel.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
image_processor = AutoImageProcessor.from_pretrained(model_name) |
|
return model, tokenizer, image_processor |
|
|
|
|
|
def preload_models(): |
|
models = {} |
|
model_names = ["laicsiifes/swin-distilbertimbau", "laicsiifes/swin-gportuguese-2"] |
|
with ThreadPoolExecutor() as executor: |
|
results = executor.map(load_model_and_components, model_names) |
|
for name, result in zip(model_names, results): |
|
models[name] = result |
|
return models |
|
|
|
models = preload_models() |
|
|
|
|
|
image_folder = "images" |
|
predefined_images = [ |
|
Image.open(os.path.join(image_folder, fname)).convert("RGB") |
|
for fname in os.listdir(image_folder) \ |
|
if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.ppm')) |
|
] |
|
|
|
|
|
def preprocess_image(image): |
|
if image is None: |
|
return None, None |
|
pil_image = image.convert("RGB") |
|
return pil_image, None |
|
|
|
|
|
def generate_caption(image, selected_model): |
|
if image is None: |
|
return "Please upload an image to generate a caption." |
|
model, tokenizer, image_processor = models[selected_model] |
|
pixel_values = image_processor(image, return_tensors="pt").pixel_values |
|
generated_ids = model.generate(pixel_values) |
|
caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
return caption |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Citrus(primary_hue="blue", secondary_hue="orange")) as interface: |
|
gr.Markdown(""" |
|
# Welcome to the LAICSI-IFES space for Vision Encoder-Decoder (VED) demonstration |
|
--- |
|
### Select an available model: Swin-DistilBERTimbau (168M) or Swin-GPorTuguese-2 (240M) |
|
""") |
|
with gr.Row(variant='panel'): |
|
with gr.Column(): |
|
model_selector = gr.Dropdown( |
|
choices=list(models.keys()), |
|
value="laicsiifes/swin-distilbertimbau", |
|
label="Select Model" |
|
) |
|
|
|
gr.Markdown(""" |
|
--- |
|
### Upload image or example images below, and click `Generate` |
|
""") |
|
|
|
with gr.Row(variant='panel'): |
|
with gr.Column(): |
|
image_display = gr.Image(type="pil", label="Image Preview", image_mode="RGB", height=400) |
|
with gr.Column(): |
|
output_text = gr.Textbox(label="Generated Caption") |
|
generate_button = gr.Button("Generate") |
|
|
|
gr.Markdown("""---""") |
|
|
|
with gr.Row(variant='panel'): |
|
examples = gr.Examples( |
|
examples=predefined_images, |
|
fn=preprocess_image, |
|
inputs=[image_display], |
|
outputs=[image_display, output_text], |
|
label="Examples" |
|
) |
|
|
|
|
|
model_selector.change(fn=lambda: (None, None), outputs=[image_display, output_text]) |
|
|
|
image_display.upload(fn=preprocess_image, inputs=[image_display], outputs=[image_display, output_text]) |
|
image_display.clear(fn=lambda: None, outputs=[output_text]) |
|
|
|
generate_button.click(fn=generate_caption, inputs=[image_display, model_selector], outputs=output_text) |
|
|
|
interface.launch(share=False) |