Molmoe / app.py
shambhuDATA's picture
Update app.py
afebc2f verified
raw
history blame
3.24 kB
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import torch
import spaces
# Load the processor and model
processor = AutoProcessor.from_pretrained(
'allenai/Molmo-1B-D-0924',
trust_remote_code=True,
torch_dtype='auto',
device_map='auto'
)
model = AutoModelForCausalLM.from_pretrained(
'allenai/Molmo-1B-D-0924',
trust_remote_code=True,
torch_dtype='auto',
device_map='auto'
)
@spaces.GPU(duration=120)
def process_image_and_text(image, text):
# Process the image and text
inputs = processor.process(
images=[Image.fromarray(image)],
text=text
)
# Move inputs to the correct device and make a batch of size 1
inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
# Generate output
output = model.generate_from_batch(
inputs,
GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
tokenizer=processor.tokenizer
)
# Only get generated tokens; decode them to text
generated_tokens = output[0, inputs['input_ids'].size(1):]
generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return generated_text
# >>> This photograph captures a small black puppy, likely a Labrador or a similar breed,
# sitting attentively on a weathered wooden deck. The deck, composed of three...
import cv2
class Video():
def __init__(self,prompt):
self.prompt= prompt
self.output_dir=None
# read a mp4 file and getting its frame at a particular interval.
def read_frame(self,file,interval=1):
video=cv2.VideoCapture(file)
fps= video.get(cv2.CAP_PROP_FPS)
frame_interval= fps*interval# fps= 24 frame/sec and interval = 1 sec so frame interval = 24 frame
while True:
success, frame=video.read()
if not success:
break
if frame % frame_interval==0:
# process this frame
"""
send the frame to MolMO which will return either co-ordinates : x,y or Null
"""
self.find(frame)
def find(self,frame):
"""
This function will take in the frame and input message and point to all the messages in the picture.
"""
model= Model()
text = model.generate(frame,self.prompt)
x,y=extract_coordinates(text)
if (x,y):
annotate_the_image_with_pointer(x,y,frame)
else:
"""
read next frame
"""
break;
import re
def extract_coordinates(text):
object=re.compile(r"\(([^)]+)\)")
co_ord= object.search(text)
if co_ord:
# Split the captured text on the comma to get the x and y values
x, y = map(float, co_ord.group(1).split(','))
coordinates = (x, y)
return co_ord
def annotate_the_image_with_pointer(x,y,frame):
return cv2.circle(frame,(x,y),2,(255,0,0),2)
read a .mp4 file
get a interval N spaced