rishabh-zuma commited on
Commit
56a6bf1
·
1 Parent(s): eca2bf8

Updated with new code

Browse files
Files changed (1) hide show
  1. handler.py +10 -6
handler.py CHANGED
@@ -21,8 +21,8 @@ class EndpointHandler():
21
 
22
 
23
  print(" $$$$ Model Loading $$$$")
24
- self.processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
25
- self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
26
  print(" $$$$ model loaded $$$$")
27
  self.model.eval()
28
  self.model = self.model.to(device)
@@ -66,9 +66,13 @@ class EndpointHandler():
66
  # img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
67
  # raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
68
 
69
- inputs = self.processor(raw_image, prompt, return_tensors="pt").to("cuda", torch.float16)
 
 
 
70
 
71
- out = self.model.generate(**inputs)
72
- captions = processor.decode(out[0], skip_special_tokens=True)
73
 
74
- return {"captions": captions}
 
 
 
 
21
 
22
 
23
  print(" $$$$ Model Loading $$$$")
24
+ self.processor = Blip2Processor.from_pretrained("blip2/sharded")
25
+ self.model = Blip2ForConditionalGeneration.from_pretrained("blip2/sharded", device_map = "auto", load_in_8bit = True)
26
  print(" $$$$ model loaded $$$$")
27
  self.model.eval()
28
  self.model = self.model.to(device)
 
66
  # img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
67
  # raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
68
 
69
+ generated_ids = self.processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
70
+ generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
71
+ print("@@@@@@ generated_text @@@@@@@")
72
+ print(generated_text)
73
 
 
 
74
 
75
+ # out = self.model.generate(**inputs)
76
+ # captions = processor.decode(out[0], skip_special_tokens=True)
77
+
78
+ return {"captions": generated_text}