Divyansh12 commited on
Commit
f604f09
·
verified ·
1 Parent(s): 77415fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -63
app.py CHANGED
@@ -7,13 +7,17 @@ import uuid
7
  import time
8
  from pathlib import Path
9
 
10
- # Force the use of CPU
11
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
12
-
13
- # Load tokenizer and model on CPU
14
- tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
15
- model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
16
- model.eval()
 
 
 
 
17
 
18
  # Define folders for uploads and results
19
  UPLOAD_FOLDER = "./uploads"
@@ -23,44 +27,18 @@ for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
23
  if not os.path.exists(folder):
24
  os.makedirs(folder)
25
 
26
- # Function to run the GOT model
27
- def run_GOT(image, got_mode, fine_grained_mode="", ocr_color="", ocr_box=""):
28
  unique_id = str(uuid.uuid4())
29
  image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png")
30
- result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}.html")
31
 
32
  image.save(image_path)
33
 
34
  try:
35
- if got_mode == "plain texts OCR":
36
- res = model.chat(tokenizer, image_path, ocr_type='ocr')
37
- return res, None
38
- elif got_mode == "format texts OCR":
39
- res = model.chat(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
40
- elif got_mode == "plain multi-crop OCR":
41
- res = model.chat_crop(tokenizer, image_path, ocr_type='ocr')
42
- return res, None
43
- elif got_mode == "format multi-crop OCR":
44
- res = model.chat_crop(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
45
- elif got_mode == "plain fine-grained OCR":
46
- res = model.chat(tokenizer, image_path, ocr_type='ocr', ocr_box=ocr_box, ocr_color=ocr_color)
47
- return res, None
48
- elif got_mode == "format fine-grained OCR":
49
- res = model.chat(tokenizer, image_path, ocr_type='format', ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
50
-
51
- res_markdown = res
52
-
53
- if "format" in got_mode and os.path.exists(result_path):
54
- with open(result_path, 'r') as f:
55
- html_content = f.read()
56
- encoded_html = base64.b64encode(html_content.encode('utf-8')).decode('utf-8')
57
- iframe_src = f"data:text/html;base64,{encoded_html}"
58
- iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
59
- return res_markdown, iframe
60
- else:
61
- return res_markdown, None
62
  except Exception as e:
63
- return f"Error: {str(e)}", None
64
  finally:
65
  if os.path.exists(image_path):
66
  os.remove(image_path)
@@ -81,6 +59,9 @@ uploaded_image = st.file_uploader("Upload your image", type=["png", "jpg", "jpeg
81
  # Create two columns for layout
82
  col1, col2 = st.columns(2)
83
 
 
 
 
84
  if uploaded_image:
85
  image = Image.open(uploaded_image)
86
 
@@ -88,33 +69,13 @@ if uploaded_image:
88
  st.image(image, caption='Uploaded Image', use_column_width=True)
89
 
90
  with col2:
91
- got_mode = st.selectbox("Choose one mode of GOT", [
92
- "plain texts OCR",
93
- "format texts OCR",
94
- "plain multi-crop OCR",
95
- "format multi-crop OCR",
96
- "plain fine-grained OCR",
97
- "format fine-grained OCR",
98
- ])
99
-
100
- fine_grained_mode = None
101
- ocr_color = ""
102
- ocr_box = ""
103
-
104
- if "fine-grained" in got_mode:
105
- fine_grained_mode = st.selectbox("Fine-grained type", ["box", "color"])
106
- if fine_grained_mode == "box":
107
- ocr_box = st.text_input("Input box: [x1,y1,x2,y2]", value="[0,0,100,100]")
108
- elif fine_grained_mode == "color":
109
- ocr_color = st.selectbox("Color list", ["red", "green", "blue"])
110
-
111
- if st.button("Submit"):
112
  with st.spinner("Processing..."):
113
- result_text, html_result = run_GOT(image, got_mode, fine_grained_mode, ocr_color, ocr_box)
 
 
114
  st.text_area("GOT Output", result_text, height=200)
115
 
116
- if html_result:
117
- st.markdown(html_result, unsafe_allow_html=True)
118
-
119
  # Cleanup old files
120
  cleanup_old_files()
 
 
7
  import time
8
  from pathlib import Path
9
 
10
+ # Define a function to load the model
11
+ def load_model(model_name):
12
+ if model_name == "GOT_CPU":
13
+ tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
14
+ model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
15
+ model = model.eval() # Load model on CPU
16
+ elif model_name == "GOT_GPU":
17
+ tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
18
+ model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
19
+ model = model.eval().cuda() # Load model on GPU
20
+ return tokenizer, model
21
 
22
  # Define folders for uploads and results
23
  UPLOAD_FOLDER = "./uploads"
 
27
  if not os.path.exists(folder):
28
  os.makedirs(folder)
29
 
30
+ # Function to run the GOT model for plain text OCR
31
+ def run_GOT(image, tokenizer, model):
32
  unique_id = str(uuid.uuid4())
33
  image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png")
 
34
 
35
  image.save(image_path)
36
 
37
  try:
38
+ res = model.chat(tokenizer, image_path, ocr_type='ocr') # Only using plain text OCR
39
+ return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  except Exception as e:
41
+ return f"Error: {str(e)}"
42
  finally:
43
  if os.path.exists(image_path):
44
  os.remove(image_path)
 
59
  # Create two columns for layout
60
  col1, col2 = st.columns(2)
61
 
62
+ # Model selection
63
+ model_option = st.selectbox("Select Model", ["GOT_CPU", "GOT_GPU"])
64
+
65
  if uploaded_image:
66
  image = Image.open(uploaded_image)
67
 
 
69
  st.image(image, caption='Uploaded Image', use_column_width=True)
70
 
71
  with col2:
72
+ if st.button("Run Plain Text OCR"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  with st.spinner("Processing..."):
74
+ # Load the selected model
75
+ tokenizer, model = load_model(model_option)
76
+ result_text = run_GOT(image, tokenizer, model)
77
  st.text_area("GOT Output", result_text, height=200)
78
 
 
 
 
79
  # Cleanup old files
80
  cleanup_old_files()
81
+