Gabor Cselle commited on
Commit
99f802a
·
1 Parent(s): 95ccd40

Train a Font Identifier using ResNet18

Browse files
README.md CHANGED
@@ -7,5 +7,6 @@ Follow along:
7
  - [On Threads.net](https://www.threads.net/@gaborcselle/post/CzZJpJCpxTz)
8
  - [On Twitter](https://twitter.com/gabor/status/1722300841691103467)
9
 
10
- Generate sample images (note this will work only on Mac): [gen_sample_data.py]
11
- Arrange test images into test and train: [arrange_train_test_images.py]
 
 
7
  - [On Threads.net](https://www.threads.net/@gaborcselle/post/CzZJpJCpxTz)
8
  - [On Twitter](https://twitter.com/gabor/status/1722300841691103467)
9
 
10
+ Generate sample images (note this will work only on Mac): [gen_sample_data.py](gen_sample_data.py)
11
+ Arrange test images into test and train: [arrange_train_test_images.py](arrange_train_test_images.py)
12
+ Train a ResNet18 on the data: [train_font_identifier.py](train_font_identifier.py)
arrange_train_test_images.py CHANGED
@@ -29,10 +29,10 @@ for font in fonts:
29
  train_files = font_files[:int(0.8 * len(font_files))]
30
  test_files = font_files[int(0.8 * len(font_files)):]
31
 
32
- # Moving training files
33
  for train_file in train_files:
34
  shutil.move(os.path.join(source_dir, train_file), font_train_dir)
35
 
36
- # Moving test files
37
  for test_file in test_files:
38
  shutil.move(os.path.join(source_dir, test_file), font_test_dir)
 
29
  train_files = font_files[:int(0.8 * len(font_files))]
30
  test_files = font_files[int(0.8 * len(font_files)):]
31
 
32
+ # Move training files
33
  for train_file in train_files:
34
  shutil.move(os.path.join(source_dir, train_file), font_train_dir)
35
 
36
+ # Move test files
37
  for test_file in test_files:
38
  shutil.move(os.path.join(source_dir, test_file), font_test_dir)
gen_sample_data.py CHANGED
@@ -7,6 +7,8 @@ import nltk
7
  from nltk.corpus import brown
8
  import random
9
 
 
 
10
  # Download the necessary data from nltk
11
  nltk.download('brown')
12
 
@@ -55,7 +57,7 @@ for font_dir in font_dirs:
55
 
56
  # Counter for the image filename
57
  j = 0
58
- for i in range(10): # Generate 50 images per font - reduced to 10 for now to make things faster
59
  prose_sample = random_prose_text(all_brown_words)
60
 
61
  for text in [prose_sample]:
 
7
  from nltk.corpus import brown
8
  import random
9
 
10
+ IMAGES_PER_FONT = 50
11
+
12
  # Download the necessary data from nltk
13
  nltk.download('brown')
14
 
 
57
 
58
  # Counter for the image filename
59
  j = 0
60
+ for i in range(IMAGES_PER_FONT): # Generate 50 images per font - reduced to 10 for now to make things faster
61
  prose_sample = random_prose_text(all_brown_words)
62
 
63
  for text in [prose_sample]:
requirements.txt CHANGED
@@ -1 +1,6 @@
 
1
  Pillow==9.5.0
 
 
 
 
 
1
+ nltk==3.8.1
2
  Pillow==9.5.0
3
+ torch==2.0.0
4
+ torchaudio==2.0.1
5
+ torchvision==0.15.1
6
+ tqdm==4.65.0
train_font_identifier.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import time
4
+ import torch
5
+ import torch.optim as optim
6
+ from torch.optim import lr_scheduler
7
+ from torchvision import datasets, models, transforms
8
+ from tqdm import tqdm
9
+
10
+ # Directory with organized font images
11
+ data_dir = './train_test_images'
12
+
13
+ # Define transformations for the image data
14
+ data_transforms = {
15
+ 'train': transforms.Compose([
16
+ transforms.Resize((224, 224)), # Resize to the input size expected by the model
17
+ transforms.ToTensor(),
18
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet standards
19
+ ]),
20
+ 'test': transforms.Compose([
21
+ transforms.Resize((224, 224)), # Resize to the input size expected by the model
22
+ transforms.ToTensor(),
23
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
24
+ ]),
25
+ }
26
+
27
+
28
+ # Create datasets
29
+ image_datasets = {
30
+ x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
31
+ for x in ['train', 'test']
32
+ }
33
+
34
+ # Create dataloaders
35
+ dataloaders = {
36
+ 'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=4),
37
+ 'test': torch.utils.data.DataLoader(image_datasets['test'], batch_size=4)
38
+ }
39
+
40
+ # Define the model
41
+ model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
42
+
43
+ # Define the loss function
44
+ criterion = torch.nn.CrossEntropyLoss()
45
+
46
+ # Optimizer (you can replace 'model.parameters()' with specific parameters to optimize if needed)
47
+ optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
48
+
49
+ # Decay LR by a factor of 0.1 every 7 epochs
50
+ exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
51
+
52
+ # Number of epochs to train for
53
+ num_epochs = 25
54
+
55
+ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
56
+ since = time.time()
57
+
58
+ best_model_wts = copy.deepcopy(model.state_dict())
59
+ best_acc = 0.0
60
+
61
+ for epoch in range(num_epochs):
62
+ print('Epoch {}/{}'.format(epoch, num_epochs - 1))
63
+ print('-' * 10)
64
+
65
+ # Each epoch has a training and validation phase
66
+ for phase in ['train', 'test']:
67
+ if phase == 'train':
68
+ model.train() # Set model to training mode
69
+ else:
70
+ model.eval() # Set model to evaluate mode
71
+
72
+ running_loss = 0.0
73
+ running_corrects = 0
74
+
75
+ # Iterate over data.
76
+ # Here we wrap the dataloader with tqdm for a progress bar
77
+ for inputs, labels in tqdm(dataloaders[phase], desc=f"Epoch {epoch} - {phase}"):
78
+ # Zero the parameter gradients
79
+ optimizer.zero_grad()
80
+
81
+ # Forward
82
+ # Track history if only in train
83
+ with torch.set_grad_enabled(phase == 'train'):
84
+ outputs = model(inputs)
85
+ _, preds = torch.max(outputs, 1)
86
+ loss = criterion(outputs, labels)
87
+
88
+ # Backward + optimize only if in training phase
89
+ if phase == 'train':
90
+ loss.backward()
91
+ optimizer.step()
92
+
93
+ # Statistics
94
+ running_loss += loss.item() * inputs.size(0)
95
+ running_corrects += torch.sum(preds == labels.data)
96
+ if phase == 'train':
97
+ scheduler.step()
98
+
99
+ epoch_loss = running_loss / len(image_datasets[phase])
100
+ epoch_acc = running_corrects.double() / len(image_datasets[phase])
101
+
102
+ print('{} Loss: {:.4f} Acc: {:.4f}'.format(
103
+ phase, epoch_loss, epoch_acc))
104
+
105
+ # Deep copy the model
106
+ if phase == 'test' and epoch_acc > best_acc:
107
+ best_acc = epoch_acc
108
+ best_model_wts = copy.deepcopy(model.state_dict())
109
+
110
+ print()
111
+
112
+ time_elapsed = time.time() - since
113
+ print('Training complete in {:.0f}m {:.0f}s'.format(
114
+ time_elapsed // 60, time_elapsed % 60))
115
+ print('Best test Acc: {:4f}'.format(best_acc))
116
+
117
+ # Load best model weights
118
+ model.load_state_dict(best_model_wts)
119
+ return model
120
+
121
+ # Train the model
122
+ model = train_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs=num_epochs)