sergiopaniego's picture
Update app.py
706519b verified
raw
history blame
3.66 kB
import gradio as gr
import spaces
from transformers import Idefics3ForConditionalGeneration, AutoProcessor
import torch
from PIL import Image
from datetime import datetime
import numpy as np
import os
DESCRIPTION = """
# SmolVLM-trl-dpo-rlaif-v Demo
This is a demo Space for a fine-tuned version of [SmolVLM](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct) trained using [rlaif-v dataset](https://huggingface.co/datasets/HuggingFaceH4/rlaif-v_formatted).
The corresponding model is located [here](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct-DPO).
For a full tutorial of fine-tuning using DPO, check out [this link](https://huggingface.co/learn/cookbook/index).
"""
model_id = "HuggingFaceTB/SmolVLM-Instruct"
model = Idefics3ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
#_attn_implementation="flash_attention_2",
)
processor = AutoProcessor.from_pretrained(model_id)
#adapter_path = "sergiopaniego/smolvlm-instruct-trl-dpo-rlaif-v"
adapter_path = "HuggingFaceTB/SmolVLM-Instruct-DPO"
model.load_adapter(adapter_path)
def array_to_image_path(image_array):
if image_array is None:
raise ValueError("No image provided. Please upload an image before submitting.")
# Convert numpy array to PIL Image
img = Image.fromarray(np.uint8(image_array))
# Generate a unique filename using timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"image_{timestamp}.png"
# Save the image
img.save(filename)
# Get the full path of the saved image
full_path = os.path.abspath(filename)
return full_path
@spaces.GPU
def run_example(image, text_input=None):
image_path = array_to_image_path(image)
image = Image.fromarray(image).convert("RGB")
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"text": None,
},
{
"text": text_input,
"type": "text"
},
],
}
]
# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs = []
if image.mode != 'RGB':
image = image.convert('RGB')
image_inputs.append([image])
inputs = processor(
text=text,
images=image_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=1024)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0]
css = """
#output {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(DESCRIPTION)
with gr.Tab(label="SmolVLM-Instruct-DPO Input"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Picture")
text_input = gr.Textbox(label="Question")
submit_btn = gr.Button(value="Submit")
with gr.Column():
output_text = gr.Textbox(label="Output Text")
submit_btn.click(run_example, [input_img, text_input], [output_text])
demo.queue(api_open=False)
demo.launch(debug=True)