KasKniesmeijer commited on
Commit
cab1df1
1 Parent(s): 7c0f537

Add SmolVLM with WebGPU frontend

Browse files
Files changed (5) hide show
  1. app.py +35 -0
  2. index.html +11 -2
  3. requirements.txt +3 -0
  4. src/main.js +35 -4
  5. style.css +29 -6
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoProcessor, AutoModelForVision2Seq
4
+
5
+ # Set the device (CPU or CUDA)
6
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
7
+
8
+ # Initialize processor and model
9
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
10
+ model = AutoModelForVision2Seq.from_pretrained(
11
+ "HuggingFaceTB/SmolVLM-Instruct",
12
+ torch_dtype=torch.bfloat16,
13
+ _attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager",
14
+ ).to(DEVICE)
15
+
16
+
17
+ # Define the function to answer questions
18
+ def answer_question(image, question):
19
+ inputs = processor(images=image, text=question, return_tensors="pt").to(DEVICE)
20
+ outputs = model.generate(**inputs)
21
+ answer = processor.batch_decode(outputs, skip_special_tokens=True)[0]
22
+ return answer
23
+
24
+
25
+ # Gradio interface
26
+ interface = gr.Interface(
27
+ fn=answer_question,
28
+ inputs=["image", "text"],
29
+ outputs="text",
30
+ title="SmolVLM - Vision-Language Question Answering",
31
+ description="Upload an image and ask a question to get an answer powered by SmolVLM.",
32
+ )
33
+
34
+ if __name__ == "__main__":
35
+ interface.launch()
index.html CHANGED
@@ -4,12 +4,21 @@
4
  <head>
5
  <meta charset="UTF-8">
6
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
- <title>WebGPU Hugging Face Space</title>
8
  <link rel="stylesheet" href="styles.css">
9
  </head>
10
 
11
  <body>
12
- <canvas id="webgpu-canvas"></canvas>
 
 
 
 
 
 
 
 
 
13
  <script type="module" src="./src/main.js"></script>
14
  </body>
15
 
 
4
  <head>
5
  <meta charset="UTF-8">
6
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <title>SmolVLM WebGPU</title>
8
  <link rel="stylesheet" href="styles.css">
9
  </head>
10
 
11
  <body>
12
+ <h1>SmolVLM - Vision-Language Question Answering</h1>
13
+ <div id="app">
14
+ <canvas id="webgpu-canvas"></canvas>
15
+ <div id="controls">
16
+ <input type="file" id="image-upload" accept="image/*">
17
+ <input type="text" id="question" placeholder="Ask a question about the image">
18
+ <button id="submit-btn">Submit</button>
19
+ </div>
20
+ <div id="answer">Answer will appear here</div>
21
+ </div>
22
  <script type="module" src="./src/main.js"></script>
23
  </body>
24
 
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
src/main.js CHANGED
@@ -1,10 +1,11 @@
1
- async function initWebGPU() {
 
 
2
  if (!navigator.gpu) {
3
  document.body.innerHTML = "<p>Your browser does not support WebGPU.</p>";
4
  return;
5
  }
6
 
7
- const canvas = document.getElementById("webgpu-canvas");
8
  const adapter = await navigator.gpu.requestAdapter();
9
  const device = await adapter.requestDevice();
10
  const context = canvas.getContext("webgpu");
@@ -15,7 +16,37 @@ async function initWebGPU() {
15
  alphaMode: "opaque",
16
  });
17
 
18
- console.log("WebGPU initialized successfully!");
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  }
20
 
21
- initWebGPU();
 
 
 
 
 
 
 
 
 
 
 
1
+ async function initializeWebGPU() {
2
+ const canvas = document.getElementById("webgpu-canvas");
3
+
4
  if (!navigator.gpu) {
5
  document.body.innerHTML = "<p>Your browser does not support WebGPU.</p>";
6
  return;
7
  }
8
 
 
9
  const adapter = await navigator.gpu.requestAdapter();
10
  const device = await adapter.requestDevice();
11
  const context = canvas.getContext("webgpu");
 
16
  alphaMode: "opaque",
17
  });
18
 
19
+ console.log("WebGPU initialized.");
20
+ }
21
+
22
+ // Submit the image and question to the backend
23
+ async function submitQuestion(imageFile, question) {
24
+ const formData = new FormData();
25
+ formData.append("image", imageFile);
26
+ formData.append("text", question);
27
+
28
+ const response = await fetch("/predict", {
29
+ method: "POST",
30
+ body: formData,
31
+ });
32
+
33
+ if (!response.ok) {
34
+ console.error("Failed to get a response:", response.statusText);
35
+ return "Error: Unable to fetch the answer.";
36
+ }
37
+
38
+ const result = await response.json();
39
+ return result.data[0];
40
  }
41
 
42
+ // Handle user interactions
43
+ document.getElementById("submit-btn").addEventListener("click", async () => {
44
+ const imageFile = document.getElementById("image-upload").files[0];
45
+ const question = document.getElementById("question").value;
46
+
47
+ const answer = await submitQuestion(imageFile, question);
48
+ document.getElementById("answer").innerText = `Answer: ${answer}`;
49
+ });
50
+
51
+ // Initialize WebGPU when the page loads
52
+ initializeWebGPU();
style.css CHANGED
@@ -1,16 +1,39 @@
1
  body {
 
 
 
 
2
  margin: 0;
 
 
 
 
 
 
 
 
3
  display: flex;
4
- justify-content: center;
5
  align-items: center;
6
- height: 100vh;
7
- background: #222;
8
- color: white;
9
- font-family: Arial, sans-serif;
10
  }
11
 
12
  canvas {
13
  width: 800px;
14
  height: 600px;
15
- border: 1px solid #fff;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  }
 
1
  body {
2
+ font-family: Arial, sans-serif;
3
+ background: #222;
4
+ color: white;
5
+ text-align: center;
6
  margin: 0;
7
+ padding: 0;
8
+ }
9
+
10
+ h1 {
11
+ margin: 20px;
12
+ }
13
+
14
+ #app {
15
  display: flex;
16
+ flex-direction: column;
17
  align-items: center;
18
+ margin: 20px;
 
 
 
19
  }
20
 
21
  canvas {
22
  width: 800px;
23
  height: 600px;
24
+ margin: 20px 0;
25
+ border: 2px solid white;
26
+ }
27
+
28
+ #controls {
29
+ display: flex;
30
+ flex-direction: column;
31
+ align-items: center;
32
+ gap: 10px;
33
+ }
34
+
35
+ #answer {
36
+ margin-top: 20px;
37
+ font-size: 1.2em;
38
+ color: #0f0;
39
  }