stoefln commited on
Commit
94c76ab
1 Parent(s): 196953e

Implement Molmo-7B, WIP

Browse files
Files changed (1) hide show
  1. app.py +76 -3
app.py CHANGED
@@ -1,7 +1,80 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
3
+ from PIL import Image
4
+ import torch
5
+ import spaces
6
 
7
+ # Load the processor and model
8
+ processor = AutoProcessor.from_pretrained(
9
+ 'allenai/Molmo-7B-D-0924',
10
+ trust_remote_code=True,
11
+ torch_dtype='auto',
12
+ device_map='auto'
13
+ )
14
+
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ 'allenai/Molmo-7B-D-0924',
17
+ trust_remote_code=True,
18
+ torch_dtype='auto',
19
+ device_map='auto'
20
+ )
21
+
22
+
23
+ @spaces.GPU(duration=120)
24
+ def process_image_and_text(image, text):
25
+ # Process the image and text
26
+ inputs = processor.process(
27
+ images=[Image.fromarray(image)],
28
+ text=text
29
+ )
30
+
31
+ # Move inputs to the correct device and make a batch of size 1
32
+ inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
33
+
34
+ # Generate output
35
+ output = model.generate_from_batch(
36
+ inputs,
37
+ GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
38
+ tokenizer=processor.tokenizer
39
+ )
40
+
41
+ # Only get generated tokens; decode them to text
42
+ generated_tokens = output[0, inputs['input_ids'].size(1):]
43
+ generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
44
+
45
+ return generated_text
46
+
47
+ def chatbot(image, text, history):
48
+ if image is None:
49
+ return history + [("Please upload an image first.", None)]
50
+
51
+ response = process_image_and_text(image, text)
52
+ history.append((text, response))
53
+ return history
54
+
55
+ # Define the Gradio interface
56
+ with gr.Blocks() as demo:
57
+ gr.Markdown("# Image Chatbot with Molmo-7B-D-0924")
58
+
59
+ with gr.Row():
60
+ image_input = gr.Image(type="numpy")
61
+ chatbot_output = gr.Chatbot()
62
+
63
+ text_input = gr.Textbox(placeholder="Ask a question about the image...")
64
+ submit_button = gr.Button("Submit")
65
+
66
+ state = gr.State([])
67
+
68
+ submit_button.click(
69
+ chatbot,
70
+ inputs=[image_input, text_input, state],
71
+ outputs=[chatbot_output]
72
+ )
73
+
74
+ text_input.submit(
75
+ chatbot,
76
+ inputs=[image_input, text_input, state],
77
+ outputs=[chatbot_output]
78
+ )
79
 
 
80
  demo.launch()