EnariGmbH commited on
Commit
20c305b
·
verified ·
1 Parent(s): 3c1dae0

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +43 -10
handler.py CHANGED
@@ -5,8 +5,8 @@ from peft import PeftModel
5
 
6
  class EndpointHandler:
7
  def __init__(self):
8
- self.base_model_name = "llava-hf/LLaVA-NeXT-Video-7B-hf"
9
- self.adapter_model_name = "EnariGmbH/surftown-1.0"
10
 
11
  # Load the base model
12
  self.model = LlavaNextVideoForConditionalGeneration.from_pretrained(
@@ -21,6 +21,9 @@ class EndpointHandler:
21
  # Merge the adapter weights into the base model and unload the adapter
22
  self.model = self.model.merge_and_unload()
23
 
 
 
 
24
  # # Optionally, load and save the processor (if needed)
25
  self.processor = LlavaNextVideoProcessor.from_pretrained(self.adapter_model_name)
26
 
@@ -30,28 +33,58 @@ class EndpointHandler:
30
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
31
  """
32
  Args:
33
- data (Dict): Contains the input data including "clip" and "prompt".
34
 
35
  Returns:
36
  List[Dict[str, Any]]: The generated text from the model.
37
  """
38
  # Extract inputs from the data dictionary
39
  clip = data.get("clip")
40
- prompt = data.get("prompt")
41
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  if clip is None or prompt is None:
43
  return [{"error": "Missing 'clip' or 'prompt' in input data"}]
44
-
45
  # Prepare the inputs for the model
46
  inputs_video = self.processor(text=prompt, videos=clip, padding=True, return_tensors="pt").to(self.model.device)
47
-
48
  # Generate output from the model
49
  generate_kwargs = {"max_new_tokens": 512, "do_sample": True, "top_p": 0.9}
50
  output = self.model.generate(**inputs_video, **generate_kwargs)
51
  generated_text = self.processor.batch_decode(output, skip_special_tokens=True)
52
-
53
  # Extract the relevant part of the assistant's answer
54
  assistant_answer_start = generated_text[0].find("ASSISTANT:") + len("ASSISTANT:")
55
  assistant_answer = generated_text[0][assistant_answer_start:].strip()
56
-
57
- return [{"generated_text": assistant_answer}]
 
5
 
6
  class EndpointHandler:
7
  def __init__(self):
8
+ self.base_model_name = "llava-hf/LLaVA-NeXT-Video-7B-hf" # Replace with the original base model ID
9
+ self.adapter_model_name = "EnariGmbH/surftown-1.0" # Your fine-tuned adapter model ID
10
 
11
  # Load the base model
12
  self.model = LlavaNextVideoForConditionalGeneration.from_pretrained(
 
21
  # Merge the adapter weights into the base model and unload the adapter
22
  self.model = self.model.merge_and_unload()
23
 
24
+ # # Save the full model
25
+ # model.save_pretrained("surftown_fine_tuned_prompt_0")
26
+
27
  # # Optionally, load and save the processor (if needed)
28
  self.processor = LlavaNextVideoProcessor.from_pretrained(self.adapter_model_name)
29
 
 
33
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
34
  """
35
  Args:
36
+ data (Dict): Contains the input data including "clip"
37
 
38
  Returns:
39
  List[Dict[str, Any]]: The generated text from the model.
40
  """
41
  # Extract inputs from the data dictionary
42
  clip = data.get("clip")
43
+
44
+ prompt = """
45
+ You are a surfing coach specialized on perfecting surfer's pop-up move. Please analyze the surfer's pop-up move in detail from the video.
46
+ In your detailed analysis you should always mention: Wave Position and paddling, Pushing Phase, Transition, Reaching Phase and finnaly Balance and Control.
47
+ At the end of your answer you must provide suggestions on how the surfer can improve in the next pop-up.
48
+ Never mention your name in the answer and keep the answers short and direct.
49
+ Your answers should ALWAYS follow this structure:
50
+ Description: \n
51
+ Wave Position and paddling: .\n.
52
+ Pushing Phase: \n.
53
+ Transition: \n.
54
+ Reaching Phase: \n
55
+ Balance and Control: \n\n\n
56
+ Summary: \n
57
+ Suggestions for improvement:\n
58
+ """
59
+
60
+
61
+ # Define a conversation history for surfing pop-up move analysis
62
+ conversation = [
63
+ {
64
+ "role": "user",
65
+ "content": [
66
+ {"type": "text", "text": prompt},
67
+ {"type": "video"},
68
+ ],
69
+ },
70
+ ]
71
+
72
+ # Apply the chat template to create the prompt for the model
73
+ prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
74
+
75
  if clip is None or prompt is None:
76
  return [{"error": "Missing 'clip' or 'prompt' in input data"}]
77
+
78
  # Prepare the inputs for the model
79
  inputs_video = self.processor(text=prompt, videos=clip, padding=True, return_tensors="pt").to(self.model.device)
80
+
81
  # Generate output from the model
82
  generate_kwargs = {"max_new_tokens": 512, "do_sample": True, "top_p": 0.9}
83
  output = self.model.generate(**inputs_video, **generate_kwargs)
84
  generated_text = self.processor.batch_decode(output, skip_special_tokens=True)
85
+
86
  # Extract the relevant part of the assistant's answer
87
  assistant_answer_start = generated_text[0].find("ASSISTANT:") + len("ASSISTANT:")
88
  assistant_answer = generated_text[0][assistant_answer_start:].strip()
89
+
90
+ return [{"generated_text": assistant_answer}]