mgyigit commited on
Commit
de474d4
·
verified ·
1 Parent(s): 0d7fce8

Update src/bin/binding_affinity_estimator.py

Browse files
src/bin/binding_affinity_estimator.py CHANGED
@@ -47,42 +47,28 @@ def calc_validation_error(X_test, y_test, model):
47
  def calc_metrics(X_train, y_train, X_test, y_test, model):
48
  '''Fits the model and returns the metrics for in-sample and out-of-sample errors.'''
49
  model.fit(X_train, y_train)
50
- train_mse_error, train_mae_error, train_corr = calc_train_error(X_train, y_train, model)
51
  val_mse_error, val_mae_error, val_corr = calc_validation_error(X_test, y_test, model)
52
- return train_mse_error, val_mse_error, train_mae_error, val_mae_error, train_corr, val_corr
53
 
54
  def report_results(
55
- train_mse_error_list,
56
  validation_mse_error_list,
57
- train_mae_error_list,
58
  validation_mae_error_list,
59
- train_corr_list,
60
  validation_corr_list,
61
- train_corr_pval_list,
62
  validation_corr_pval_list,
63
  ):
64
  result_summary = {
65
- "train_mse_error": round(np.mean(train_mse_error_list) * 100, 4),
66
- "train_mse_std": round(np.std(train_mse_error_list) * 100, 4),
67
  "val_mse_error": round(np.mean(validation_mse_error_list) * 100, 4),
68
  "val_mse_std": round(np.std(validation_mse_error_list) * 100, 4),
69
- "train_mae_error": round(np.mean(train_mae_error_list) * 100, 4),
70
- "train_mae_std": round(np.std(train_mae_error_list) * 100, 4),
71
  "val_mae_error": round(np.mean(validation_mae_error_list) * 100, 4),
72
  "val_mae_std": round(np.std(validation_mae_error_list) * 100, 4),
73
- "train_corr": round(np.mean(train_corr_list), 4),
74
- "train_corr_pval": round(np.mean(train_corr_pval_list), 4),
75
  "validation_corr": round(np.mean(validation_corr_list), 4),
76
  "validation_corr_pval": round(np.mean(validation_corr_pval_list), 4),
77
  }
78
 
79
  result_detail = {
80
- "train_mse_errors": list(np.multiply(train_mse_error_list, 100)),
81
  "val_mse_errors": list(np.multiply(validation_mse_error_list, 100)),
82
- "train_mae_errors": list(np.multiply(train_mae_error_list, 100)),
83
  "val_mae_errors": list(np.multiply(validation_mae_error_list, 100)),
84
- "train_corrs": list(np.multiply(train_corr_list, 100)),
85
- "train_corr_pvals": list(np.multiply(train_corr_pval_list, 100)),
86
  "validation_corrs": list(np.multiply(validation_corr_list, 100)),
87
  "validation_corr_pvals": list(np.multiply(validation_corr_pval_list, 100)),
88
  }
@@ -123,35 +109,24 @@ def predictAffinityWithModel(regressor_model, multiplied_vectors_df):
123
 
124
  # calculate errors
125
  (
126
- train_mse_error,
127
  val_mse_error,
128
- train_mae_error,
129
  val_mae_error,
130
- train_corr,
131
  val_corr,
132
  ) = calc_metrics(X_train, y_train, X_val, y_val, reg)
133
 
134
  # append to appropriate lists
135
- train_mse_error_list.append(train_mse_error)
136
  validation_mse_error_list.append(val_mse_error)
137
 
138
- train_mae_error_list.append(train_mae_error)
139
  validation_mae_error_list.append(val_mae_error)
140
 
141
- train_corr_list.append(train_corr[0])
142
  validation_corr_list.append(val_corr[0])
143
 
144
- train_corr_pval_list.append(train_corr[1])
145
  validation_corr_pval_list.append(val_corr[1])
146
 
147
  return report_results(
148
- train_mse_error_list,
149
  validation_mse_error_list,
150
- train_mae_error_list,
151
  validation_mae_error_list,
152
- train_corr_list,
153
  validation_corr_list,
154
- train_corr_pval_list,
155
  validation_corr_pval_list,
156
  )
157
 
 
47
  def calc_metrics(X_train, y_train, X_test, y_test, model):
48
  '''Fits the model and returns the metrics for in-sample and out-of-sample errors.'''
49
  model.fit(X_train, y_train)
50
+ #train_mse_error, train_mae_error, train_corr = calc_train_error(X_train, y_train, model)
51
  val_mse_error, val_mae_error, val_corr = calc_validation_error(X_test, y_test, model)
52
+ return val_mse_error, val_mae_error, val_corr
53
 
54
  def report_results(
 
55
  validation_mse_error_list,
 
56
  validation_mae_error_list,
 
57
  validation_corr_list,
 
58
  validation_corr_pval_list,
59
  ):
60
  result_summary = {
 
 
61
  "val_mse_error": round(np.mean(validation_mse_error_list) * 100, 4),
62
  "val_mse_std": round(np.std(validation_mse_error_list) * 100, 4),
 
 
63
  "val_mae_error": round(np.mean(validation_mae_error_list) * 100, 4),
64
  "val_mae_std": round(np.std(validation_mae_error_list) * 100, 4),
 
 
65
  "validation_corr": round(np.mean(validation_corr_list), 4),
66
  "validation_corr_pval": round(np.mean(validation_corr_pval_list), 4),
67
  }
68
 
69
  result_detail = {
 
70
  "val_mse_errors": list(np.multiply(validation_mse_error_list, 100)),
 
71
  "val_mae_errors": list(np.multiply(validation_mae_error_list, 100)),
 
 
72
  "validation_corrs": list(np.multiply(validation_corr_list, 100)),
73
  "validation_corr_pvals": list(np.multiply(validation_corr_pval_list, 100)),
74
  }
 
109
 
110
  # calculate errors
111
  (
 
112
  val_mse_error,
 
113
  val_mae_error,
 
114
  val_corr,
115
  ) = calc_metrics(X_train, y_train, X_val, y_val, reg)
116
 
117
  # append to appropriate lists
 
118
  validation_mse_error_list.append(val_mse_error)
119
 
 
120
  validation_mae_error_list.append(val_mae_error)
121
 
 
122
  validation_corr_list.append(val_corr[0])
123
 
 
124
  validation_corr_pval_list.append(val_corr[1])
125
 
126
  return report_results(
 
127
  validation_mse_error_list,
 
128
  validation_mae_error_list,
 
129
  validation_corr_list,
 
130
  validation_corr_pval_list,
131
  )
132