VascoBartolo commited on
Commit
776aee4
·
1 Parent(s): 4d14f79

add custom handler

Browse files
__pycache__/handler.cpython-39.pyc ADDED
Binary file (1.73 kB). View file
 
handler.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration
4
+
5
+ class EndpointHandler:
6
+ def __init__(self, path=""):
7
+ # Adjust the loading process to handle potential configuration mismatches
8
+ self.model = LlavaNextVideoForConditionalGeneration.from_pretrained(
9
+ path
10
+ )
11
+ self.processor = LlavaNextVideoProcessor.from_pretrained(path)
12
+
13
+ # Ensure the model is in evaluation mode
14
+ self.model.eval()
15
+
16
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
17
+ """
18
+ Args:
19
+ data (Dict): Contains the input data including "clip" and "prompt".
20
+
21
+ Returns:
22
+ List[Dict[str, Any]]: The generated text from the model.
23
+ """
24
+ # Extract inputs from the data dictionary
25
+ clip = data.get("clip")
26
+ prompt = data.get("prompt")
27
+
28
+ if clip is None or prompt is None:
29
+ return [{"error": "Missing 'clip' or 'prompt' in input data"}]
30
+
31
+ # Prepare the inputs for the model
32
+ inputs_video = self.processor(text=prompt, videos=clip, padding=True, return_tensors="pt").to(self.model.device)
33
+
34
+ # Generate output from the model
35
+ generate_kwargs = {"max_new_tokens": 512, "do_sample": True, "top_p": 0.9}
36
+ output = self.model.generate(**inputs_video, **generate_kwargs)
37
+ generated_text = self.processor.batch_decode(output, skip_special_tokens=True)
38
+
39
+ # Extract the relevant part of the assistant's answer
40
+ assistant_answer_start = generated_text[0].find("ASSISTANT:") + len("ASSISTANT:")
41
+ assistant_answer = generated_text[0][assistant_answer_start:].strip()
42
+
43
+ return [{"generated_text": assistant_answer}]
model-00002-of-00003.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d33895934eef0fd0121b24eb475248232d3d2a9e847e7c957814d9cb4ca4e8e4
3
- size 4957878440
 
 
 
 
model-00003-of-00003.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:06fdf4f4afa75b24b6467da6c4bc10a4d609e0aae8ea648a24b43eff0cba5742
3
- size 4176137408