PoseEstimationYOLOv8 / src /classification_keypoint.py
nishantkaushik20's picture
Update src/classification_keypoint.py
ff8f460
raw
history blame
1.65 kB
import torch
import torch.nn as nn
class NeuralNet(nn.Module):
def __init__(
self,
input_size = 24,
hidden_size = 256,
num_classes = 5
):
super(NeuralNet, self).__init__()
self.l1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.l2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
out = self.l1(x)
out = self.relu(out)
out = self.l2(out)
return out
class KeypointClassification:
def __init__(self, path_model):
self.path_model = path_model
self.classes = ['Downdog', 'Goddess', 'Plank', 'Tree', 'Warrior2']
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.load_model()
def load_model(self):
self.model = NeuralNet()
self.model.load_state_dict(
torch.load(self.path_model, map_location=self.device)
)
def __call__(self, input_keypoint):
if not type(input_keypoint) == torch.Tensor:
input_keypoint = torch.tensor(
input_keypoint, dtype=torch.float32
)
out = self.model(input_keypoint)
_, predict = torch.max(out, -1)
label_predict = self.classes[predict]
return label_predict
if __name__ == '__main__':
keypoint_classification = KeypointClassification(
path_model='/Users/nishantkaushik20/Me/source-code/AI/PoseEstimationYOLOv8/models/pose_classification.pt'
)
dummy_input = torch.randn(23)
classification = keypoint_classification(dummy_input)
print(classification)