qwen2-vl-7b-infer / handler.py
morthens's picture
handler and requirements file creation
8c0009d
raw
history blame
1.9 kB
from typing import Dict, Any
import torch
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from PIL import Image
import requests
from io import BytesIO
# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EndpointHandler:
def __init__(self, path: str = "morthens/qwen2-vl-7b-infer"):
# Load the processor and model
self.processor = AutoProcessor.from_pretrained(path)
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
path,
torch_dtype="auto",
device_map="auto"
)
# Move the model to the appropriate device
self.model.to(device)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# Extract the input data
image_url = data.get("image_url", "")
text = data.get("text", "")
# Load the image from the URL
try:
response = requests.get(image_url)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
except Exception as e:
return {"error": f"Failed to fetch or process image: {str(e)}"}
# Preprocess the input
inputs = self.processor(
text=[text],
images=[image],
padding=True,
return_tensors="pt"
)
# Move inputs to the correct device
inputs = {key: value.to(device) for key, value in inputs.items()}
# Perform inference
output_ids = self.model.generate(
**inputs,
max_new_tokens=128
)
# Decode the output
output_text = self.processor.batch_decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)[0]
# Return the raw prediction
return {"prediction": output_text}