KB-VQA / app.py
m7mdal7aj's picture
Update app.py
dc81fd5
raw
history blame
1.97 kB
import streamlit as st
import torch
import bitsandbytes
import accelerate
import scipy
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
def load_caption_model(blip2=false, instructblip=True):
if blip2:
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16, device_map="auto")
#model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
if instructblip:
model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16, device_map="auto")
processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
return model, processor
def answer_question(image, question, model, processor):
image = Image.open(image).convert('RGB')
inputs = processor(image, question, return_tensors="pt").to("cuda", torch.float16)
out = model.generate(**inputs, max_length=200, min_length=20, num_beams=1)
answer = processor.decode(out[0], skip_special_tokens=True).strip()
return answer
st.title("Image Question Answering")
# File uploader for the image
image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
# Text input for the question
question = st.text_input("Enter your question about the image:")
if st.button("Get Answer"):
if image is not None and question:
# Display the image
st.image(image, use_column_width=True)
# Get and display the answer
model, processor = load_caption_model()
answer = answer_question(image, question, model, processor)
st.write(answer)
else:
st.write("Please upload an image and enter a question.")