Pclanglais commited on
Commit
1fca231
·
verified ·
1 Parent(s): fa86caf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -9
app.py CHANGED
@@ -11,6 +11,7 @@ import shutil
11
  import requests
12
  import pandas as pd
13
  import difflib
 
14
 
15
  # Define the device
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -168,24 +169,32 @@ def split_text(text, max_tokens=500):
168
 
169
 
170
  # Function to generate text
171
- def ocr_correction(prompt, max_new_tokens=600):
172
-
173
  prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
174
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
175
 
 
 
 
176
  # Generate text
177
- output = model.generate(input_ids,
178
- max_new_tokens=max_new_tokens,
179
- pad_token_id=tokenizer.eos_token_id,
180
- top_k=50)
 
 
 
 
 
 
 
 
181
 
182
  # Decode and return the generated text
183
  result = tokenizer.decode(output[0], skip_special_tokens=True)
184
-
185
  print(result)
186
-
187
  result = result.split("### Correction ###")[1]
188
-
189
  return result
190
 
191
  # OCR Correction Class
 
11
  import requests
12
  import pandas as pd
13
  import difflib
14
+ from concurrent.futures import ThreadPoolExecutor
15
 
16
  # Define the device
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
169
 
170
 
171
  # Function to generate text
172
+ def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
 
173
  prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
174
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
175
 
176
+ # Set the number of threads for PyTorch
177
+ torch.set_num_threads(num_threads)
178
+
179
  # Generate text
180
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
181
+ future = executor.submit(
182
+ model.generate,
183
+ input_ids,
184
+ max_new_tokens=max_new_tokens,
185
+ pad_token_id=tokenizer.eos_token_id,
186
+ top_k=50,
187
+ num_return_sequences=1,
188
+ do_sample=True,
189
+ temperature=0.7
190
+ )
191
+ output = future.result()
192
 
193
  # Decode and return the generated text
194
  result = tokenizer.decode(output[0], skip_special_tokens=True)
 
195
  print(result)
196
+
197
  result = result.split("### Correction ###")[1]
 
198
  return result
199
 
200
  # OCR Correction Class