m7mdal7aj commited on
Commit
eedbfb7
·
1 Parent(s): 9a3c83b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -11,13 +11,13 @@ from transformers import Blip2Processor, Blip2ForConditionalGeneration, Instruct
11
  def load_caption_model(blip2=False, instructblip=True):
12
 
13
  if blip2:
14
- processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16, device_map="auto")
15
- model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16, device_map="auto")
16
  #model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
17
 
18
  if instructblip:
19
- model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16, device_map="auto")
20
- processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16, device_map="auto")
21
 
22
  return model, processor
23
 
@@ -26,13 +26,13 @@ def load_caption_model(blip2=False, instructblip=True):
26
  def answer_question(image, question, model, processor):
27
 
28
 
29
- image = Image.open(image).convert('RGB')
30
 
31
 
32
 
33
  inputs = processor(image, question, return_tensors="pt").to("cuda", torch.float16)
34
 
35
- out = model.generate(**inputs, max_length=200, min_length=20, num_beams=3)
36
 
37
  answer = processor.decode(out[0], skip_special_tokens=True).strip()
38
  return answer
 
11
  def load_caption_model(blip2=False, instructblip=True):
12
 
13
  if blip2:
14
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16, device_map="cuda")
15
+ model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16, device_map="cuda")
16
  #model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
17
 
18
  if instructblip:
19
+ model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16, device_map="cuda")
20
+ processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16, device_map="cuda")
21
 
22
  return model, processor
23
 
 
26
  def answer_question(image, question, model, processor):
27
 
28
 
29
+ image = Image.open(image)
30
 
31
 
32
 
33
  inputs = processor(image, question, return_tensors="pt").to("cuda", torch.float16)
34
 
35
+ out = model.generate(**inputs, max_length=200, min_length=20).to("cuda", torch.float16)
36
 
37
  answer = processor.decode(out[0], skip_special_tokens=True).strip()
38
  return answer