EnariGmbH commited on
Commit
2554896
·
verified ·
1 Parent(s): 16e7684

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +20 -13
handler.py CHANGED
@@ -11,32 +11,39 @@ def base64_to_numpy(base64_str, shape):
11
  return arr.reshape(shape)
12
 
13
  class EndpointHandler:
14
- def __init__(self):
15
- self.base_model_name = "llava-hf/LLaVA-NeXT-Video-7B-hf" # Replace with the original base model ID
16
- self.adapter_model_name = "EnariGmbH/surftown-1.0" # Your fine-tuned adapter model ID
17
 
18
  # Load the base model
 
19
  self.model = LlavaNextVideoForConditionalGeneration.from_pretrained(
20
  self.base_model_name,
21
  torch_dtype=torch.float16,
22
  device_map="auto"
23
  )
24
-
25
- # Load the fine-tuned adapter model into the base model
26
- self.model = PeftModel.from_pretrained(self.model, self.adapter_model_name)
27
-
28
- # Merge the adapter weights into the base model and unload the adapter
 
 
 
 
 
 
 
29
  self.model = self.model.merge_and_unload()
 
30
 
31
- # # Save the full model
32
- # model.save_pretrained("surftown_fine_tuned_prompt_0")
33
-
34
- # # Optionally, load and save the processor (if needed)
35
  self.processor = LlavaNextVideoProcessor.from_pretrained(self.adapter_model_name)
 
36
 
37
- # Ensure the model is in evaluation mode
38
  self.model.eval()
39
 
 
40
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
41
  """
42
  Args:
 
11
  return arr.reshape(shape)
12
 
13
  class EndpointHandler:
14
+ def __init__(self, model_dir: str = None):
15
+ self.base_model_name = "llava-hf/LLaVA-NeXT-Video-7B-hf"
16
+ self.adapter_model_name = "EnariGmbH/surftown-1.0"
17
 
18
  # Load the base model
19
+ print("Loading base model:", self.base_model_name)
20
  self.model = LlavaNextVideoForConditionalGeneration.from_pretrained(
21
  self.base_model_name,
22
  torch_dtype=torch.float16,
23
  device_map="auto"
24
  )
25
+ print("Base model successfully loaded.")
26
+
27
+ # Load the adapter model into the base model
28
+ print("Loading adapter model:", self.adapter_model_name)
29
+ try:
30
+ self.model = PeftModel.from_pretrained(self.model, self.adapter_model_name)
31
+ print("Adapter model successfully loaded.")
32
+ except Exception as e:
33
+ print(f"Failed to load adapter model: {e}")
34
+ raise e
35
+
36
+ # Merge the adapter weights into the base model
37
  self.model = self.model.merge_and_unload()
38
+ print("Adapter model merged and unloaded.")
39
 
40
+ # Load processor
 
 
 
41
  self.processor = LlavaNextVideoProcessor.from_pretrained(self.adapter_model_name)
42
+ print("Processor loaded.")
43
 
 
44
  self.model.eval()
45
 
46
+
47
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
48
  """
49
  Args: