m7mdal7aj commited on
Commit
e9d7d81
·
verified ·
1 Parent(s): 8f97cdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -18
app.py CHANGED
@@ -7,25 +7,8 @@ from PIL import Image
7
  import torch.nn as nn
8
  from transformers import Blip2Processor, Blip2ForConditionalGeneration, InstructBlipProcessor, InstructBlipForConditionalGeneration
9
  from my_model.object_detection import detect_and_draw_objects
 
10
 
11
- def load_caption_model(blip2=False, instructblip=True):
12
-
13
- if blip2:
14
- processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16)
15
- model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16)
16
- if torch.cuda.device_count() > 1:
17
- model = nn.DataParallel(model)
18
- model.to('cuda')
19
- #model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
20
-
21
- if instructblip:
22
- model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16)
23
- if torch.cuda.device_count() > 1:
24
- model = nn.DataParallel(model)
25
- model.to('cuda')
26
- processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16)
27
-
28
- return model, processor
29
 
30
 
31
 
@@ -54,7 +37,16 @@ image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
54
  # Text input for the question
55
  question = st.text_input("Enter your question about the image:")
56
 
 
 
 
 
 
 
57
 
 
 
 
58
  if st.button("Get Answer"):
59
  if image is not None and question:
60
  # Display the image
 
7
  import torch.nn as nn
8
  from transformers import Blip2Processor, Blip2ForConditionalGeneration, InstructBlipProcessor, InstructBlipForConditionalGeneration
9
  from my_model.object_detection import detect_and_draw_objects
10
+ from my_model.captioner.image_captioning import get_caption
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
 
 
37
  # Text input for the question
38
  question = st.text_input("Enter your question about the image:")
39
 
40
+ if st.button('Generate Caption'):
41
+ if image is not None:
42
+ # Display the image
43
+ st.image(image, use_column_width=True)
44
+ caption = get_caption(image)
45
+ st.write(caption)
46
 
47
+ else:
48
+ st.write("Please upload an image and enter a question.")
49
+
50
  if st.button("Get Answer"):
51
  if image is not None and question:
52
  # Display the image