gabrielmotablima's picture
Update app.py
522a810 verified
raw
history blame
3.69 kB
import requests
from PIL import Image
from transformers import AutoTokenizer, AutoImageProcessor, VisionEncoderDecoderModel
import gradio as gr
import os
from concurrent.futures import ThreadPoolExecutor
# Load the model, tokenizer, and image processor with error handling
def load_model_and_components(model_name):
model = VisionEncoderDecoderModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
image_processor = AutoImageProcessor.from_pretrained(model_name)
return model, tokenizer, image_processor
# Preload both models in parallel
def preload_models():
models = {}
model_names = ["laicsiifes/swin-distilbertimbau", "laicsiifes/swin-gportuguese-2"]
with ThreadPoolExecutor() as executor:
results = executor.map(load_model_and_components, model_names)
for name, result in zip(model_names, results):
models[name] = result
return models
models = preload_models()
# Predefined images for selection
image_folder = "images"
predefined_images = [
Image.open(os.path.join(image_folder, fname)).convert("RGB")
for fname in os.listdir(image_folder) \
if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.ppm'))
]
# Function to preprocess the image to RGB format
def preprocess_image(image):
if image is None:
return None, None
pil_image = image.convert("RGB")
return pil_image, None
# Function to process the image and generate a caption
def generate_caption(image, selected_model):
if image is None:
return "Please upload an image to generate a caption."
model, tokenizer, image_processor = models[selected_model]
pixel_values = image_processor(image, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values)
caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return caption
# Define UI
with gr.Blocks(theme=gr.themes.Citrus(primary_hue="blue", secondary_hue="orange")) as interface:
gr.Markdown("""
# Welcome to the LAICSI-IFES space for Vision Encoder-Decoder (VED) demonstration
---
### Select an available model: Swin-DistilBERTimbau (168M) or Swin-GPorTuguese-2 (240M)
""")
with gr.Row(variant='panel'):
with gr.Column():
model_selector = gr.Dropdown(
choices=list(models.keys()),
value="laicsiifes/swin-distilbertimbau",
label="Select Model"
)
gr.Markdown("""
---
### Upload image or example images below, and click `Generate`
""")
with gr.Row(variant='panel'):
with gr.Column():
image_display = gr.Image(type="pil", label="Image Preview", image_mode="RGB", height=400)
with gr.Column():
output_text = gr.Textbox(label="Generated Caption")
generate_button = gr.Button("Generate")
gr.Markdown("""---""")
with gr.Row(variant='panel'):
examples = gr.Examples(
examples=predefined_images,
fn=preprocess_image,
inputs=[image_display],
outputs=[image_display, output_text],
label="Examples"
)
# Define actions
model_selector.change(fn=lambda: (None, None), outputs=[image_display, output_text])
image_display.upload(fn=preprocess_image, inputs=[image_display], outputs=[image_display, output_text])
image_display.clear(fn=lambda: None, outputs=[output_text])
generate_button.click(fn=generate_caption, inputs=[image_display, model_selector], outputs=output_text)
interface.launch(share=False)