dragonSwing
commited on
Commit
·
32335bd
1
Parent(s):
7f0e78a
Fix error prob bug
Browse files- gec_model.py +6 -3
gec_model.py
CHANGED
@@ -78,6 +78,8 @@ class GecBERTModel(torch.nn.Module):
|
|
78 |
List of punctuations.
|
79 |
"""
|
80 |
super().__init__()
|
|
|
|
|
81 |
self.model_weights = list(map(float, weights)) if weights else [1] * len(model_paths)
|
82 |
self.device = (
|
83 |
torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
|
@@ -106,8 +108,6 @@ class GecBERTModel(torch.nn.Module):
|
|
106 |
|
107 |
self.indexers = []
|
108 |
self.models = []
|
109 |
-
if isinstance(model_paths, str):
|
110 |
-
model_paths = [model_paths]
|
111 |
for model_path in model_paths:
|
112 |
model = Seq2LabelsModel.from_pretrained(model_path)
|
113 |
config = model.config
|
@@ -337,7 +337,10 @@ class GecBERTModel(torch.nn.Module):
|
|
337 |
for output, weight in zip(data, self.model_weights):
|
338 |
class_probabilities_labels = torch.softmax(output['logits'], dim=-1)
|
339 |
all_class_probs += weight * class_probabilities_labels / sum(self.model_weights)
|
340 |
-
|
|
|
|
|
|
|
341 |
|
342 |
max_vals = torch.max(all_class_probs, dim=-1)
|
343 |
probs = max_vals[0].tolist()
|
|
|
78 |
List of punctuations.
|
79 |
"""
|
80 |
super().__init__()
|
81 |
+
if isinstance(model_paths, str):
|
82 |
+
model_paths = [model_paths]
|
83 |
self.model_weights = list(map(float, weights)) if weights else [1] * len(model_paths)
|
84 |
self.device = (
|
85 |
torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
|
|
|
108 |
|
109 |
self.indexers = []
|
110 |
self.models = []
|
|
|
|
|
111 |
for model_path in model_paths:
|
112 |
model = Seq2LabelsModel.from_pretrained(model_path)
|
113 |
config = model.config
|
|
|
337 |
for output, weight in zip(data, self.model_weights):
|
338 |
class_probabilities_labels = torch.softmax(output['logits'], dim=-1)
|
339 |
all_class_probs += weight * class_probabilities_labels / sum(self.model_weights)
|
340 |
+
class_probabilities_d = torch.softmax(output['detect_logits'], dim=-1)
|
341 |
+
error_probs_d = class_probabilities_d[:, :, self.incorr_index]
|
342 |
+
incorr_prob = torch.max(error_probs_d, dim=-1)[0]
|
343 |
+
error_probs += weight * incorr_prob / sum(self.model_weights)
|
344 |
|
345 |
max_vals = torch.max(all_class_probs, dim=-1)
|
346 |
probs = max_vals[0].tolist()
|