rroset commited on
Commit
50ba760
verified
1 Parent(s): c0a7220

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +4 -8
handler.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import CLIPProcessor, CLIPModel, pipeline
2
  from PIL import Image
3
  import requests
4
  from io import BytesIO
@@ -7,12 +7,8 @@ from typing import Dict, List, Any
7
 
8
  class EndpointHandler():
9
  def __init__(self, path=""):
10
- # Utilitza el nom correcte del model per carregar-lo des de Hugging Face Hub
11
- self.model = CLIPModel.from_pretrained("hf-hub:rroset/CLIP-ViT-B-32-laion2B-s34B-b79K")
12
- self.processor = CLIPProcessor.from_pretrained("hf-hub:rroset/CLIP-ViT-B-32-laion2B-s34B-b79K")
13
-
14
- # Crea la pipeline de classificaci贸 d'imatges zero-shot
15
- self.classifier = pipeline("zero-shot-image-classification", model=self.model, processor=self.processor)
16
 
17
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
18
  image_input = data.get("inputs", None)
@@ -21,7 +17,7 @@ class EndpointHandler():
21
  if image_input is None or candidate_labels is None:
22
  raise ValueError("Image input or candidate labels not provided")
23
 
24
- # Determina si l'input 茅s una URL o una cadena base64
25
  if image_input.startswith("http"):
26
  response = requests.get(image_input)
27
  image = Image.open(BytesIO(response.content))
 
1
+ from transformers import pipeline
2
  from PIL import Image
3
  import requests
4
  from io import BytesIO
 
7
 
8
  class EndpointHandler():
9
  def __init__(self, path=""):
10
+ # Crea la pipeline de classificaci贸 d'imatges zero-shot amb el model espec铆fic
11
+ self.classifier = pipeline("zero-shot-image-classification", model="rroset/CLIP-ViT-B-32-laion2B-s34B-b79K")
 
 
 
 
12
 
13
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
14
  image_input = data.get("inputs", None)
 
17
  if image_input is None or candidate_labels is None:
18
  raise ValueError("Image input or candidate labels not provided")
19
 
20
+ # Carregar la imatge, podria ser via URL o base64
21
  if image_input.startswith("http"):
22
  response = requests.get(image_input)
23
  image = Image.open(BytesIO(response.content))