VictorSanh commited on
Commit
4a9f0a0
·
1 Parent(s): 521b81b

some cleaning and on the path to having token streamign

Browse files
Files changed (1) hide show
  1. app.py +26 -13
app.py CHANGED
@@ -6,7 +6,8 @@ import gradio as gr
6
 
7
  from gradio_client.client import DEFAULT_TEMP_DIR
8
  from playwright.sync_api import sync_playwright
9
- from transformers import AutoProcessor, AutoModelForCausalLM
 
10
  from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
11
  from typing import List
12
  from PIL import Image
@@ -14,15 +15,12 @@ from PIL import Image
14
  from transformers.image_transforms import resize, to_channel_dimension_format
15
 
16
 
17
- API_TOKEN = os.getenv("HF_AUTH_TOKEN")
18
  DEVICE = torch.device("cuda")
19
  PROCESSOR = AutoProcessor.from_pretrained(
20
  "HuggingFaceM4/VLM_WebSight_finetuned",
21
- token=API_TOKEN,
22
  )
23
  MODEL = AutoModelForCausalLM.from_pretrained(
24
  "HuggingFaceM4/VLM_WebSight_finetuned",
25
- token=API_TOKEN,
26
  trust_remote_code=True,
27
  torch_dtype=torch.bfloat16,
28
  ).to(DEVICE)
@@ -134,20 +132,35 @@ def model_inference(
134
  k: v.to(DEVICE)
135
  for k, v in inputs.items()
136
  }
137
- generated_ids = MODEL.generate(
138
- **inputs,
 
 
 
 
 
 
 
 
139
  bad_words_ids=BAD_WORDS_IDS,
140
- max_length=4096
 
141
  )
142
- generated_text = PROCESSOR.batch_decode(
143
- generated_ids,
144
- skip_special_tokens=True
145
- )[0]
 
 
 
 
 
 
 
146
 
147
  rendered_page = render_webpage(generated_text)
148
  return generated_text, rendered_page
149
 
150
-
151
  generated_html = gr.Code(
152
  label="Extracted HTML",
153
  elem_id="generated_html",
@@ -189,7 +202,7 @@ with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as d
189
  regenerate_btn = gr.Button(
190
  value="🔄 Regenerate", visible=True, min_width=120
191
  )
192
- with gr.Column(scale=4) as result_area:
193
  rendered_html.render()
194
 
195
  with gr.Row():
 
6
 
7
  from gradio_client.client import DEFAULT_TEMP_DIR
8
  from playwright.sync_api import sync_playwright
9
+ from threading import Thread
10
+ from transformers import AutoProcessor, AutoModelForCausalLM, TextIteratorStreamer
11
  from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
12
  from typing import List
13
  from PIL import Image
 
15
  from transformers.image_transforms import resize, to_channel_dimension_format
16
 
17
 
 
18
  DEVICE = torch.device("cuda")
19
  PROCESSOR = AutoProcessor.from_pretrained(
20
  "HuggingFaceM4/VLM_WebSight_finetuned",
 
21
  )
22
  MODEL = AutoModelForCausalLM.from_pretrained(
23
  "HuggingFaceM4/VLM_WebSight_finetuned",
 
24
  trust_remote_code=True,
25
  torch_dtype=torch.bfloat16,
26
  ).to(DEVICE)
 
132
  k: v.to(DEVICE)
133
  for k, v in inputs.items()
134
  }
135
+
136
+ streamer = TextIteratorStreamer(
137
+ PROCESSOR.tokenizer,
138
+ decode_kwargs=dict(
139
+ skip_special_tokens=True
140
+ ),
141
+ skip_prompt=True,
142
+ )
143
+ generation_kwargs = dict(
144
+ inputs,
145
  bad_words_ids=BAD_WORDS_IDS,
146
+ max_length=4096,
147
+ streamer=streamer,
148
  )
149
+ thread = Thread(
150
+ target=MODEL.generate,
151
+ kwargs=generation_kwargs,
152
+ )
153
+ thread.start()
154
+ generated_text = ""
155
+ for new_text in streamer:
156
+ generated_text += new_text
157
+ print("before yield")
158
+ # yield generated_text, image
159
+ print("after yield")
160
 
161
  rendered_page = render_webpage(generated_text)
162
  return generated_text, rendered_page
163
 
 
164
  generated_html = gr.Code(
165
  label="Extracted HTML",
166
  elem_id="generated_html",
 
202
  regenerate_btn = gr.Button(
203
  value="🔄 Regenerate", visible=True, min_width=120
204
  )
205
+ with gr.Column(scale=4):
206
  rendered_html.render()
207
 
208
  with gr.Row():