Update handler.py
Browse files- handler.py +20 -13
handler.py
CHANGED
@@ -11,32 +11,39 @@ def base64_to_numpy(base64_str, shape):
|
|
11 |
return arr.reshape(shape)
|
12 |
|
13 |
class EndpointHandler:
|
14 |
-
def __init__(self):
|
15 |
-
self.base_model_name = "llava-hf/LLaVA-NeXT-Video-7B-hf"
|
16 |
-
self.adapter_model_name = "EnariGmbH/surftown-1.0"
|
17 |
|
18 |
# Load the base model
|
|
|
19 |
self.model = LlavaNextVideoForConditionalGeneration.from_pretrained(
|
20 |
self.base_model_name,
|
21 |
torch_dtype=torch.float16,
|
22 |
device_map="auto"
|
23 |
)
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
self.model = self.model.merge_and_unload()
|
|
|
30 |
|
31 |
-
#
|
32 |
-
# model.save_pretrained("surftown_fine_tuned_prompt_0")
|
33 |
-
|
34 |
-
# # Optionally, load and save the processor (if needed)
|
35 |
self.processor = LlavaNextVideoProcessor.from_pretrained(self.adapter_model_name)
|
|
|
36 |
|
37 |
-
# Ensure the model is in evaluation mode
|
38 |
self.model.eval()
|
39 |
|
|
|
40 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
41 |
"""
|
42 |
Args:
|
|
|
11 |
return arr.reshape(shape)
|
12 |
|
13 |
class EndpointHandler:
|
14 |
+
def __init__(self, model_dir: str = None):
|
15 |
+
self.base_model_name = "llava-hf/LLaVA-NeXT-Video-7B-hf"
|
16 |
+
self.adapter_model_name = "EnariGmbH/surftown-1.0"
|
17 |
|
18 |
# Load the base model
|
19 |
+
print("Loading base model:", self.base_model_name)
|
20 |
self.model = LlavaNextVideoForConditionalGeneration.from_pretrained(
|
21 |
self.base_model_name,
|
22 |
torch_dtype=torch.float16,
|
23 |
device_map="auto"
|
24 |
)
|
25 |
+
print("Base model successfully loaded.")
|
26 |
+
|
27 |
+
# Load the adapter model into the base model
|
28 |
+
print("Loading adapter model:", self.adapter_model_name)
|
29 |
+
try:
|
30 |
+
self.model = PeftModel.from_pretrained(self.model, self.adapter_model_name)
|
31 |
+
print("Adapter model successfully loaded.")
|
32 |
+
except Exception as e:
|
33 |
+
print(f"Failed to load adapter model: {e}")
|
34 |
+
raise e
|
35 |
+
|
36 |
+
# Merge the adapter weights into the base model
|
37 |
self.model = self.model.merge_and_unload()
|
38 |
+
print("Adapter model merged and unloaded.")
|
39 |
|
40 |
+
# Load processor
|
|
|
|
|
|
|
41 |
self.processor = LlavaNextVideoProcessor.from_pretrained(self.adapter_model_name)
|
42 |
+
print("Processor loaded.")
|
43 |
|
|
|
44 |
self.model.eval()
|
45 |
|
46 |
+
|
47 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
48 |
"""
|
49 |
Args:
|