waidhoferj commited on
Commit
17a2a7d
·
1 Parent(s): ad4c4e2

added evaulations

Browse files
Files changed (2) hide show
  1. models/training_environment.py +55 -9
  2. models/utils.py +29 -2
models/training_environment.py CHANGED
@@ -1,10 +1,16 @@
1
  import importlib
2
- from models.utils import calculate_metrics
3
-
4
  from abc import ABC, abstractmethod
5
  import pytorch_lightning as pl
 
6
  import torch
7
  import torch.nn as nn
 
 
 
 
 
8
 
9
 
10
  class TrainingEnvironment(pl.LightningModule):
@@ -27,8 +33,8 @@ class TrainingEnvironment(pl.LightningModule):
27
  config["training_environment"].get("loggers", {})
28
  )
29
  self.config = config
30
- self.has_multi_label_predictions = (
31
- not type(criterion).__name__ == "CrossEntropyLoss"
32
  )
33
  self.save_hyperparameters(
34
  {
@@ -44,6 +50,8 @@ class TrainingEnvironment(pl.LightningModule):
44
  ) -> torch.Tensor:
45
  features, labels = batch
46
  outputs = self.model(features)
 
 
47
  loss = self.criterion(outputs, labels)
48
  metrics = calculate_metrics(
49
  outputs,
@@ -62,6 +70,8 @@ class TrainingEnvironment(pl.LightningModule):
62
  ):
63
  x, y = batch
64
  preds = self.model(x)
 
 
65
  metrics = calculate_metrics(
66
  preds, y, prefix="val/", multi_label=self.has_multi_label_predictions
67
  )
@@ -71,12 +81,48 @@ class TrainingEnvironment(pl.LightningModule):
71
  def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
72
  x, y = batch
73
  preds = self.model(x)
74
- self.log_dict(
75
- calculate_metrics(
76
- preds, y, prefix="test/", multi_label=self.has_multi_label_predictions
77
- ),
78
- prog_bar=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  def configure_optimizers(self):
82
  optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
 
1
  import importlib
2
+ from models.utils import calculate_metrics, plot_to_image, get_dance_mapping
3
+ import numpy as np
4
  from abc import ABC, abstractmethod
5
  import pytorch_lightning as pl
6
+ import matplotlib.pyplot as plt
7
  import torch
8
  import torch.nn as nn
9
+ from sklearn.metrics import (
10
+ roc_auc_score,
11
+ confusion_matrix,
12
+ ConfusionMatrixDisplay,
13
+ )
14
 
15
 
16
  class TrainingEnvironment(pl.LightningModule):
 
33
  config["training_environment"].get("loggers", {})
34
  )
35
  self.config = config
36
+ self.has_multi_label_predictions = not (
37
+ type(criterion).__name__ == "CrossEntropyLoss"
38
  )
39
  self.save_hyperparameters(
40
  {
 
50
  ) -> torch.Tensor:
51
  features, labels = batch
52
  outputs = self.model(features)
53
+ if self.has_multi_label_predictions:
54
+ outputs = nn.functional.sigmoid(outputs)
55
  loss = self.criterion(outputs, labels)
56
  metrics = calculate_metrics(
57
  outputs,
 
70
  ):
71
  x, y = batch
72
  preds = self.model(x)
73
+ if self.has_multi_label_predictions:
74
+ preds = nn.functional.sigmoid(preds)
75
  metrics = calculate_metrics(
76
  preds, y, prefix="val/", multi_label=self.has_multi_label_predictions
77
  )
 
81
  def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
82
  x, y = batch
83
  preds = self.model(x)
84
+ if self.has_multi_label_predictions:
85
+ preds = nn.functional.sigmoid(preds)
86
+ metrics = calculate_metrics(
87
+ preds, y, prefix="test/", multi_label=self.has_multi_label_predictions
88
+ )
89
+ if not self.has_multi_label_predictions:
90
+ preds = nn.functional.softmax(preds, dim=1)
91
+ y = y.detach().cpu().numpy()
92
+ preds = preds.detach().cpu().numpy()
93
+ # ROC-auc score
94
+ try:
95
+ metrics["test/roc_auc_score"] = torch.tensor(
96
+ roc_auc_score(y, preds), dtype=torch.float32
97
+ )
98
+ except ValueError:
99
+ # If there is only one class, roc_auc_score will throw an error
100
+ pass
101
+
102
+ pass
103
+ self.log_dict(metrics, prog_bar=True)
104
+ # Create confusion matrix
105
+
106
+ preds = preds.argmax(axis=1)
107
+ y = y.argmax(axis=1)
108
+ cm = confusion_matrix(
109
+ preds, y, normalize="all", labels=np.arange(len(self.config["dance_ids"]))
110
  )
111
+ if hasattr(self, "test_cm"):
112
+ self.test_cm += cm
113
+ else:
114
+ self.test_cm = cm
115
+
116
+ def on_test_end(self):
117
+ dance_ids = sorted(self.config["dance_ids"])
118
+ np.fill_diagonal(self.test_cm, 0)
119
+ cm = self.test_cm / self.test_cm.max()
120
+ ConfusionMatrixDisplay(cm, display_labels=dance_ids).plot()
121
+ image = plot_to_image(plt.gcf())
122
+ image = torch.tensor(image, dtype=torch.uint8)
123
+ image = image.permute(2, 0, 1)
124
+ self.logger.experiment.add_image("test/confusion_matrix", image, 0)
125
+ delattr(self, "test_cm")
126
 
127
  def configure_optimizers(self):
128
  optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
models/utils.py CHANGED
@@ -2,6 +2,11 @@ import torch.nn as nn
2
  import torch
3
  import numpy as np
4
  from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
 
 
 
 
 
5
 
6
 
7
  class LabelWeightedBCELoss(nn.Module):
@@ -38,10 +43,13 @@ def calculate_metrics(
38
  ) -> dict[str, torch.Tensor]:
39
  target = target.detach().cpu().numpy()
40
  pred = pred.detach().cpu()
41
- pred = nn.functional.softmax(pred, dim=1)
 
42
  pred = pred.numpy()
43
  params = {
44
- "y_true": target if multi_label else target.argmax(1),
 
 
45
  "y_pred": np.array(pred > threshold, dtype=float)
46
  if multi_label
47
  else pred.argmax(1),
@@ -85,3 +93,22 @@ def get_id_label_mapping(labels: list[str]) -> tuple[dict, dict]:
85
  def compute_hf_metrics(eval_pred):
86
  predictions = np.argmax(eval_pred.predictions, axis=1)
87
  return accuracy_score(y_true=eval_pred.label_ids, y_pred=predictions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import numpy as np
4
  from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
5
+ from functools import cache
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ import io
9
+ from PIL import Image
10
 
11
 
12
  class LabelWeightedBCELoss(nn.Module):
 
43
  ) -> dict[str, torch.Tensor]:
44
  target = target.detach().cpu().numpy()
45
  pred = pred.detach().cpu()
46
+ if not multi_label:
47
+ pred = nn.functional.softmax(pred, dim=1)
48
  pred = pred.numpy()
49
  params = {
50
+ "y_true": np.array(target > 0.0, dtype=float)
51
+ if multi_label
52
+ else target.argmax(1),
53
  "y_pred": np.array(pred > threshold, dtype=float)
54
  if multi_label
55
  else pred.argmax(1),
 
93
  def compute_hf_metrics(eval_pred):
94
  predictions = np.argmax(eval_pred.predictions, axis=1)
95
  return accuracy_score(y_true=eval_pred.label_ids, y_pred=predictions)
96
+
97
+
98
+ @cache
99
+ def get_dance_mapping(mapping_file: str) -> dict[str, str]:
100
+ mapping_df = pd.read_csv(mapping_file)
101
+ return {row["id"]: row["name"] for _, row in mapping_df.iterrows()}
102
+
103
+
104
+ def plot_to_image(figure) -> np.ndarray:
105
+ """Converts the matplotlib plot specified by 'figure' to a PNG image and
106
+ returns it. The supplied figure is closed and inaccessible after this call."""
107
+ # Save the plot to a PNG in memory.
108
+ buf = io.BytesIO()
109
+ plt.savefig(buf, format="png")
110
+ # Closing the figure prevents it from being displayed directly inside
111
+ # the notebook.
112
+ plt.close(figure)
113
+ buf.seek(0)
114
+ return np.array(Image.open(buf))