Tymec commited on
Commit
5a2db0a
·
1 Parent(s): 88f3204

Add cross validation

Browse files
Files changed (3) hide show
  1. app/cli.py +24 -9
  2. app/model.py +34 -5
  3. notebook.ipynb +11 -3
app/cli.py CHANGED
@@ -90,6 +90,13 @@ def predict(model_path: Path, text: list[str]) -> None:
90
  show_default=True,
91
  type=click.IntRange(1, None),
92
  )
 
 
 
 
 
 
 
93
  @click.option(
94
  "--seed",
95
  default=42,
@@ -97,19 +104,26 @@ def predict(model_path: Path, text: list[str]) -> None:
97
  show_default=True,
98
  type=click.IntRange(-1, None),
99
  )
 
 
 
 
 
100
  def train(
101
  dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
102
  max_features: int,
 
103
  seed: int,
 
104
  ) -> None:
105
  """Train the model on the provided dataset"""
106
  import joblib
107
 
108
  from app.constants import MODELS_DIR
109
- from app.model import create_model, load_data, train_model
110
 
111
  model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
112
- if model_path.exists():
113
  click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
114
 
115
  click.echo("Preprocessing dataset... ", nl=False)
@@ -122,16 +136,17 @@ def train(
122
 
123
  # click.echo("Training model... ", nl=False)
124
  click.echo("Training model... ")
125
- accuracy = train_model(model, text_data, label_data)
126
- joblib.dump(model, model_path)
127
- click.echo("Model saved to: ", nl=False)
128
- click.secho(str(model_path), fg="blue")
129
-
130
  click.echo("Model accuracy: ", nl=False)
131
  click.secho(f"{accuracy:.2%}", fg="blue")
132
 
133
- # TODO: Add hyperparameter options
134
- # TODO: Random/grid search for finding best classifier and hyperparameters
 
 
 
 
 
135
 
136
 
137
  def cli_wrapper() -> None:
 
90
  show_default=True,
91
  type=click.IntRange(1, None),
92
  )
93
+ @click.option(
94
+ "--cv",
95
+ default=5,
96
+ help="Number of cross-validation folds",
97
+ show_default=True,
98
+ type=click.IntRange(1, 50),
99
+ )
100
  @click.option(
101
  "--seed",
102
  default=42,
 
104
  show_default=True,
105
  type=click.IntRange(-1, None),
106
  )
107
+ @click.option(
108
+ "--force",
109
+ is_flag=True,
110
+ help="Overwrite the model file if it already exists",
111
+ )
112
  def train(
113
  dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
114
  max_features: int,
115
+ cv: int,
116
  seed: int,
117
+ force: bool,
118
  ) -> None:
119
  """Train the model on the provided dataset"""
120
  import joblib
121
 
122
  from app.constants import MODELS_DIR
123
+ from app.model import create_model, evaluate_model, load_data, train_model
124
 
125
  model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
126
+ if model_path.exists() and not force:
127
  click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
128
 
129
  click.echo("Preprocessing dataset... ", nl=False)
 
136
 
137
  # click.echo("Training model... ", nl=False)
138
  click.echo("Training model... ")
139
+ accuracy, text_test, text_label = train_model(model, text_data, label_data)
 
 
 
 
140
  click.echo("Model accuracy: ", nl=False)
141
  click.secho(f"{accuracy:.2%}", fg="blue")
142
 
143
+ click.echo("Model saved to: ", nl=False)
144
+ joblib.dump(model, model_path)
145
+ click.secho(str(model_path), fg="blue")
146
+
147
+ click.echo("Evaluating model... ", nl=False)
148
+ acc_mean, acc_std = evaluate_model(model, text_test, text_label, cv=cv)
149
+ click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
150
 
151
 
152
  def cli_wrapper() -> None:
app/model.py CHANGED
@@ -13,7 +13,7 @@ from nltk.stem import WordNetLemmatizer
13
  from sklearn.base import BaseEstimator, TransformerMixin
14
  from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
15
  from sklearn.linear_model import LogisticRegression
16
- from sklearn.model_selection import train_test_split
17
  from sklearn.pipeline import Pipeline
18
 
19
  from app.constants import (
@@ -28,7 +28,7 @@ from app.constants import (
28
  URL_REGEX,
29
  )
30
 
31
- __all__ = ["load_data", "create_model", "train_model"]
32
 
33
 
34
  class TextCleaner(BaseEstimator, TransformerMixin):
@@ -293,7 +293,7 @@ def train_model(
293
  text_data: list[str],
294
  label_data: list[int],
295
  seed: int = 42,
296
- ) -> float:
297
  """Train the sentiment analysis model.
298
 
299
  Args:
@@ -303,7 +303,7 @@ def train_model(
303
  seed: Random seed (None for random seed)
304
 
305
  Returns:
306
- Accuracy score
307
  """
308
  text_train, text_test, label_train, label_test = train_test_split(
309
  text_data,
@@ -316,4 +316,33 @@ def train_model(
316
  warnings.simplefilter("ignore")
317
  model.fit(text_train, label_train)
318
 
319
- return model.score(text_test, label_test)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  from sklearn.base import BaseEstimator, TransformerMixin
14
  from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
15
  from sklearn.linear_model import LogisticRegression
16
+ from sklearn.model_selection import cross_val_score, train_test_split
17
  from sklearn.pipeline import Pipeline
18
 
19
  from app.constants import (
 
28
  URL_REGEX,
29
  )
30
 
31
+ __all__ = ["load_data", "create_model", "train_model", "evaluate_model"]
32
 
33
 
34
  class TextCleaner(BaseEstimator, TransformerMixin):
 
293
  text_data: list[str],
294
  label_data: list[int],
295
  seed: int = 42,
296
+ ) -> tuple[float, list[str], list[int]]:
297
  """Train the sentiment analysis model.
298
 
299
  Args:
 
303
  seed: Random seed (None for random seed)
304
 
305
  Returns:
306
+ Model accuracy and test data
307
  """
308
  text_train, text_test, label_train, label_test = train_test_split(
309
  text_data,
 
316
  warnings.simplefilter("ignore")
317
  model.fit(text_train, label_train)
318
 
319
+ return model.score(text_test, label_test), text_test, label_test
320
+
321
+
322
+ def evaluate_model(
323
+ model: Pipeline,
324
+ text_test: list[str],
325
+ label_test: list[int],
326
+ cv: int = 5,
327
+ ) -> tuple[float, float]:
328
+ """Evaluate the model using cross-validation.
329
+
330
+ Args:
331
+ model: Trained model
332
+ text_test: Text data
333
+ label_test: Label data
334
+ seed: Random seed (None for random seed)
335
+ cv: Number of cross-validation folds
336
+
337
+ Returns:
338
+ Mean accuracy and standard deviation
339
+ """
340
+ scores = cross_val_score(
341
+ model,
342
+ text_test,
343
+ label_test,
344
+ cv=cv,
345
+ scoring="accuracy",
346
+ n_jobs=-1,
347
+ )
348
+ return scores.mean(), scores.std()
notebook.ipynb CHANGED
@@ -668,9 +668,17 @@
668
  },
669
  {
670
  "cell_type": "code",
671
- "execution_count": null,
672
  "metadata": {},
673
- "outputs": [],
 
 
 
 
 
 
 
 
674
  "source": [
675
  "# SVM\n",
676
  "svm_clf = SVC(random_state=SEED)\n",
@@ -680,7 +688,7 @@
680
  " svm_clf,\n",
681
  " {\n",
682
  " \"C\": np.logspace(-4, 4, 20),\n",
683
- " \"kernel\": [\"linear\", \"poly\", \"rbf\", \"sigmoid\"],\n",
684
  " \"degree\": [2, 3, 4],\n",
685
  " },\n",
686
  ")\n",
 
668
  },
669
  {
670
  "cell_type": "code",
671
+ "execution_count": 24,
672
  "metadata": {},
673
+ "outputs": [
674
+ {
675
+ "name": "stdout",
676
+ "output_type": "stream",
677
+ "text": [
678
+ "Fitting 3 folds for each of 10 candidates, totalling 30 fits\n"
679
+ ]
680
+ }
681
+ ],
682
  "source": [
683
  "# SVM\n",
684
  "svm_clf = SVC(random_state=SEED)\n",
 
688
  " svm_clf,\n",
689
  " {\n",
690
  " \"C\": np.logspace(-4, 4, 20),\n",
691
+ " \"kernel\": [\"linear\", \"poly\", \"rbf\"],\n",
692
  " \"degree\": [2, 3, 4],\n",
693
  " },\n",
694
  ")\n",