Jensen-holm commited on
Commit
8176fea
·
1 Parent(s): ea2bb97

copying the src from the github repository

Browse files
src/m_pp.ipynb ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import pandas as pd\n",
10
+ "import numpy as np\n",
11
+ "import os\n",
12
+ "\n",
13
+ "DATA_DIR = os.path.join(\"..\", \"data\")"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": 2,
19
+ "metadata": {},
20
+ "outputs": [
21
+ {
22
+ "name": "stdout",
23
+ "output_type": "stream",
24
+ "text": [
25
+ "<class 'pandas.core.frame.DataFrame'>\n",
26
+ "RangeIndex: 1315 entries, 0 to 1314\n",
27
+ "Data columns (total 38 columns):\n",
28
+ " # Column Non-Null Count Dtype \n",
29
+ "--- ------ -------------- ----- \n",
30
+ " 0 Season 1315 non-null int64 \n",
31
+ " 1 DayNum 1315 non-null int64 \n",
32
+ " 2 WTeamID 1315 non-null int64 \n",
33
+ " 3 WScore 1315 non-null int64 \n",
34
+ " 4 LTeamID 1315 non-null int64 \n",
35
+ " 5 LScore 1315 non-null int64 \n",
36
+ " 6 WLoc 1315 non-null int64 \n",
37
+ " 7 NumOT 1315 non-null int64 \n",
38
+ " 8 WFGM 1315 non-null int64 \n",
39
+ " 9 WFGA 1315 non-null int64 \n",
40
+ " 10 WFGM3 1315 non-null int64 \n",
41
+ " 11 WFGA3 1315 non-null int64 \n",
42
+ " 12 WFTM 1315 non-null int64 \n",
43
+ " 13 WFTA 1315 non-null int64 \n",
44
+ " 14 WOR 1315 non-null int64 \n",
45
+ " 15 WDR 1315 non-null int64 \n",
46
+ " 16 WAst 1315 non-null int64 \n",
47
+ " 17 WTO 1315 non-null int64 \n",
48
+ " 18 WStl 1315 non-null int64 \n",
49
+ " 19 WBlk 1315 non-null int64 \n",
50
+ " 20 WPF 1315 non-null int64 \n",
51
+ " 21 LFGM 1315 non-null int64 \n",
52
+ " 22 LFGA 1315 non-null int64 \n",
53
+ " 23 LFGM3 1315 non-null int64 \n",
54
+ " 24 LFGA3 1315 non-null int64 \n",
55
+ " 25 LFTM 1315 non-null int64 \n",
56
+ " 26 LFTA 1315 non-null int64 \n",
57
+ " 27 LOR 1315 non-null int64 \n",
58
+ " 28 LDR 1315 non-null int64 \n",
59
+ " 29 LAst 1315 non-null int64 \n",
60
+ " 30 LTO 1315 non-null int64 \n",
61
+ " 31 LStl 1315 non-null int64 \n",
62
+ " 32 LBlk 1315 non-null int64 \n",
63
+ " 33 LPF 1315 non-null int64 \n",
64
+ " 34 GameType 1315 non-null object\n",
65
+ " 35 WPA 1315 non-null int64 \n",
66
+ " 36 LPA 1315 non-null int64 \n",
67
+ " 37 LLoc 1315 non-null int64 \n",
68
+ "dtypes: int64(37), object(1)\n",
69
+ "memory usage: 390.5+ KB\n"
70
+ ]
71
+ }
72
+ ],
73
+ "source": [
74
+ "tourney_games_df = pd.read_csv(\n",
75
+ " os.path.join(DATA_DIR, \"MNCAATourneyDetailedResults.csv\")\n",
76
+ ")\n",
77
+ "\n",
78
+ "tourney_games_df[\"GameType\"] = \"tourney\"\n",
79
+ "\n",
80
+ "tourney_games_df[\"WPA\"] = tourney_games_df[\"LScore\"]\n",
81
+ "tourney_games_df[\"LPA\"] = tourney_games_df[\"WScore\"]\n",
82
+ "\n",
83
+ "tourney_games_df[\"LLoc\"] = tourney_games_df[\"WLoc\"].apply(lambda x: 0 if x == \"A\" else 1)\n",
84
+ "tourney_games_df[\"WLoc\"] = tourney_games_df[\"LLoc\"].apply(lambda x: 0 if x == \"A\" else 1)\n",
85
+ "\n",
86
+ "tourney_games_df.info()"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": 3,
92
+ "metadata": {},
93
+ "outputs": [
94
+ {
95
+ "name": "stdout",
96
+ "output_type": "stream",
97
+ "text": [
98
+ "<class 'pandas.core.frame.DataFrame'>\n",
99
+ "RangeIndex: 111817 entries, 0 to 111816\n",
100
+ "Data columns (total 38 columns):\n",
101
+ " # Column Non-Null Count Dtype \n",
102
+ "--- ------ -------------- ----- \n",
103
+ " 0 Season 111817 non-null int64 \n",
104
+ " 1 DayNum 111817 non-null int64 \n",
105
+ " 2 WTeamID 111817 non-null int64 \n",
106
+ " 3 WScore 111817 non-null int64 \n",
107
+ " 4 LTeamID 111817 non-null int64 \n",
108
+ " 5 LScore 111817 non-null int64 \n",
109
+ " 6 WLoc 111817 non-null int64 \n",
110
+ " 7 NumOT 111817 non-null int64 \n",
111
+ " 8 WFGM 111817 non-null int64 \n",
112
+ " 9 WFGA 111817 non-null int64 \n",
113
+ " 10 WFGM3 111817 non-null int64 \n",
114
+ " 11 WFGA3 111817 non-null int64 \n",
115
+ " 12 WFTM 111817 non-null int64 \n",
116
+ " 13 WFTA 111817 non-null int64 \n",
117
+ " 14 WOR 111817 non-null int64 \n",
118
+ " 15 WDR 111817 non-null int64 \n",
119
+ " 16 WAst 111817 non-null int64 \n",
120
+ " 17 WTO 111817 non-null int64 \n",
121
+ " 18 WStl 111817 non-null int64 \n",
122
+ " 19 WBlk 111817 non-null int64 \n",
123
+ " 20 WPF 111817 non-null int64 \n",
124
+ " 21 LFGM 111817 non-null int64 \n",
125
+ " 22 LFGA 111817 non-null int64 \n",
126
+ " 23 LFGM3 111817 non-null int64 \n",
127
+ " 24 LFGA3 111817 non-null int64 \n",
128
+ " 25 LFTM 111817 non-null int64 \n",
129
+ " 26 LFTA 111817 non-null int64 \n",
130
+ " 27 LOR 111817 non-null int64 \n",
131
+ " 28 LDR 111817 non-null int64 \n",
132
+ " 29 LAst 111817 non-null int64 \n",
133
+ " 30 LTO 111817 non-null int64 \n",
134
+ " 31 LStl 111817 non-null int64 \n",
135
+ " 32 LBlk 111817 non-null int64 \n",
136
+ " 33 LPF 111817 non-null int64 \n",
137
+ " 34 GameType 111817 non-null object\n",
138
+ " 35 WPA 111817 non-null int64 \n",
139
+ " 36 LPA 111817 non-null int64 \n",
140
+ " 37 LLoc 111817 non-null int64 \n",
141
+ "dtypes: int64(37), object(1)\n",
142
+ "memory usage: 32.4+ MB\n"
143
+ ]
144
+ }
145
+ ],
146
+ "source": [
147
+ "reg_games_df = pd.read_csv(\n",
148
+ " os.path.join(DATA_DIR, \"MRegularSeasonDetailedResults.csv\")\n",
149
+ ")\n",
150
+ "\n",
151
+ "reg_games_df[\"GameType\"] = \"reg\"\n",
152
+ "\n",
153
+ "# points allowed column\n",
154
+ "reg_games_df[\"WPA\"] = reg_games_df[\"LScore\"]\n",
155
+ "reg_games_df[\"LPA\"] = reg_games_df[\"WScore\"]\n",
156
+ "\n",
157
+ "# loser location column\n",
158
+ "reg_games_df[\"LLoc\"] = reg_games_df[\"WLoc\"].apply(lambda x: 0 if x == \"A\" else 1)\n",
159
+ "reg_games_df[\"WLoc\"] = reg_games_df[\"LLoc\"].apply(lambda x: 0 if x == \"A\" else 1)\n",
160
+ "\n",
161
+ "reg_games_df.info()"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": 10,
167
+ "metadata": {},
168
+ "outputs": [],
169
+ "source": [
170
+ "\n",
171
+ "def flatten_multi_idx(df: pd.DataFrame) -> None:\n",
172
+ " df.columns = [\"_\".join(filter(None, col)) for col in df.columns.to_flat_index()]\n",
173
+ "\n",
174
+ "\n",
175
+ "def summarize_teams(df: pd.DataFrame) -> pd.DataFrame:\n",
176
+ " other_cols = {\"TeamID\", \"WTeamID\", \"LTeamID\", \"DayNum\", \"Season\", \"GameType\", \"total_games\"}\n",
177
+ " agg_funcs = [np.sum, np.mean, np.median, np.std, np.min, np.max]\n",
178
+ " dfs = {}\n",
179
+ " subsets = [\"W\", \"L\"]\n",
180
+ " for subset in subsets:\n",
181
+ " sub = df[[col for col in df.columns if subset in col or col in other_cols]]\n",
182
+ " agg_df = sub \\\n",
183
+ " .groupby([f\"{subset}TeamID\", \"Season\"]) \\\n",
184
+ " .agg({col: agg_funcs for col in sub.columns if col not in other_cols}) \\\n",
185
+ " .reset_index()\n",
186
+ " \n",
187
+ " flatten_multi_idx(agg_df)\n",
188
+ " agg_df[f\"total{subset}\"] = df \\\n",
189
+ " .groupby([f\"{subset}TeamID\", \"Season\"])[f\"{subset}TeamID\"] \\\n",
190
+ " .transform(\"count\")\n",
191
+ " dfs[subset] = agg_df\n",
192
+ "\n",
193
+ " merged = pd.merge(\n",
194
+ " left=dfs[\"W\"],\n",
195
+ " right=dfs[\"L\"],\n",
196
+ " left_on=[\"WTeamID\", \"Season\"],\n",
197
+ " right_on=[\"LTeamID\", \"Season\"],\n",
198
+ " )\n",
199
+ "\n",
200
+ " merged[\"total_games\"] = merged[\"totalW\"] + merged[\"totalL\"]\n",
201
+ " merged[\"TeamID\"] = merged[\"WTeamID\"]\n",
202
+ " merged.drop([\"WTeamID\", \"LTeamID\"], axis=1, inplace=True)\n",
203
+ " return merged\n",
204
+ "\n",
205
+ " # overall_stats_df = merged[[\"TeamID\", \"Season\", \"total_games\", \"WPA_sum\", \"LPA_sum\", \"total_games\"]]\n",
206
+ " # # Combine stats from games won and games lost\n",
207
+ " # overall_stats_df[\"TotalPA\"] = overall_stats_df[\"WPA_sum\"] + overall_stats_df[\"LPA_sum\"]\n",
208
+ " return merged\n"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": 11,
214
+ "metadata": {},
215
+ "outputs": [],
216
+ "source": [
217
+ "reg_agg_df = summarize_teams(reg_games_df)"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": 12,
223
+ "metadata": {},
224
+ "outputs": [
225
+ {
226
+ "data": {
227
+ "text/html": [
228
+ "<div>\n",
229
+ "<style scoped>\n",
230
+ " .dataframe tbody tr th:only-of-type {\n",
231
+ " vertical-align: middle;\n",
232
+ " }\n",
233
+ "\n",
234
+ " .dataframe tbody tr th {\n",
235
+ " vertical-align: top;\n",
236
+ " }\n",
237
+ "\n",
238
+ " .dataframe thead th {\n",
239
+ " text-align: right;\n",
240
+ " }\n",
241
+ "</style>\n",
242
+ "<table border=\"1\" class=\"dataframe\">\n",
243
+ " <thead>\n",
244
+ " <tr style=\"text-align: right;\">\n",
245
+ " <th></th>\n",
246
+ " <th>Season</th>\n",
247
+ " <th>WScore_sum</th>\n",
248
+ " <th>WScore_mean</th>\n",
249
+ " <th>WScore_median</th>\n",
250
+ " <th>WScore_std</th>\n",
251
+ " <th>WScore_min</th>\n",
252
+ " <th>WScore_max</th>\n",
253
+ " <th>WLoc_sum_x</th>\n",
254
+ " <th>WLoc_mean_x</th>\n",
255
+ " <th>WLoc_median_x</th>\n",
256
+ " <th>...</th>\n",
257
+ " <th>LPA_max</th>\n",
258
+ " <th>LLoc_sum</th>\n",
259
+ " <th>LLoc_mean</th>\n",
260
+ " <th>LLoc_median</th>\n",
261
+ " <th>LLoc_std</th>\n",
262
+ " <th>LLoc_min</th>\n",
263
+ " <th>LLoc_max</th>\n",
264
+ " <th>totalL</th>\n",
265
+ " <th>total_games</th>\n",
266
+ " <th>TeamID</th>\n",
267
+ " </tr>\n",
268
+ " </thead>\n",
269
+ " <tbody>\n",
270
+ " <tr>\n",
271
+ " <th>0</th>\n",
272
+ " <td>2014</td>\n",
273
+ " <td>160</td>\n",
274
+ " <td>80.000000</td>\n",
275
+ " <td>80.0</td>\n",
276
+ " <td>9.899495</td>\n",
277
+ " <td>73</td>\n",
278
+ " <td>87</td>\n",
279
+ " <td>2</td>\n",
280
+ " <td>1.0</td>\n",
281
+ " <td>1.0</td>\n",
282
+ " <td>...</td>\n",
283
+ " <td>103</td>\n",
284
+ " <td>14</td>\n",
285
+ " <td>0.736842</td>\n",
286
+ " <td>1.0</td>\n",
287
+ " <td>0.452414</td>\n",
288
+ " <td>0</td>\n",
289
+ " <td>1</td>\n",
290
+ " <td>6</td>\n",
291
+ " <td>23</td>\n",
292
+ " <td>1101</td>\n",
293
+ " </tr>\n",
294
+ " <tr>\n",
295
+ " <th>1</th>\n",
296
+ " <td>2015</td>\n",
297
+ " <td>542</td>\n",
298
+ " <td>77.428571</td>\n",
299
+ " <td>72.0</td>\n",
300
+ " <td>11.012979</td>\n",
301
+ " <td>65</td>\n",
302
+ " <td>95</td>\n",
303
+ " <td>7</td>\n",
304
+ " <td>1.0</td>\n",
305
+ " <td>1.0</td>\n",
306
+ " <td>...</td>\n",
307
+ " <td>102</td>\n",
308
+ " <td>15</td>\n",
309
+ " <td>0.714286</td>\n",
310
+ " <td>1.0</td>\n",
311
+ " <td>0.462910</td>\n",
312
+ " <td>0</td>\n",
313
+ " <td>1</td>\n",
314
+ " <td>5</td>\n",
315
+ " <td>28</td>\n",
316
+ " <td>1101</td>\n",
317
+ " </tr>\n",
318
+ " <tr>\n",
319
+ " <th>2</th>\n",
320
+ " <td>2016</td>\n",
321
+ " <td>704</td>\n",
322
+ " <td>78.222222</td>\n",
323
+ " <td>79.0</td>\n",
324
+ " <td>9.257129</td>\n",
325
+ " <td>62</td>\n",
326
+ " <td>91</td>\n",
327
+ " <td>9</td>\n",
328
+ " <td>1.0</td>\n",
329
+ " <td>1.0</td>\n",
330
+ " <td>...</td>\n",
331
+ " <td>108</td>\n",
332
+ " <td>13</td>\n",
333
+ " <td>0.722222</td>\n",
334
+ " <td>1.0</td>\n",
335
+ " <td>0.460889</td>\n",
336
+ " <td>0</td>\n",
337
+ " <td>1</td>\n",
338
+ " <td>15</td>\n",
339
+ " <td>38</td>\n",
340
+ " <td>1101</td>\n",
341
+ " </tr>\n",
342
+ " <tr>\n",
343
+ " <th>3</th>\n",
344
+ " <td>2017</td>\n",
345
+ " <td>669</td>\n",
346
+ " <td>74.333333</td>\n",
347
+ " <td>71.0</td>\n",
348
+ " <td>7.648529</td>\n",
349
+ " <td>65</td>\n",
350
+ " <td>85</td>\n",
351
+ " <td>9</td>\n",
352
+ " <td>1.0</td>\n",
353
+ " <td>1.0</td>\n",
354
+ " <td>...</td>\n",
355
+ " <td>89</td>\n",
356
+ " <td>11</td>\n",
357
+ " <td>0.687500</td>\n",
358
+ " <td>1.0</td>\n",
359
+ " <td>0.478714</td>\n",
360
+ " <td>0</td>\n",
361
+ " <td>1</td>\n",
362
+ " <td>10</td>\n",
363
+ " <td>27</td>\n",
364
+ " <td>1101</td>\n",
365
+ " </tr>\n",
366
+ " <tr>\n",
367
+ " <th>4</th>\n",
368
+ " <td>2018</td>\n",
369
+ " <td>915</td>\n",
370
+ " <td>76.250000</td>\n",
371
+ " <td>77.0</td>\n",
372
+ " <td>7.484833</td>\n",
373
+ " <td>62</td>\n",
374
+ " <td>88</td>\n",
375
+ " <td>12</td>\n",
376
+ " <td>1.0</td>\n",
377
+ " <td>1.0</td>\n",
378
+ " <td>...</td>\n",
379
+ " <td>88</td>\n",
380
+ " <td>9</td>\n",
381
+ " <td>0.600000</td>\n",
382
+ " <td>1.0</td>\n",
383
+ " <td>0.507093</td>\n",
384
+ " <td>0</td>\n",
385
+ " <td>1</td>\n",
386
+ " <td>8</td>\n",
387
+ " <td>30</td>\n",
388
+ " <td>1101</td>\n",
389
+ " </tr>\n",
390
+ " <tr>\n",
391
+ " <th>...</th>\n",
392
+ " <td>...</td>\n",
393
+ " <td>...</td>\n",
394
+ " <td>...</td>\n",
395
+ " <td>...</td>\n",
396
+ " <td>...</td>\n",
397
+ " <td>...</td>\n",
398
+ " <td>...</td>\n",
399
+ " <td>...</td>\n",
400
+ " <td>...</td>\n",
401
+ " <td>...</td>\n",
402
+ " <td>...</td>\n",
403
+ " <td>...</td>\n",
404
+ " <td>...</td>\n",
405
+ " <td>...</td>\n",
406
+ " <td>...</td>\n",
407
+ " <td>...</td>\n",
408
+ " <td>...</td>\n",
409
+ " <td>...</td>\n",
410
+ " <td>...</td>\n",
411
+ " <td>...</td>\n",
412
+ " <td>...</td>\n",
413
+ " </tr>\n",
414
+ " <tr>\n",
415
+ " <th>7600</th>\n",
416
+ " <td>2023</td>\n",
417
+ " <td>920</td>\n",
418
+ " <td>70.769231</td>\n",
419
+ " <td>73.0</td>\n",
420
+ " <td>9.047595</td>\n",
421
+ " <td>51</td>\n",
422
+ " <td>82</td>\n",
423
+ " <td>13</td>\n",
424
+ " <td>1.0</td>\n",
425
+ " <td>1.0</td>\n",
426
+ " <td>...</td>\n",
427
+ " <td>102</td>\n",
428
+ " <td>13</td>\n",
429
+ " <td>0.764706</td>\n",
430
+ " <td>1.0</td>\n",
431
+ " <td>0.437237</td>\n",
432
+ " <td>0</td>\n",
433
+ " <td>1</td>\n",
434
+ " <td>14</td>\n",
435
+ " <td>29</td>\n",
436
+ " <td>1476</td>\n",
437
+ " </tr>\n",
438
+ " <tr>\n",
439
+ " <th>7601</th>\n",
440
+ " <td>2024</td>\n",
441
+ " <td>128</td>\n",
442
+ " <td>64.000000</td>\n",
443
+ " <td>64.0</td>\n",
444
+ " <td>9.899495</td>\n",
445
+ " <td>57</td>\n",
446
+ " <td>71</td>\n",
447
+ " <td>2</td>\n",
448
+ " <td>1.0</td>\n",
449
+ " <td>1.0</td>\n",
450
+ " <td>...</td>\n",
451
+ " <td>107</td>\n",
452
+ " <td>17</td>\n",
453
+ " <td>0.739130</td>\n",
454
+ " <td>1.0</td>\n",
455
+ " <td>0.448978</td>\n",
456
+ " <td>0</td>\n",
457
+ " <td>1</td>\n",
458
+ " <td>5</td>\n",
459
+ " <td>25</td>\n",
460
+ " <td>1476</td>\n",
461
+ " </tr>\n",
462
+ " <tr>\n",
463
+ " <th>7602</th>\n",
464
+ " <td>2023</td>\n",
465
+ " <td>864</td>\n",
466
+ " <td>72.000000</td>\n",
467
+ " <td>74.0</td>\n",
468
+ " <td>10.206950</td>\n",
469
+ " <td>53</td>\n",
470
+ " <td>84</td>\n",
471
+ " <td>12</td>\n",
472
+ " <td>1.0</td>\n",
473
+ " <td>1.0</td>\n",
474
+ " <td>...</td>\n",
475
+ " <td>97</td>\n",
476
+ " <td>15</td>\n",
477
+ " <td>0.750000</td>\n",
478
+ " <td>1.0</td>\n",
479
+ " <td>0.444262</td>\n",
480
+ " <td>0</td>\n",
481
+ " <td>1</td>\n",
482
+ " <td>20</td>\n",
483
+ " <td>34</td>\n",
484
+ " <td>1477</td>\n",
485
+ " </tr>\n",
486
+ " <tr>\n",
487
+ " <th>7603</th>\n",
488
+ " <td>2024</td>\n",
489
+ " <td>483</td>\n",
490
+ " <td>80.500000</td>\n",
491
+ " <td>80.0</td>\n",
492
+ " <td>17.683325</td>\n",
493
+ " <td>57</td>\n",
494
+ " <td>101</td>\n",
495
+ " <td>6</td>\n",
496
+ " <td>1.0</td>\n",
497
+ " <td>1.0</td>\n",
498
+ " <td>...</td>\n",
499
+ " <td>90</td>\n",
500
+ " <td>10</td>\n",
501
+ " <td>0.625000</td>\n",
502
+ " <td>1.0</td>\n",
503
+ " <td>0.500000</td>\n",
504
+ " <td>0</td>\n",
505
+ " <td>1</td>\n",
506
+ " <td>9</td>\n",
507
+ " <td>33</td>\n",
508
+ " <td>1477</td>\n",
509
+ " </tr>\n",
510
+ " <tr>\n",
511
+ " <th>7604</th>\n",
512
+ " <td>2024</td>\n",
513
+ " <td>578</td>\n",
514
+ " <td>82.571429</td>\n",
515
+ " <td>80.0</td>\n",
516
+ " <td>7.345228</td>\n",
517
+ " <td>74</td>\n",
518
+ " <td>94</td>\n",
519
+ " <td>7</td>\n",
520
+ " <td>1.0</td>\n",
521
+ " <td>1.0</td>\n",
522
+ " <td>...</td>\n",
523
+ " <td>96</td>\n",
524
+ " <td>12</td>\n",
525
+ " <td>0.857143</td>\n",
526
+ " <td>1.0</td>\n",
527
+ " <td>0.363137</td>\n",
528
+ " <td>0</td>\n",
529
+ " <td>1</td>\n",
530
+ " <td>12</td>\n",
531
+ " <td>26</td>\n",
532
+ " <td>1478</td>\n",
533
+ " </tr>\n",
534
+ " </tbody>\n",
535
+ "</table>\n",
536
+ "<p>7605 rows × 203 columns</p>\n",
537
+ "</div>"
538
+ ],
539
+ "text/plain": [
540
+ " Season WScore_sum WScore_mean WScore_median WScore_std WScore_min \\\n",
541
+ "0 2014 160 80.000000 80.0 9.899495 73 \n",
542
+ "1 2015 542 77.428571 72.0 11.012979 65 \n",
543
+ "2 2016 704 78.222222 79.0 9.257129 62 \n",
544
+ "3 2017 669 74.333333 71.0 7.648529 65 \n",
545
+ "4 2018 915 76.250000 77.0 7.484833 62 \n",
546
+ "... ... ... ... ... ... ... \n",
547
+ "7600 2023 920 70.769231 73.0 9.047595 51 \n",
548
+ "7601 2024 128 64.000000 64.0 9.899495 57 \n",
549
+ "7602 2023 864 72.000000 74.0 10.206950 53 \n",
550
+ "7603 2024 483 80.500000 80.0 17.683325 57 \n",
551
+ "7604 2024 578 82.571429 80.0 7.345228 74 \n",
552
+ "\n",
553
+ " WScore_max WLoc_sum_x WLoc_mean_x WLoc_median_x ... LPA_max \\\n",
554
+ "0 87 2 1.0 1.0 ... 103 \n",
555
+ "1 95 7 1.0 1.0 ... 102 \n",
556
+ "2 91 9 1.0 1.0 ... 108 \n",
557
+ "3 85 9 1.0 1.0 ... 89 \n",
558
+ "4 88 12 1.0 1.0 ... 88 \n",
559
+ "... ... ... ... ... ... ... \n",
560
+ "7600 82 13 1.0 1.0 ... 102 \n",
561
+ "7601 71 2 1.0 1.0 ... 107 \n",
562
+ "7602 84 12 1.0 1.0 ... 97 \n",
563
+ "7603 101 6 1.0 1.0 ... 90 \n",
564
+ "7604 94 7 1.0 1.0 ... 96 \n",
565
+ "\n",
566
+ " LLoc_sum LLoc_mean LLoc_median LLoc_std LLoc_min LLoc_max totalL \\\n",
567
+ "0 14 0.736842 1.0 0.452414 0 1 6 \n",
568
+ "1 15 0.714286 1.0 0.462910 0 1 5 \n",
569
+ "2 13 0.722222 1.0 0.460889 0 1 15 \n",
570
+ "3 11 0.687500 1.0 0.478714 0 1 10 \n",
571
+ "4 9 0.600000 1.0 0.507093 0 1 8 \n",
572
+ "... ... ... ... ... ... ... ... \n",
573
+ "7600 13 0.764706 1.0 0.437237 0 1 14 \n",
574
+ "7601 17 0.739130 1.0 0.448978 0 1 5 \n",
575
+ "7602 15 0.750000 1.0 0.444262 0 1 20 \n",
576
+ "7603 10 0.625000 1.0 0.500000 0 1 9 \n",
577
+ "7604 12 0.857143 1.0 0.363137 0 1 12 \n",
578
+ "\n",
579
+ " total_games TeamID \n",
580
+ "0 23 1101 \n",
581
+ "1 28 1101 \n",
582
+ "2 38 1101 \n",
583
+ "3 27 1101 \n",
584
+ "4 30 1101 \n",
585
+ "... ... ... \n",
586
+ "7600 29 1476 \n",
587
+ "7601 25 1476 \n",
588
+ "7602 34 1477 \n",
589
+ "7603 33 1477 \n",
590
+ "7604 26 1478 \n",
591
+ "\n",
592
+ "[7605 rows x 203 columns]"
593
+ ]
594
+ },
595
+ "execution_count": 12,
596
+ "metadata": {},
597
+ "output_type": "execute_result"
598
+ }
599
+ ],
600
+ "source": [
601
+ "# combine the winning and losing stats so that we have overall game stats\n",
602
+ "reg_agg_df\n"
603
+ ]
604
+ },
605
+ {
606
+ "cell_type": "code",
607
+ "execution_count": null,
608
+ "metadata": {},
609
+ "outputs": [],
610
+ "source": []
611
+ }
612
+ ],
613
+ "metadata": {
614
+ "kernelspec": {
615
+ "display_name": "Python 3 (ipykernel)",
616
+ "language": "python",
617
+ "name": "python3"
618
+ },
619
+ "language_info": {
620
+ "codemirror_mode": {
621
+ "name": "ipython",
622
+ "version": 3
623
+ },
624
+ "file_extension": ".py",
625
+ "mimetype": "text/x-python",
626
+ "name": "python",
627
+ "nbconvert_exporter": "python",
628
+ "pygments_lexer": "ipython3",
629
+ "version": "3.11.7"
630
+ }
631
+ },
632
+ "nbformat": 4,
633
+ "nbformat_minor": 2
634
+ }
src/mens_monte_carlo.ipynb ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import pandas as pd\n",
10
+ "import numpy as np\n",
11
+ "import os\n",
12
+ "\n",
13
+ "DATA_DIR = os.path.join(\"..\", \"data\")"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": []
22
+ }
23
+ ],
24
+ "metadata": {
25
+ "kernelspec": {
26
+ "display_name": "Python 3",
27
+ "language": "python",
28
+ "name": "python3"
29
+ },
30
+ "language_info": {
31
+ "codemirror_mode": {
32
+ "name": "ipython",
33
+ "version": 3
34
+ },
35
+ "file_extension": ".py",
36
+ "mimetype": "text/x-python",
37
+ "name": "python",
38
+ "nbconvert_exporter": "python",
39
+ "pygments_lexer": "ipython3",
40
+ "version": "3.11.7"
41
+ }
42
+ },
43
+ "nbformat": 4,
44
+ "nbformat_minor": 2
45
+ }
src/mens_nn.ipynb ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "998997dd",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Modeling NCAA Tournament Basketball games\n",
9
+ "\n",
10
+ "The thought process is to build a neural network that can predict a teams tournament <br>\n",
11
+ "performance on a per game basis. Then we can use these predicted metrics to run a monte carlo <br>\n",
12
+ "style simulation and select whichever team is most likley to win. <br>"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": 1,
18
+ "id": "f0ec30d9",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "from sklearn.model_selection import train_test_split\n",
23
+ "import torch\n",
24
+ "import torch.nn as nn\n",
25
+ "import torch.optim as optim\n",
26
+ "\n",
27
+ "import pandas as pd\n",
28
+ "import os\n",
29
+ "\n",
30
+ "\n",
31
+ "# check to make sure if there are any gpu's available for faster training\n",
32
+ "def get_device() -> str:\n",
33
+ " if torch.cuda.is_available():\n",
34
+ " return \"cuda\"\n",
35
+ " if torch.backends.mps.is_available():\n",
36
+ " return \"mps\" \n",
37
+ " return \"cpu\"\n",
38
+ "\n",
39
+ "# mps not working correctly on my m1 macbook air so just doing cpu for now\n",
40
+ "# DEVICE = get_device()\n",
41
+ "DEVICE = \"cpu\"\n",
42
+ "\n",
43
+ "# universal data directory for this project\n",
44
+ "DATA_DIR = os.path.join(\"..\", \"data\") "
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 2,
50
+ "id": "b820f210",
51
+ "metadata": {},
52
+ "outputs": [
53
+ {
54
+ "name": "stdout",
55
+ "output_type": "stream",
56
+ "text": [
57
+ "<class 'pandas.core.frame.DataFrame'>\n",
58
+ "RangeIndex: 655 entries, 0 to 654\n",
59
+ "Columns: 1068 entries, Unnamed: 0 to Seed\n",
60
+ "dtypes: float64(672), int64(388), object(8)\n",
61
+ "memory usage: 5.3+ MB\n"
62
+ ]
63
+ }
64
+ ],
65
+ "source": [
66
+ "all_games_df = pd.read_csv(os.path.join(DATA_DIR, \"MDetailedAggregatedGames.csv\"))\n",
67
+ "all_games_df.info()"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": 3,
73
+ "id": "02ebc500",
74
+ "metadata": {},
75
+ "outputs": [
76
+ {
77
+ "data": {
78
+ "text/html": [
79
+ "<div>\n",
80
+ "<style scoped>\n",
81
+ " .dataframe tbody tr th:only-of-type {\n",
82
+ " vertical-align: middle;\n",
83
+ " }\n",
84
+ "\n",
85
+ " .dataframe tbody tr th {\n",
86
+ " vertical-align: top;\n",
87
+ " }\n",
88
+ "\n",
89
+ " .dataframe thead th {\n",
90
+ " text-align: right;\n",
91
+ " }\n",
92
+ "</style>\n",
93
+ "<table border=\"1\" class=\"dataframe\">\n",
94
+ " <thead>\n",
95
+ " <tr style=\"text-align: right;\">\n",
96
+ " <th></th>\n",
97
+ " <th>Unnamed: 0</th>\n",
98
+ " <th>Season</th>\n",
99
+ " <th>DayNum</th>\n",
100
+ " <th>WTeamID</th>\n",
101
+ " <th>WScore</th>\n",
102
+ " <th>LTeamID</th>\n",
103
+ " <th>LScore</th>\n",
104
+ " <th>WLoc</th>\n",
105
+ " <th>NumOT</th>\n",
106
+ " <th>WFGM</th>\n",
107
+ " <th>...</th>\n",
108
+ " <th>tourney_DR_max</th>\n",
109
+ " <th>tourney_DR_mean</th>\n",
110
+ " <th>tourney_DR_median</th>\n",
111
+ " <th>tourney_DR_std</th>\n",
112
+ " <th>tourney_DR_sum</th>\n",
113
+ " <th>ConfAbbrev</th>\n",
114
+ " <th>TeamName</th>\n",
115
+ " <th>FirstD1Season</th>\n",
116
+ " <th>LastD1Season</th>\n",
117
+ " <th>Seed</th>\n",
118
+ " </tr>\n",
119
+ " </thead>\n",
120
+ " <tbody>\n",
121
+ " <tr>\n",
122
+ " <th>0</th>\n",
123
+ " <td>0</td>\n",
124
+ " <td>2003</td>\n",
125
+ " <td>40</td>\n",
126
+ " <td>1266</td>\n",
127
+ " <td>63</td>\n",
128
+ " <td>1458</td>\n",
129
+ " <td>54</td>\n",
130
+ " <td>H</td>\n",
131
+ " <td>0</td>\n",
132
+ " <td>24</td>\n",
133
+ " <td>...</td>\n",
134
+ " <td>21.666667</td>\n",
135
+ " <td>21.666667</td>\n",
136
+ " <td>21.666667</td>\n",
137
+ " <td>21.666667</td>\n",
138
+ " <td>21.666667</td>\n",
139
+ " <td>big_ten</td>\n",
140
+ " <td>Wisconsin</td>\n",
141
+ " <td>1985</td>\n",
142
+ " <td>2024</td>\n",
143
+ " <td>Y05</td>\n",
144
+ " </tr>\n",
145
+ " <tr>\n",
146
+ " <th>1</th>\n",
147
+ " <td>5</td>\n",
148
+ " <td>2003</td>\n",
149
+ " <td>97</td>\n",
150
+ " <td>1266</td>\n",
151
+ " <td>68</td>\n",
152
+ " <td>1448</td>\n",
153
+ " <td>61</td>\n",
154
+ " <td>H</td>\n",
155
+ " <td>0</td>\n",
156
+ " <td>21</td>\n",
157
+ " <td>...</td>\n",
158
+ " <td>26.000000</td>\n",
159
+ " <td>26.000000</td>\n",
160
+ " <td>26.000000</td>\n",
161
+ " <td>26.000000</td>\n",
162
+ " <td>26.000000</td>\n",
163
+ " <td>acc</td>\n",
164
+ " <td>Wake Forest</td>\n",
165
+ " <td>1985</td>\n",
166
+ " <td>2024</td>\n",
167
+ " <td>W02</td>\n",
168
+ " </tr>\n",
169
+ " <tr>\n",
170
+ " <th>2</th>\n",
171
+ " <td>9</td>\n",
172
+ " <td>2003</td>\n",
173
+ " <td>115</td>\n",
174
+ " <td>1266</td>\n",
175
+ " <td>78</td>\n",
176
+ " <td>1257</td>\n",
177
+ " <td>73</td>\n",
178
+ " <td>A</td>\n",
179
+ " <td>0</td>\n",
180
+ " <td>26</td>\n",
181
+ " <td>...</td>\n",
182
+ " <td>24.000000</td>\n",
183
+ " <td>24.000000</td>\n",
184
+ " <td>24.000000</td>\n",
185
+ " <td>24.000000</td>\n",
186
+ " <td>24.000000</td>\n",
187
+ " <td>cusa</td>\n",
188
+ " <td>Louisville</td>\n",
189
+ " <td>1985</td>\n",
190
+ " <td>2024</td>\n",
191
+ " <td>W04</td>\n",
192
+ " </tr>\n",
193
+ " <tr>\n",
194
+ " <th>3</th>\n",
195
+ " <td>12</td>\n",
196
+ " <td>2003</td>\n",
197
+ " <td>138</td>\n",
198
+ " <td>1266</td>\n",
199
+ " <td>101</td>\n",
200
+ " <td>1281</td>\n",
201
+ " <td>92</td>\n",
202
+ " <td>N</td>\n",
203
+ " <td>1</td>\n",
204
+ " <td>35</td>\n",
205
+ " <td>...</td>\n",
206
+ " <td>26.000000</td>\n",
207
+ " <td>26.000000</td>\n",
208
+ " <td>26.000000</td>\n",
209
+ " <td>26.000000</td>\n",
210
+ " <td>26.000000</td>\n",
211
+ " <td>big_twelve</td>\n",
212
+ " <td>Missouri</td>\n",
213
+ " <td>1985</td>\n",
214
+ " <td>2024</td>\n",
215
+ " <td>Y06</td>\n",
216
+ " </tr>\n",
217
+ " <tr>\n",
218
+ " <th>4</th>\n",
219
+ " <td>19</td>\n",
220
+ " <td>2003</td>\n",
221
+ " <td>143</td>\n",
222
+ " <td>1266</td>\n",
223
+ " <td>77</td>\n",
224
+ " <td>1338</td>\n",
225
+ " <td>74</td>\n",
226
+ " <td>N</td>\n",
227
+ " <td>0</td>\n",
228
+ " <td>28</td>\n",
229
+ " <td>...</td>\n",
230
+ " <td>21.333333</td>\n",
231
+ " <td>21.333333</td>\n",
232
+ " <td>21.333333</td>\n",
233
+ " <td>21.333333</td>\n",
234
+ " <td>21.333333</td>\n",
235
+ " <td>big_east</td>\n",
236
+ " <td>Pittsburgh</td>\n",
237
+ " <td>1985</td>\n",
238
+ " <td>2024</td>\n",
239
+ " <td>Y02</td>\n",
240
+ " </tr>\n",
241
+ " </tbody>\n",
242
+ "</table>\n",
243
+ "<p>5 rows × 1068 columns</p>\n",
244
+ "</div>"
245
+ ],
246
+ "text/plain": [
247
+ " Unnamed: 0 Season DayNum WTeamID WScore LTeamID LScore WLoc NumOT \\\n",
248
+ "0 0 2003 40 1266 63 1458 54 H 0 \n",
249
+ "1 5 2003 97 1266 68 1448 61 H 0 \n",
250
+ "2 9 2003 115 1266 78 1257 73 A 0 \n",
251
+ "3 12 2003 138 1266 101 1281 92 N 1 \n",
252
+ "4 19 2003 143 1266 77 1338 74 N 0 \n",
253
+ "\n",
254
+ " WFGM ... tourney_DR_max tourney_DR_mean tourney_DR_median \\\n",
255
+ "0 24 ... 21.666667 21.666667 21.666667 \n",
256
+ "1 21 ... 26.000000 26.000000 26.000000 \n",
257
+ "2 26 ... 24.000000 24.000000 24.000000 \n",
258
+ "3 35 ... 26.000000 26.000000 26.000000 \n",
259
+ "4 28 ... 21.333333 21.333333 21.333333 \n",
260
+ "\n",
261
+ " tourney_DR_std tourney_DR_sum ConfAbbrev TeamName FirstD1Season \\\n",
262
+ "0 21.666667 21.666667 big_ten Wisconsin 1985 \n",
263
+ "1 26.000000 26.000000 acc Wake Forest 1985 \n",
264
+ "2 24.000000 24.000000 cusa Louisville 1985 \n",
265
+ "3 26.000000 26.000000 big_twelve Missouri 1985 \n",
266
+ "4 21.333333 21.333333 big_east Pittsburgh 1985 \n",
267
+ "\n",
268
+ " LastD1Season Seed \n",
269
+ "0 2024 Y05 \n",
270
+ "1 2024 W02 \n",
271
+ "2 2024 W04 \n",
272
+ "3 2024 Y06 \n",
273
+ "4 2024 Y02 \n",
274
+ "\n",
275
+ "[5 rows x 1068 columns]"
276
+ ]
277
+ },
278
+ "execution_count": 3,
279
+ "metadata": {},
280
+ "output_type": "execute_result"
281
+ }
282
+ ],
283
+ "source": [
284
+ "all_games_df.head()"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "markdown",
289
+ "id": "58e4fee8",
290
+ "metadata": {},
291
+ "source": [
292
+ "# Feature Selection"
293
+ ]
294
+ },
295
+ {
296
+ "cell_type": "code",
297
+ "execution_count": 4,
298
+ "id": "1251726e",
299
+ "metadata": {},
300
+ "outputs": [],
301
+ "source": [
302
+ "target_df = all_games_df[[\"tourney_Score_mean\", \"tourney_Score_std\", \"tourney_Score_max\", \"tourney_Score_min\"]]\n",
303
+ "\n",
304
+ "features_df = all_games_df[[col for col in all_games_df if col.startswith(\"reg\") and \"_W\" not in col and \"_L\" not in col and \"sum\" not in col]]\n",
305
+ "# features_df = features_df.select_dtypes(include=\"number\")\n",
306
+ "\n",
307
+ "# split data into training and testing data sets\n",
308
+ "X_train, X_test, y_train, y_test = train_test_split(\n",
309
+ " features_df.astype(float),\n",
310
+ " target_df.astype(float),\n",
311
+ " train_size=0.8,\n",
312
+ " random_state=8,\n",
313
+ ")"
314
+ ]
315
+ },
316
+ {
317
+ "cell_type": "code",
318
+ "execution_count": 5,
319
+ "id": "28478189",
320
+ "metadata": {},
321
+ "outputs": [
322
+ {
323
+ "name": "stdout",
324
+ "output_type": "stream",
325
+ "text": [
326
+ "<class 'pandas.core.frame.DataFrame'>\n",
327
+ "Int64Index: 524 entries, 5 to 451\n",
328
+ "Data columns (total 71 columns):\n",
329
+ " # Column Non-Null Count Dtype \n",
330
+ "--- ------ -------------- ----- \n",
331
+ " 0 reg_Games 524 non-null float64\n",
332
+ " 1 reg_Score_min 524 non-null float64\n",
333
+ " 2 reg_Score_max 524 non-null float64\n",
334
+ " 3 reg_Score_mean 524 non-null float64\n",
335
+ " 4 reg_Score_median 524 non-null float64\n",
336
+ " 5 reg_Score_std 524 non-null float64\n",
337
+ " 6 reg_FGM_min 524 non-null float64\n",
338
+ " 7 reg_FGM_max 524 non-null float64\n",
339
+ " 8 reg_FGM_mean 524 non-null float64\n",
340
+ " 9 reg_FGM_median 524 non-null float64\n",
341
+ " 10 reg_FGM_std 524 non-null float64\n",
342
+ " 11 reg_FGA_min 524 non-null float64\n",
343
+ " 12 reg_FGA_max 524 non-null float64\n",
344
+ " 13 reg_FGA_mean 524 non-null float64\n",
345
+ " 14 reg_FGA_median 524 non-null float64\n",
346
+ " 15 reg_FGA_std 524 non-null float64\n",
347
+ " 16 reg_FTM_min 524 non-null float64\n",
348
+ " 17 reg_FTM_max 524 non-null float64\n",
349
+ " 18 reg_FTM_mean 524 non-null float64\n",
350
+ " 19 reg_FTM_median 524 non-null float64\n",
351
+ " 20 reg_FTM_std 524 non-null float64\n",
352
+ " 21 reg_FTA_min 524 non-null float64\n",
353
+ " 22 reg_FTA_max 524 non-null float64\n",
354
+ " 23 reg_FTA_mean 524 non-null float64\n",
355
+ " 24 reg_FTA_median 524 non-null float64\n",
356
+ " 25 reg_FTA_std 524 non-null float64\n",
357
+ " 26 reg_Ast_min 524 non-null float64\n",
358
+ " 27 reg_Ast_max 524 non-null float64\n",
359
+ " 28 reg_Ast_mean 524 non-null float64\n",
360
+ " 29 reg_Ast_median 524 non-null float64\n",
361
+ " 30 reg_Ast_std 524 non-null float64\n",
362
+ " 31 reg_Blk_min 524 non-null float64\n",
363
+ " 32 reg_Blk_max 524 non-null float64\n",
364
+ " 33 reg_Blk_mean 524 non-null float64\n",
365
+ " 34 reg_Blk_median 524 non-null float64\n",
366
+ " 35 reg_Blk_std 524 non-null float64\n",
367
+ " 36 reg_PF_min 524 non-null float64\n",
368
+ " 37 reg_PF_max 524 non-null float64\n",
369
+ " 38 reg_PF_mean 524 non-null float64\n",
370
+ " 39 reg_PF_median 524 non-null float64\n",
371
+ " 40 reg_PF_std 524 non-null float64\n",
372
+ " 41 reg_Stl_min 524 non-null float64\n",
373
+ " 42 reg_Stl_max 524 non-null float64\n",
374
+ " 43 reg_Stl_mean 524 non-null float64\n",
375
+ " 44 reg_Stl_median 524 non-null float64\n",
376
+ " 45 reg_Stl_std 524 non-null float64\n",
377
+ " 46 reg_TO_min 524 non-null float64\n",
378
+ " 47 reg_TO_max 524 non-null float64\n",
379
+ " 48 reg_TO_mean 524 non-null float64\n",
380
+ " 49 reg_TO_median 524 non-null float64\n",
381
+ " 50 reg_TO_std 524 non-null float64\n",
382
+ " 51 reg_FGM3_min 524 non-null float64\n",
383
+ " 52 reg_FGM3_max 524 non-null float64\n",
384
+ " 53 reg_FGM3_mean 524 non-null float64\n",
385
+ " 54 reg_FGM3_median 524 non-null float64\n",
386
+ " 55 reg_FGM3_std 524 non-null float64\n",
387
+ " 56 reg_FGA3_min 524 non-null float64\n",
388
+ " 57 reg_FGA3_max 524 non-null float64\n",
389
+ " 58 reg_FGA3_mean 524 non-null float64\n",
390
+ " 59 reg_FGA3_median 524 non-null float64\n",
391
+ " 60 reg_FGA3_std 524 non-null float64\n",
392
+ " 61 reg_OR_min 524 non-null float64\n",
393
+ " 62 reg_OR_max 524 non-null float64\n",
394
+ " 63 reg_OR_mean 524 non-null float64\n",
395
+ " 64 reg_OR_median 524 non-null float64\n",
396
+ " 65 reg_OR_std 524 non-null float64\n",
397
+ " 66 reg_DR_min 524 non-null float64\n",
398
+ " 67 reg_DR_max 524 non-null float64\n",
399
+ " 68 reg_DR_mean 524 non-null float64\n",
400
+ " 69 reg_DR_median 524 non-null float64\n",
401
+ " 70 reg_DR_std 524 non-null float64\n",
402
+ "dtypes: float64(71)\n",
403
+ "memory usage: 294.8 KB\n"
404
+ ]
405
+ }
406
+ ],
407
+ "source": [
408
+ "X_train.info()"
409
+ ]
410
+ },
411
+ {
412
+ "cell_type": "code",
413
+ "execution_count": 6,
414
+ "id": "04f4a0a6",
415
+ "metadata": {},
416
+ "outputs": [
417
+ {
418
+ "name": "stdout",
419
+ "output_type": "stream",
420
+ "text": [
421
+ "<class 'pandas.core.frame.DataFrame'>\n",
422
+ "Int64Index: 524 entries, 5 to 451\n",
423
+ "Data columns (total 4 columns):\n",
424
+ " # Column Non-Null Count Dtype \n",
425
+ "--- ------ -------------- ----- \n",
426
+ " 0 tourney_Score_mean 524 non-null float64\n",
427
+ " 1 tourney_Score_std 524 non-null float64\n",
428
+ " 2 tourney_Score_max 524 non-null float64\n",
429
+ " 3 tourney_Score_min 524 non-null float64\n",
430
+ "dtypes: float64(4)\n",
431
+ "memory usage: 20.5 KB\n"
432
+ ]
433
+ }
434
+ ],
435
+ "source": [
436
+ "y_train.info()"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": 7,
442
+ "id": "40094cd0",
443
+ "metadata": {},
444
+ "outputs": [],
445
+ "source": [
446
+ "# convert all datasets into tensors and register them \n",
447
+ "# with the device (cuda, mps or cpu)\n",
448
+ "X_trainT = torch.Tensor(\n",
449
+ " X_train.values,\n",
450
+ ").float().to(DEVICE)\n",
451
+ "\n",
452
+ "X_testT = torch.Tensor(\n",
453
+ " X_test.values,\n",
454
+ ").float().to(DEVICE)\n",
455
+ "\n",
456
+ "y_trainT = torch.Tensor(\n",
457
+ " y_train.values,\n",
458
+ ").float().to(DEVICE)\n",
459
+ "\n",
460
+ "y_testT = torch.Tensor(\n",
461
+ " y_test.values,\n",
462
+ ").float().to(DEVICE)"
463
+ ]
464
+ },
465
+ {
466
+ "cell_type": "markdown",
467
+ "id": "20bceb9a",
468
+ "metadata": {},
469
+ "source": [
470
+ "# Building Neural Network"
471
+ ]
472
+ },
473
+ {
474
+ "cell_type": "code",
475
+ "execution_count": 8,
476
+ "id": "7b0573ee",
477
+ "metadata": {},
478
+ "outputs": [],
479
+ "source": [
480
+ "num_features = len(X_train.columns)\n",
481
+ "\n",
482
+ "class MadnessNN(nn.Module):\n",
483
+ " def __init__(self) -> None:\n",
484
+ " super().__init__()\n",
485
+ " self.input_layer = nn.Linear(num_features, 64)\n",
486
+ " self.activation_func = nn.ReLU()\n",
487
+ " self.layer1 = nn.Linear(64, 32)\n",
488
+ " self.layer2 = nn.Linear(32, 16)\n",
489
+ " self.layer3 = nn.Linear(16, 8)\n",
490
+ " self.output_layer = nn.Linear(8, 4)\n",
491
+ "\n",
492
+ " def forward(self, x):\n",
493
+ " x = self.input_layer(x)\n",
494
+ " x = self.activation_func(x)\n",
495
+ " x = self.layer1(x)\n",
496
+ " x = self.activation_func(x)\n",
497
+ " x = self.layer2(x)\n",
498
+ " x = self.activation_func(x)\n",
499
+ " x = self.layer3(x)\n",
500
+ " x = self.activation_func(x)\n",
501
+ " x = self.output_layer(x)\n",
502
+ " x = self.activation_func(x)\n",
503
+ " return x\n"
504
+ ]
505
+ },
506
+ {
507
+ "cell_type": "markdown",
508
+ "id": "061e2b52",
509
+ "metadata": {},
510
+ "source": [
511
+ "# Training Loop"
512
+ ]
513
+ },
514
+ {
515
+ "cell_type": "code",
516
+ "execution_count": 21,
517
+ "id": "db035b9d",
518
+ "metadata": {},
519
+ "outputs": [
520
+ {
521
+ "name": "stdout",
522
+ "output_type": "stream",
523
+ "text": [
524
+ "[500 / 5000] Loss = 40.454681396484375\n",
525
+ "[1000 / 5000] Loss = 39.701454162597656\n",
526
+ "[1500 / 5000] Loss = 39.055484771728516\n",
527
+ "[2000 / 5000] Loss = 38.53948974609375\n",
528
+ "[2500 / 5000] Loss = 38.149085998535156\n",
529
+ "[3000 / 5000] Loss = 37.87413024902344\n",
530
+ "[3500 / 5000] Loss = 37.6934928894043\n",
531
+ "[4000 / 5000] Loss = 37.573673248291016\n",
532
+ "[4500 / 5000] Loss = 37.48927307128906\n",
533
+ "[5000 / 5000] Loss = 37.43183135986328\n"
534
+ ]
535
+ }
536
+ ],
537
+ "source": [
538
+ "torch.manual_seed(1)\n",
539
+ "\n",
540
+ "model5000 = MadnessNN()\n",
541
+ "optimizer = optim.Adam(lr=0.001, params=model5000.parameters())\n",
542
+ "loss_fn = nn.MSELoss()\n",
543
+ "epochs = 5000\n",
544
+ "\n",
545
+ "for epoch in range(1, epochs + 1):\n",
546
+ " pred = model5000(X_trainT)\n",
547
+ " loss = loss_fn(pred, y_trainT)\n",
548
+ " loss.backward()\n",
549
+ " optimizer.step()\n",
550
+ " optimizer.zero_grad()\n",
551
+ "\n",
552
+ " if epoch % 500 == 0:\n",
553
+ " print(f\"[{epoch} / {epochs}] Loss = {loss}\") \n"
554
+ ]
555
+ },
556
+ {
557
+ "cell_type": "code",
558
+ "execution_count": 22,
559
+ "id": "b62fd19c",
560
+ "metadata": {},
561
+ "outputs": [],
562
+ "source": [
563
+ "# save\n",
564
+ "torch.save(model5000, os.path.join(\"models\", \"model5000.pth\"))"
565
+ ]
566
+ },
567
+ {
568
+ "cell_type": "code",
569
+ "execution_count": 23,
570
+ "id": "17694dc7",
571
+ "metadata": {},
572
+ "outputs": [
573
+ {
574
+ "name": "stdout",
575
+ "output_type": "stream",
576
+ "text": [
577
+ "MSE on testing data: 47.071144104003906\n"
578
+ ]
579
+ }
580
+ ],
581
+ "source": [
582
+ "# evaluate\n",
583
+ "model5000.eval()\n",
584
+ "\n",
585
+ "with torch.no_grad():\n",
586
+ " pred = model5000(X_testT)\n",
587
+ " loss = loss_fn(pred, y_testT)\n",
588
+ " print(f\"MSE on testing data: {loss}\")\n"
589
+ ]
590
+ }
591
+ ],
592
+ "metadata": {
593
+ "kernelspec": {
594
+ "display_name": "Python 3",
595
+ "language": "python",
596
+ "name": "python3"
597
+ },
598
+ "language_info": {
599
+ "codemirror_mode": {
600
+ "name": "ipython",
601
+ "version": 3
602
+ },
603
+ "file_extension": ".py",
604
+ "mimetype": "text/x-python",
605
+ "name": "python",
606
+ "nbconvert_exporter": "python",
607
+ "pygments_lexer": "ipython3",
608
+ "version": "3.11.7"
609
+ }
610
+ },
611
+ "nbformat": 4,
612
+ "nbformat_minor": 5
613
+ }
src/mens_pre_processing.ipynb ADDED
The diff for this file is too large to render. See raw diff