rishabh-zuma
commited on
Commit
·
56a6bf1
1
Parent(s):
eca2bf8
Updated with new code
Browse files- handler.py +10 -6
handler.py
CHANGED
@@ -21,8 +21,8 @@ class EndpointHandler():
|
|
21 |
|
22 |
|
23 |
print(" $$$$ Model Loading $$$$")
|
24 |
-
self.processor = Blip2Processor.from_pretrained("
|
25 |
-
self.model = Blip2ForConditionalGeneration.from_pretrained("
|
26 |
print(" $$$$ model loaded $$$$")
|
27 |
self.model.eval()
|
28 |
self.model = self.model.to(device)
|
@@ -66,9 +66,13 @@ class EndpointHandler():
|
|
66 |
# img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
|
67 |
# raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
|
68 |
|
69 |
-
|
|
|
|
|
|
|
70 |
|
71 |
-
out = self.model.generate(**inputs)
|
72 |
-
captions = processor.decode(out[0], skip_special_tokens=True)
|
73 |
|
74 |
-
|
|
|
|
|
|
|
|
21 |
|
22 |
|
23 |
print(" $$$$ Model Loading $$$$")
|
24 |
+
self.processor = Blip2Processor.from_pretrained("blip2/sharded")
|
25 |
+
self.model = Blip2ForConditionalGeneration.from_pretrained("blip2/sharded", device_map = "auto", load_in_8bit = True)
|
26 |
print(" $$$$ model loaded $$$$")
|
27 |
self.model.eval()
|
28 |
self.model = self.model.to(device)
|
|
|
66 |
# img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
|
67 |
# raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
|
68 |
|
69 |
+
generated_ids = self.processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
|
70 |
+
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
71 |
+
print("@@@@@@ generated_text @@@@@@@")
|
72 |
+
print(generated_text)
|
73 |
|
|
|
|
|
74 |
|
75 |
+
# out = self.model.generate(**inputs)
|
76 |
+
# captions = processor.decode(out[0], skip_special_tokens=True)
|
77 |
+
|
78 |
+
return {"captions": generated_text}
|