import torch import torch.nn as nn class NeuralNet(nn.Module): def __init__( self, input_size = 24, hidden_size = 256, num_classes = 5 ): super(NeuralNet, self).__init__() self.l1 = nn.Linear(input_size, hidden_size) self.relu = nn.ReLU() self.l2 = nn.Linear(hidden_size, num_classes) def forward(self, x): out = self.l1(x) out = self.relu(out) out = self.l2(out) return out class KeypointClassification: def __init__(self, path_model): self.path_model = path_model self.classes = ['Downdog', 'Goddess', 'Plank', 'Tree', 'Warrior2'] self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.load_model() def load_model(self): self.model = NeuralNet() self.model.load_state_dict( torch.load(self.path_model, map_location=self.device) ) def __call__(self, input_keypoint): if not type(input_keypoint) == torch.Tensor: input_keypoint = torch.tensor( input_keypoint, dtype=torch.float32 ) out = self.model(input_keypoint) _, predict = torch.max(out, -1) label_predict = self.classes[predict] return label_predict if __name__ == '__main__': keypoint_classification = KeypointClassification( path_model='/Users/alimustofa/Me/source-code/AI/YoloV8_Pose_Classification/models/pose_classification.pt' ) dummy_input = torch.randn(23) classification = keypoint_classification(dummy_input) print(classification)