gg4ever commited on
Commit
877ca03
·
1 Parent(s): bb68d47

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +62 -0
README.md CHANGED
@@ -1,3 +1,65 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ language:
4
+ - ko
5
+ metrics:
6
+ - cer
7
+ - wer
8
+ pipeline_tag: image-to-text
9
  ---
10
+
11
+ # trOCR-youtube-kor-OCR
12
+
13
+ fine-tuned for VisionEncoderDecoderModel(encoder , decoder)
14
+ encoder = 'facebook/deit-base-distilled-patch16-384'
15
+ decoder = 'klue/roberta-base'
16
+
17
+ ## How to Get Started with the Model
18
+
19
+ ```python
20
+ from transformers import VisionEncoderDecoderModel,AutoTokenizer, TrOCRProcessor
21
+ import torch
22
+ from PIL import Image
23
+
24
+
25
+ device = torch.device('cuda') # change 'cuda' if you need.
26
+
27
+ image_path='(your image path)'
28
+ image = Image.open(image_path)
29
+ #model can be .jpg or .png
30
+ #hugging face download: https://huggingface.co/gg4ever/trOCR-final
31
+
32
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
33
+ trocr_model = "gg4ever/trOCR-youtube-kor-OCR"
34
+ model = VisionEncoderDecoderModel.from_pretrained(trocr_model).to(device)
35
+ tokenizer = AutoTokenizer.from_pretrained(trocr_model)
36
+
37
+ pixel_values = (processor(image, return_tensors="pt").pixel_values).to(device)
38
+ generated_ids = model.generate(pixel_values)
39
+ generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
40
+ print(generated_text)
41
+
42
+ ```
43
+
44
+ ## Training Details
45
+ ### Training Data
46
+
47
+ 100k words generated by TextRecognitionDataGenerator(trdg) : https://github.com/Belval/TextRecognitionDataGenerator/blob/master/trdg/run.py
48
+ 120k words from AI-hub OCR words dataset : https://aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100&dataSetSn=81
49
+
50
+ ### Training Hyperparameters
51
+
52
+ training_args = Seq2SeqTrainingArguments(
53
+ predict_with_generate=True,
54
+ evaluation_strategy="steps",
55
+ per_device_train_batch_size=32,
56
+ per_device_eval_batch_size=32,
57
+ num_train_epochs=2,
58
+ fp16=True,
59
+ learning_rate=4e-5,
60
+ output_dir="./models",
61
+ save_steps=2000,
62
+ eval_steps=1000,
63
+ warmup_steps=2000,
64
+ weight_decay=0.01
65
+ )