from typing import Dict, List, Any import torch from transformers import LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor from peft import PeftModel import base64 import numpy as np def base64_to_numpy(base64_str, shape): arr_bytes = base64.b64decode(base64_str) arr = np.frombuffer(arr_bytes, dtype=np.uint8) return arr.reshape(shape) class EndpointHandler: def __init__(self, model_dir: str = None): self.base_model_name = "llava-hf/LLaVA-NeXT-Video-7B-hf" self.adapter_model_name = "EnariGmbH/surftown-1.0" # Load the base model print("Loading base model:", self.base_model_name) self.model = LlavaNextVideoForConditionalGeneration.from_pretrained( self.base_model_name, torch_dtype=torch.float16, device_map="auto" ) print("Base model successfully loaded.") # Load the adapter model into the base model print("Loading adapter model:", self.adapter_model_name) try: self.model = PeftModel.from_pretrained(self.model, self.adapter_model_name) print("Adapter model successfully loaded.") except Exception as e: print(f"Failed to load adapter model: {e}") raise e # Merge the adapter weights into the base model self.model = self.model.merge_and_unload() print("Adapter model merged and unloaded.") # Load processor self.processor = LlavaNextVideoProcessor.from_pretrained(self.adapter_model_name) print("Processor loaded.") self.model.eval() def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Args: data (Dict): Contains the input data including "clip" Returns: List[Dict[str, Any]]: The generated text from the model. """ # Extract inputs from the data dictionary clip_base64 = data.get("clip") clip_shape = data.get("clip_shape") # Expect the shape to be passed if clip_base64 is None or clip_shape is None: return [{"error": "Missing 'clip' or 'clip_shape' in input data"}] # Decode the base64 back to numpy array and reshape clip = base64_to_numpy(clip_base64, tuple(clip_shape)) prompt = """ You are a surfing coach specialized on perfecting surfer's pop-up move. Please analyze the surfer's pop-up move in detail from the video. In your detailed analysis you should always mention: Wave Position and paddling, Pushing Phase, Transition, Reaching Phase and finnaly Balance and Control. At the end of your answer you must provide suggestions on how the surfer can improve in the next pop-up. Your answers should ALWAYS follow this structure: Description: Wave Position and paddling: text Pushing Phase: text Transition: text Reaching Phase: text Balance and Control: text Summary: Suggestions for improvement: text NEVER MENTION ANY INFORMATION THAT IS NOT RELEVANT FOR THE SURFER. KEEP YOUR ANSWERS SHORT AND DIRECT AND DO NOT MENTION ANY INFORMATION OUTSIDE OF THE BEFORE MENTIONED STRUCTURE. IMPORTANT: In the Balance and Control section you should also explain how the surfer performs their twists and turns after the pop-up is done. """ # Define a conversation history for surfing pop-up move analysis conversation = [ { "role": "user", "content": [ {"type": "text", "text": prompt}, {"type": "video"}, ], }, ] # Apply the chat template to create the prompt for the model prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True) if clip is None or prompt is None: return [{"error": "Missing 'clip' or 'prompt' in input data"}] # Ensure clip_bytes is converted properly to the expected format by the model inputs_video = ml.processor(text=prompt, videos=clip, padding=True, return_tensors="pt").to(ml.model.device) # Debug: Print the entire inputs_video structure print(f"Keys in inputs_video: {inputs_video.keys()}") # Rename pixel_values_videos to pixel_values if it exists if 'pixel_values_videos' in inputs_video: inputs_video['pixel_values'] = inputs_video.pop('pixel_values_videos') print(f"Renamed pixel_values_videos to pixel_values. New shape: {inputs_video['pixel_values'].shape}") else: print("pixel_values_videos not found in inputs_video") if 'input_ids' in inputs_video: print(f"input_ids shape: {inputs_video['input_ids'].shape}") else: print("input_ids not found in inputs_video") if 'attention_mask' in inputs_video: print(f"attention_mask shape: {inputs_video['attention_mask'].shape}") else: print("attention_mask not found in inputs_video") # Generate output from the model generate_kwargs = {"max_new_tokens": 512, "do_sample": True, "top_p": 0.9} output = self.model.generate(**inputs_video, **generate_kwargs) generated_text = self.processor.batch_decode(output, skip_special_tokens=True) # Extract the relevant part of the assistant's answer assistant_answer_start = generated_text[0].find("ASSISTANT:") + len("ASSISTANT:") assistant_answer = generated_text[0][assistant_answer_start:].strip() print("model answer", assistant_answer) return [{"generated_text": assistant_answer}]