dragonSwing commited on
Commit
32335bd
·
1 Parent(s): 7f0e78a

Fix error prob bug

Browse files
Files changed (1) hide show
  1. gec_model.py +6 -3
gec_model.py CHANGED
@@ -78,6 +78,8 @@ class GecBERTModel(torch.nn.Module):
78
  List of punctuations.
79
  """
80
  super().__init__()
 
 
81
  self.model_weights = list(map(float, weights)) if weights else [1] * len(model_paths)
82
  self.device = (
83
  torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
@@ -106,8 +108,6 @@ class GecBERTModel(torch.nn.Module):
106
 
107
  self.indexers = []
108
  self.models = []
109
- if isinstance(model_paths, str):
110
- model_paths = [model_paths]
111
  for model_path in model_paths:
112
  model = Seq2LabelsModel.from_pretrained(model_path)
113
  config = model.config
@@ -337,7 +337,10 @@ class GecBERTModel(torch.nn.Module):
337
  for output, weight in zip(data, self.model_weights):
338
  class_probabilities_labels = torch.softmax(output['logits'], dim=-1)
339
  all_class_probs += weight * class_probabilities_labels / sum(self.model_weights)
340
- error_probs += weight * output['max_error_probability'] / sum(self.model_weights)
 
 
 
341
 
342
  max_vals = torch.max(all_class_probs, dim=-1)
343
  probs = max_vals[0].tolist()
 
78
  List of punctuations.
79
  """
80
  super().__init__()
81
+ if isinstance(model_paths, str):
82
+ model_paths = [model_paths]
83
  self.model_weights = list(map(float, weights)) if weights else [1] * len(model_paths)
84
  self.device = (
85
  torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
 
108
 
109
  self.indexers = []
110
  self.models = []
 
 
111
  for model_path in model_paths:
112
  model = Seq2LabelsModel.from_pretrained(model_path)
113
  config = model.config
 
337
  for output, weight in zip(data, self.model_weights):
338
  class_probabilities_labels = torch.softmax(output['logits'], dim=-1)
339
  all_class_probs += weight * class_probabilities_labels / sum(self.model_weights)
340
+ class_probabilities_d = torch.softmax(output['detect_logits'], dim=-1)
341
+ error_probs_d = class_probabilities_d[:, :, self.incorr_index]
342
+ incorr_prob = torch.max(error_probs_d, dim=-1)[0]
343
+ error_probs += weight * incorr_prob / sum(self.model_weights)
344
 
345
  max_vals = torch.max(all_class_probs, dim=-1)
346
  probs = max_vals[0].tolist()