Gabor Cselle
commited on
Commit
·
2e58968
1
Parent(s):
ea56d2d
It does help if we save the model :-)
Browse files- train_font_identifier.py +4 -1
train_font_identifier.py
CHANGED
@@ -13,7 +13,7 @@ data_dir = './train_test_images'
|
|
13 |
|
14 |
# Transformations for the image data
|
15 |
data_transforms = transforms.Compose([
|
16 |
-
|
17 |
transforms.Resize((224, 224)), # Resize images to the expected input size of the model
|
18 |
transforms.ToTensor(), # Convert images to PyTorch tensors
|
19 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize with ImageNet stats
|
@@ -88,3 +88,6 @@ for epoch in range(num_epochs):
|
|
88 |
train_loss = train_step(model, dataloaders["train"], criterion, optimizer)
|
89 |
val_loss, val_accuracy = validate(model, dataloaders["test"], criterion)
|
90 |
print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
|
|
|
|
|
|
|
|
13 |
|
14 |
# Transformations for the image data
|
15 |
data_transforms = transforms.Compose([
|
16 |
+
transforms.Grayscale(num_output_channels=3), # Convert images to grayscale with 3 channels
|
17 |
transforms.Resize((224, 224)), # Resize images to the expected input size of the model
|
18 |
transforms.ToTensor(), # Convert images to PyTorch tensors
|
19 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize with ImageNet stats
|
|
|
88 |
train_loss = train_step(model, dataloaders["train"], criterion, optimizer)
|
89 |
val_loss, val_accuracy = validate(model, dataloaders["test"], criterion)
|
90 |
print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
|
91 |
+
|
92 |
+
# Save the model to disk
|
93 |
+
torch.save(model.state_dict(), 'font_identifier_model.pth')
|