File size: 2,416 Bytes
e4147e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
from typing import Any, Mapping
from .configuration_arabichar import ArabiCharModelConfig
from transformers import PreTrainedModel
import torch
import torch.nn as nn
class ArabiCharModel(nn.Module):
def __init__(self, config):
super(ArabiCharModel, self).__init__()
self.conv1 = nn.Conv2d(1, config.conv1_channels, kernel_size=5, padding=4)
self.conv2 = nn.Conv2d(config.conv1_channels, config.conv1_channels, kernel_size=5)
self.conv3 = nn.Conv2d(config.conv1_channels, config.conv1_channels, kernel_size=5)
self.pool1 = nn.MaxPool2d(2)
self.bn1 = nn.BatchNorm2d(config.conv1_channels)
self.conv4 = nn.Conv2d(config.conv1_channels, config.conv2_channels, kernel_size=5, padding=4)
self.conv5 = nn.Conv2d(config.conv2_channels, config.conv2_channels, kernel_size=5)
self.conv6 = nn.Conv2d(config.conv2_channels, config.conv2_channels, kernel_size=5)
self.pool2 = nn.MaxPool2d(2)
self.bn2 = nn.BatchNorm2d(config.conv2_channels)
self.fc1 = nn.Linear(config.conv2_channels * 5 * 5, config.fc1_units)
self.fc2 = nn.Linear(config.fc1_units, config.fc1_units)
self.dropout = nn.Dropout(config.dropout_prob)
self.fc3 = nn.Linear(config.fc1_units, config.num_classes)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = torch.relu(self.conv3(x))
x = self.pool1(x)
x = self.bn1(x)
x = torch.relu(self.conv4(x))
x = torch.relu(self.conv5(x))
x = torch.relu(self.conv6(x))
x = self.pool2(x)
x = self.bn2(x)
x = x.view(x.size(0), -1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.dropout(x)
return torch.softmax(self.fc3(x), dim=1)
class ArabiCharModelForImageClassification(PreTrainedModel):
config_class = ArabiCharModelConfig
def __init__(self, config):
super().__init__(config)
self.model = ArabiCharModel(config)
def forward(self, tensor, labels=None):
logits = self.model(tensor)
if labels is not None:
loss = torch.nn.cross_entropy(logits, labels)
return {"loss": loss, "logits": logits}
return {"logits": logits}
def load_state_dict(self, model_name):
self.model.load_state_dict(torch.load(model_name))
|