adasdimchom commited on
Commit
c37d664
·
1 Parent(s): 6140afd

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +13 -12
handler.py CHANGED
@@ -15,8 +15,8 @@ class EndpointHandler():
15
  self.generate_model = Blip2ForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
16
  self.generate_model.to(self.device)
17
 
18
- self.feature_model = Blip2Model.from_pretrained(path, torch_dtype=torch.float16)
19
- self.feature_model.to(self.device)
20
 
21
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
22
  """
@@ -33,27 +33,28 @@ class EndpointHandler():
33
  prompt = inputs["prompt"]
34
  else:
35
  prompt = None
36
- if "extract_feature" in inputs:
37
- extract_feature = inputs["extract_feature"]
38
- else:
39
- extract_feature = False
40
 
41
  image = Image.open(requests.get(image_url, stream=True).raw)
42
  processed_image = self.processor(images=image, return_tensors="pt").to(self.device, torch.float16)
43
  generated_ids = self.generate_model.generate(**processed_image)
44
  generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
45
  result["image_caption"] = generated_text
46
- if extract_feature:
47
- caption_feature = self.feature_model(**processed_image)
48
- result["caption_feature"] = caption_feature
 
49
 
50
  if prompt:
51
  prompt_image_processed = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device, torch.float16)
52
  generated_ids = self.generate_model.generate(**prompt_image_processed)
53
  generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
54
  result["image_prompt"] = generated_text
55
- if extract_feature:
56
- prompt_feature = self.feature_model(**prompt_image_processed)
57
- result["prompt_feature"] = prompt_feature
58
 
59
  return result
 
15
  self.generate_model = Blip2ForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
16
  self.generate_model.to(self.device)
17
 
18
+ #self.feature_model = Blip2Model.from_pretrained(path, torch_dtype=torch.float16)
19
+ #self.feature_model.to(self.device)
20
 
21
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
22
  """
 
33
  prompt = inputs["prompt"]
34
  else:
35
  prompt = None
36
+ #if "extract_feature" in inputs:
37
+ # extract_feature = inputs["extract_feature"]
38
+ #else:
39
+ # extract_feature = False
40
 
41
  image = Image.open(requests.get(image_url, stream=True).raw)
42
  processed_image = self.processor(images=image, return_tensors="pt").to(self.device, torch.float16)
43
  generated_ids = self.generate_model.generate(**processed_image)
44
  generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
45
  result["image_caption"] = generated_text
46
+
47
+ #if extract_feature:
48
+ # caption_feature = self.feature_model(**processed_image)
49
+ # result["caption_feature"] = caption_feature
50
 
51
  if prompt:
52
  prompt_image_processed = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device, torch.float16)
53
  generated_ids = self.generate_model.generate(**prompt_image_processed)
54
  generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
55
  result["image_prompt"] = generated_text
56
+ #if extract_feature:
57
+ # prompt_feature = self.feature_model(**prompt_image_processed)
58
+ # result["prompt_feature"] = prompt_feature
59
 
60
  return result