Spaces:
Sleeping
Sleeping
File size: 4,865 Bytes
bc7d231 c7c92f9 dfda773 58e3cb5 63fc765 d5a60de 7f2e710 bc7d231 85f811b dc81fd5 ca90c3f 63fc765 dc81fd5 ca90c3f 63fc765 ca90c3f bc7d231 eedbfb7 bc7d231 8e2f248 fcca3a5 8e2f248 fcca3a5 bc7d231 4e1ae0e 7f2e710 0d3c647 7f2e710 0d3c647 7f2e710 4f92821 0d3c647 4f92821 c55a1ed 4f92821 7f2e710 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import streamlit as st
import torch
import bitsandbytes
import accelerate
import scipy
from PIL import Image
import torch.nn as nn
from transformers import Blip2Processor, Blip2ForConditionalGeneration, InstructBlipProcessor, InstructBlipForConditionalGeneration
from my_model.object_detection import ObjectDetector
def load_caption_model(blip2=False, instructblip=True):
if blip2:
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16)
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16)
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.to('cuda')
#model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
if instructblip:
model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16)
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.to('cuda')
processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16)
return model, processor
def answer_question(image, question, model, processor):
image = Image.open(image)
inputs = processor(image, question, return_tensors="pt").to("cuda", torch.float16)
if isinstance(model, torch.nn.DataParallel):
# Use the 'module' attribute to access the original model
out = model.module.generate(**inputs, max_length=100, min_length=20)
else:
out = model.generate(**inputs, max_length=100, min_length=20)
answer = processor.decode(out[0], skip_special_tokens=True).strip()
return answer
st.title("Image Question Answering")
# File uploader for the image
image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
# Text input for the question
question = st.text_input("Enter your question about the image:")
if st.button("Get Answer"):
if image is not None and question:
# Display the image
st.image(image, use_column_width=True)
# Get and display the answer
model, processor = load_caption_model()
answer = answer_question(image, question, model, processor)
st.write(answer)
else:
st.write("Please upload an image and enter a question.")
# Object Detection
# Object Detection UI in the sidebar
st.sidebar.title("Object Detection")
# Dropdown to select the model
detect_model = st.sidebar.selectbox("Choose a model for object detection:", ["detic", "yolov5"])
# Slider for threshold with default values based on the model
threshold = st.sidebar.slider("Select Detection Threshold", 0.1, 0.9, 0.2 if detect_model == "yolov5" else 0.4)
# Button to trigger object detection
detect_button = st.sidebar.button("Detect Objects")
def perform_object_detection(image, model_name, threshold):
"""
Perform object detection on the given image using the specified model and threshold.
Args:
image (PIL.Image): The image on which to perform object detection.
model_name (str): The name of the object detection model to use.
threshold (float): The threshold for object detection.
Returns:
PIL.Image, str: The image with drawn bounding boxes and a string of detected objects.
"""
try:
# Perform object detection and draw bounding boxes
processed_image, detected_objects = detect_and_draw_objects(image, model_name, threshold)
return processed_image, detected_objects
except Exception as e:
# Print the error for debugging
print(f"Error in object detection: {e}")
return None, str(e)
# Check if the 'Detect Objects' button was clicked
if detect_button:
if image is not None:
# Open the uploaded image
try:
image = Image.open(image)
# Display the original image
st.image(image, use_column_width=True, caption="Original Image")
except Exception as e:
st.error(f"Error loading image: {e}")
return
# Perform object detection
processed_image, detected_objects = perform_object_detection(image, detect_model, threshold)
if processed_image:
# Display the image with detected objects
st.image(processed_image, use_column_width=True, caption="Image with Detected Objects")
else:
st.error("Failed to process image for object detection.")
# Display the detected objects as text
st.write(detected_objects)
else:
st.write("Please upload an image for object detection.")
|