Jensen-holm
commited on
Commit
·
ef83bf7
1
Parent(s):
6d466dc
re trained the models on only tournament games and the ChalkSeedDiff
Browse filesadded as a feature. This really helped the womens neural network, but
the mens one is maybe a little bit worse
- data/AllSuperDetailedGames.csv +2 -2
- data/AllTeamsAgg.csv +1 -1
- models/Mnn10k.pth +2 -2
- models/Wnn10k.pth +2 -2
- src/__pycache__/visual_eval.cpython-311.pyc +0 -0
- src/baseline.ipynb +3 -3
- src/nn.ipynb +0 -0
- src/pre_processing.ipynb +258 -201
- src/visual_eval.py +33 -0
- src/visualizations.py +0 -27
data/AllSuperDetailedGames.csv
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3ec95fd20de671096891e8426969303b879e0720e52e6da3dc55a0369ba98787
|
3 |
+
size 1046244854
|
data/AllTeamsAgg.csv
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 31040659
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bdbaeb6c905ad5f6480ffe9acb3ebb75ffe8905954826574410c0d8c94a12826
|
3 |
size 31040659
|
models/Mnn10k.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:500d7ddd0596b59cf7fbfab9abc7d8f50278f8269f4399e9cb5222fe843418a7
|
3 |
+
size 19170
|
models/Wnn10k.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1de0daa46afa742af4a7d31a026b54917abfca4c972546208d50b827e2fe4119
|
3 |
+
size 19170
|
src/__pycache__/visual_eval.cpython-311.pyc
ADDED
Binary file (2.7 kB). View file
|
|
src/baseline.ipynb
CHANGED
@@ -68,7 +68,7 @@
|
|
68 |
"# games_df[\"BaselinePrediction\"] = games_df.apply(\n",
|
69 |
"# lambda row: predict_baseline(row),\n",
|
70 |
"# axis=1,\n",
|
71 |
-
"# )
|
72 |
]
|
73 |
},
|
74 |
{
|
@@ -461,7 +461,7 @@
|
|
461 |
"wmns_actual_T = torch.tensor(\n",
|
462 |
" wmns_subset[\"Win\"].values,\n",
|
463 |
" dtype=torch.float32,\n",
|
464 |
-
")
|
465 |
]
|
466 |
},
|
467 |
{
|
@@ -513,7 +513,7 @@
|
|
513 |
"plt.ylabel(\"True Positive Rate\")\n",
|
514 |
"\n",
|
515 |
"plt.tight_layout()\n",
|
516 |
-
"plt.show()
|
517 |
]
|
518 |
},
|
519 |
{
|
|
|
68 |
"# games_df[\"BaselinePrediction\"] = games_df.apply(\n",
|
69 |
"# lambda row: predict_baseline(row),\n",
|
70 |
"# axis=1,\n",
|
71 |
+
"# )"
|
72 |
]
|
73 |
},
|
74 |
{
|
|
|
461 |
"wmns_actual_T = torch.tensor(\n",
|
462 |
" wmns_subset[\"Win\"].values,\n",
|
463 |
" dtype=torch.float32,\n",
|
464 |
+
")"
|
465 |
]
|
466 |
},
|
467 |
{
|
|
|
513 |
"plt.ylabel(\"True Positive Rate\")\n",
|
514 |
"\n",
|
515 |
"plt.tight_layout()\n",
|
516 |
+
"plt.show()"
|
517 |
]
|
518 |
},
|
519 |
{
|
src/nn.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
src/pre_processing.ipynb
CHANGED
@@ -10,7 +10,7 @@
|
|
10 |
"import numpy as np\n",
|
11 |
"import os\n",
|
12 |
"\n",
|
13 |
-
"DATA_DIR = os.path.join(\"..\", \"data\")
|
14 |
]
|
15 |
},
|
16 |
{
|
@@ -212,10 +212,16 @@
|
|
212 |
}
|
213 |
],
|
214 |
"source": [
|
215 |
-
"detailed_tourney_games_df = pd.concat(
|
216 |
-
"
|
217 |
-
"
|
218 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
"\n",
|
220 |
"detailed_tourney_games_df.sample(5, random_state=1)"
|
221 |
]
|
@@ -419,10 +425,16 @@
|
|
419 |
}
|
420 |
],
|
421 |
"source": [
|
422 |
-
"detailed_reg_games_df = pd.concat(
|
423 |
-
"
|
424 |
-
"
|
425 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
"\n",
|
427 |
"detailed_reg_games_df.sample(5, random_state=1)"
|
428 |
]
|
@@ -445,7 +457,7 @@
|
|
445 |
"\n",
|
446 |
"detailed_metrics = {\n",
|
447 |
" \"Score\",\n",
|
448 |
-
" # \"Loc\"
|
449 |
" \"FGM\",\n",
|
450 |
" \"FGA\",\n",
|
451 |
" \"FGM3\",\n",
|
@@ -460,8 +472,12 @@
|
|
460 |
" \"PF\",\n",
|
461 |
"}\n",
|
462 |
"\n",
|
463 |
-
"w_renamed_cols = {f\"W{col}\": f\"Team{col}\" for col in detailed_metrics} | {
|
464 |
-
"
|
|
|
|
|
|
|
|
|
465 |
]
|
466 |
},
|
467 |
{
|
@@ -520,22 +536,26 @@
|
|
520 |
}
|
521 |
],
|
522 |
"source": [
|
523 |
-
"\n",
|
524 |
-
"
|
525 |
-
"
|
526 |
-
"
|
527 |
-
"
|
528 |
-
"
|
529 |
-
"
|
530 |
-
"
|
531 |
-
"
|
532 |
-
"
|
533 |
-
"
|
534 |
-
"
|
535 |
-
"
|
536 |
-
"
|
537 |
-
"\n",
|
538 |
-
"
|
|
|
|
|
|
|
|
|
539 |
"\n",
|
540 |
"detailed_reg_games_df.info()"
|
541 |
]
|
@@ -597,20 +617,30 @@
|
|
597 |
],
|
598 |
"source": [
|
599 |
"# do the same thing for the tournament games\n",
|
600 |
-
"detailed_tourney_games_df = pd.concat(
|
601 |
-
"
|
602 |
-
"
|
603 |
-
"
|
604 |
-
"
|
605 |
-
"
|
606 |
-
"
|
607 |
-
"
|
608 |
-
"
|
609 |
-
"
|
610 |
-
"
|
611 |
-
"
|
612 |
-
"
|
613 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
614 |
"\n",
|
615 |
"detailed_tourney_games_df.info()"
|
616 |
]
|
@@ -621,7 +651,6 @@
|
|
621 |
"metadata": {},
|
622 |
"outputs": [],
|
623 |
"source": [
|
624 |
-
"\n",
|
625 |
"for col in detailed_metrics:\n",
|
626 |
" detailed_reg_games_df[f\"{col}Diff\"] = detailed_reg_games_df.apply(\n",
|
627 |
" lambda row: row[f\"Team{col}\"] - row[f\"Opp{col}\"],\n",
|
@@ -631,7 +660,7 @@
|
|
631 |
" detailed_tourney_games_df[f\"{col}Diff\"] = detailed_tourney_games_df.apply(\n",
|
632 |
" lambda row: row[f\"Team{col}\"] - row[f\"Opp{col}\"],\n",
|
633 |
" axis=1,\n",
|
634 |
-
" )
|
635 |
]
|
636 |
},
|
637 |
{
|
@@ -671,16 +700,16 @@
|
|
671 |
" <th>TeamFGM</th>\n",
|
672 |
" <th>TeamFGA</th>\n",
|
673 |
" <th>...</th>\n",
|
674 |
-
" <th>PFDiff</th>\n",
|
675 |
-
" <th>TODiff</th>\n",
|
676 |
-
" <th>ORDiff</th>\n",
|
677 |
-
" <th>FGMDiff</th>\n",
|
678 |
-
" <th>BlkDiff</th>\n",
|
679 |
" <th>FTADiff</th>\n",
|
680 |
-
" <th>
|
681 |
-
" <th>FGM3Diff</th>\n",
|
682 |
" <th>ScoreDiff</th>\n",
|
683 |
" <th>FGADiff</th>\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
684 |
" </tr>\n",
|
685 |
" </thead>\n",
|
686 |
" <tbody>\n",
|
@@ -697,16 +726,16 @@
|
|
697 |
" <td>21</td>\n",
|
698 |
" <td>55</td>\n",
|
699 |
" <td>...</td>\n",
|
700 |
-
" <td>9</td>\n",
|
701 |
-
" <td>7</td>\n",
|
702 |
" <td>-11</td>\n",
|
703 |
-
" <td
|
|
|
|
|
704 |
" <td>1</td>\n",
|
|
|
705 |
" <td>-11</td>\n",
|
706 |
" <td>-7</td>\n",
|
707 |
-
" <td>-
|
708 |
-
" <td>-
|
709 |
-
" <td>-12</td>\n",
|
710 |
" </tr>\n",
|
711 |
" <tr>\n",
|
712 |
" <th>100732</th>\n",
|
@@ -721,16 +750,16 @@
|
|
721 |
" <td>23</td>\n",
|
722 |
" <td>60</td>\n",
|
723 |
" <td>...</td>\n",
|
724 |
-
" <td>-9</td>\n",
|
725 |
-
" <td>-6</td>\n",
|
726 |
-
" <td>-1</td>\n",
|
727 |
-
" <td>-1</td>\n",
|
728 |
-
" <td>2</td>\n",
|
729 |
" <td>17</td>\n",
|
730 |
-
" <td
|
731 |
-
" <td>-2</td>\n",
|
732 |
" <td>12</td>\n",
|
733 |
" <td>-4</td>\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
734 |
" </tr>\n",
|
735 |
" <tr>\n",
|
736 |
" <th>83150</th>\n",
|
@@ -745,16 +774,16 @@
|
|
745 |
" <td>27</td>\n",
|
746 |
" <td>58</td>\n",
|
747 |
" <td>...</td>\n",
|
|
|
748 |
" <td>-5</td>\n",
|
|
|
|
|
|
|
749 |
" <td>1</td>\n",
|
750 |
" <td>4</td>\n",
|
751 |
-
" <td>-1</td>\n",
|
752 |
-
" <td>2</td>\n",
|
753 |
-
" <td>10</td>\n",
|
754 |
" <td>-5</td>\n",
|
755 |
" <td>1</td>\n",
|
756 |
" <td>13</td>\n",
|
757 |
-
" <td>-6</td>\n",
|
758 |
" </tr>\n",
|
759 |
" <tr>\n",
|
760 |
" <th>345009</th>\n",
|
@@ -769,16 +798,16 @@
|
|
769 |
" <td>19</td>\n",
|
770 |
" <td>55</td>\n",
|
771 |
" <td>...</td>\n",
|
772 |
-
" <td>7</td>\n",
|
773 |
-
" <td>-5</td>\n",
|
774 |
-
" <td>2</td>\n",
|
775 |
-
" <td>1</td>\n",
|
776 |
-
" <td>-3</td>\n",
|
777 |
" <td>-11</td>\n",
|
|
|
|
|
|
|
778 |
" <td>-3</td>\n",
|
779 |
" <td>-1</td>\n",
|
|
|
|
|
|
|
780 |
" <td>-7</td>\n",
|
781 |
-
" <td>13</td>\n",
|
782 |
" </tr>\n",
|
783 |
" <tr>\n",
|
784 |
" <th>318707</th>\n",
|
@@ -793,16 +822,16 @@
|
|
793 |
" <td>20</td>\n",
|
794 |
" <td>51</td>\n",
|
795 |
" <td>...</td>\n",
|
|
|
796 |
" <td>2</td>\n",
|
797 |
-
" <td
|
798 |
-
" <td>
|
799 |
-
" <td>-3</td>\n",
|
800 |
" <td>1</td>\n",
|
801 |
-
" <td>-
|
|
|
802 |
" <td>-7</td>\n",
|
|
|
803 |
" <td>-3</td>\n",
|
804 |
-
" <td>-18</td>\n",
|
805 |
-
" <td>3</td>\n",
|
806 |
" </tr>\n",
|
807 |
" </tbody>\n",
|
808 |
"</table>\n",
|
@@ -817,19 +846,19 @@
|
|
817 |
"345009 2019 4 3435 58 3292 65 H 0 \n",
|
818 |
"318707 2013 128 3322 45 3270 63 N 0 \n",
|
819 |
"\n",
|
820 |
-
" TeamFGM TeamFGA ... PFDiff
|
821 |
-
"337067 21 55 ...
|
822 |
-
"100732 23 60 ... -9
|
823 |
-
"83150 27 58 ... -5
|
824 |
-
"345009 19 55 ... 7
|
825 |
-
"318707 20 51 ... 2
|
826 |
"\n",
|
827 |
-
"
|
828 |
-
"337067
|
829 |
-
"100732
|
830 |
-
"83150
|
831 |
-
"345009
|
832 |
-
"318707
|
833 |
"\n",
|
834 |
"[5 rows x 49 columns]"
|
835 |
]
|
@@ -878,10 +907,12 @@
|
|
878 |
"source": [
|
879 |
"# combine the two detailed game dataframes into one for future use\n",
|
880 |
"\n",
|
881 |
-
"all_detailed_games_df = pd.concat(
|
882 |
-
"
|
883 |
-
"
|
884 |
-
"
|
|
|
|
|
885 |
]
|
886 |
},
|
887 |
{
|
@@ -1306,7 +1337,13 @@
|
|
1306 |
"source": [
|
1307 |
"team_reg_agg = (\n",
|
1308 |
" detailed_reg_games_df.groupby([\"TeamID\", \"Season\", \"League\"])\n",
|
1309 |
-
" .agg(
|
|
|
|
|
|
|
|
|
|
|
|
|
1310 |
" .reset_index()\n",
|
1311 |
")\n",
|
1312 |
"\n",
|
@@ -1668,15 +1705,23 @@
|
|
1668 |
}
|
1669 |
],
|
1670 |
"source": [
|
1671 |
-
"# aggregate the same metrics for the tournament dataset
|
1672 |
"\n",
|
1673 |
"team_tourney_agg = (\n",
|
1674 |
" detailed_tourney_games_df.groupby([\"TeamID\", \"Season\", \"League\"])\n",
|
1675 |
-
" .agg(
|
|
|
|
|
|
|
|
|
|
|
|
|
1676 |
" .reset_index()\n",
|
1677 |
")\n",
|
1678 |
"\n",
|
1679 |
-
"team_tourney_agg.columns = [\
|
|
|
|
|
1680 |
"\n",
|
1681 |
"team_tourney_agg.sample(10, random_state=1)"
|
1682 |
]
|
@@ -1727,83 +1772,83 @@
|
|
1727 |
" </thead>\n",
|
1728 |
" <tbody>\n",
|
1729 |
" <tr>\n",
|
1730 |
-
" <th>
|
1731 |
-
" <td>
|
1732 |
-
" <td>
|
1733 |
-
" <td>
|
1734 |
-
" <td>
|
1735 |
-
" <td>
|
1736 |
-
" <td>
|
1737 |
-
" <td>
|
1738 |
-
" <td>
|
1739 |
-
" <td>
|
1740 |
" </tr>\n",
|
1741 |
" <tr>\n",
|
1742 |
-
" <th>
|
1743 |
-
" <td>
|
1744 |
-
" <td>
|
1745 |
-
" <td>
|
1746 |
-
" <td>
|
1747 |
-
" <td>
|
1748 |
-
" <td>
|
1749 |
-
" <td>
|
1750 |
-
" <td>
|
1751 |
-
" <td>
|
1752 |
" </tr>\n",
|
1753 |
" <tr>\n",
|
1754 |
-
" <th>
|
1755 |
-
" <td>
|
1756 |
-
" <td>
|
1757 |
-
" <td>
|
1758 |
" <td>M</td>\n",
|
1759 |
-
" <td>
|
1760 |
-
" <td>
|
1761 |
-
" <td>1985</td>\n",
|
1762 |
-
" <td>2024</td>\n",
|
1763 |
-
" <td>
|
1764 |
" </tr>\n",
|
1765 |
" <tr>\n",
|
1766 |
-
" <th>
|
1767 |
-
" <td>
|
1768 |
-
" <td>
|
1769 |
-
" <td>
|
1770 |
" <td>M</td>\n",
|
1771 |
-
" <td>
|
1772 |
-
" <td>
|
1773 |
-
" <td>1985</td>\n",
|
1774 |
-
" <td>2024</td>\n",
|
1775 |
-
" <td>
|
1776 |
" </tr>\n",
|
1777 |
" <tr>\n",
|
1778 |
-
" <th>
|
1779 |
-
" <td>
|
1780 |
-
" <td>
|
1781 |
-
" <td>
|
1782 |
-
" <td>
|
1783 |
" <td>big_east</td>\n",
|
1784 |
-
" <td>
|
1785 |
-
" <td>
|
1786 |
-
" <td>
|
1787 |
-
" <td>
|
1788 |
" </tr>\n",
|
1789 |
" </tbody>\n",
|
1790 |
"</table>\n",
|
1791 |
"</div>"
|
1792 |
],
|
1793 |
"text/plain": [
|
1794 |
-
"
|
1795 |
-
"
|
1796 |
-
"
|
1797 |
-
"
|
1798 |
-
"
|
1799 |
-
"
|
1800 |
"\n",
|
1801 |
-
"
|
1802 |
-
"
|
1803 |
-
"1
|
1804 |
-
"
|
1805 |
-
"
|
1806 |
-
"
|
1807 |
]
|
1808 |
},
|
1809 |
"execution_count": 15,
|
@@ -1812,26 +1857,34 @@
|
|
1812 |
}
|
1813 |
],
|
1814 |
"source": [
|
1815 |
-
"conference_df = pd.concat(
|
1816 |
-
"
|
1817 |
-
"
|
1818 |
-
"\n",
|
1819 |
-
"
|
1820 |
-
"
|
1821 |
-
"])\n",
|
1822 |
"\n",
|
1823 |
-
"team_conf_seeds_df = (\n",
|
1824 |
-
"
|
1825 |
-
"
|
1826 |
-
"
|
1827 |
-
"
|
1828 |
-
"
|
1829 |
-
"
|
1830 |
-
"
|
1831 |
-
"
|
1832 |
-
"
|
1833 |
-
"
|
1834 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1835 |
")\n",
|
1836 |
"\n",
|
1837 |
"team_conf_seeds_df[\"ChalkSeed\"] = team_conf_seeds_df.apply(\n",
|
@@ -1839,7 +1892,7 @@
|
|
1839 |
" axis=1,\n",
|
1840 |
")\n",
|
1841 |
"\n",
|
1842 |
-
"team_conf_seeds_df.
|
1843 |
]
|
1844 |
},
|
1845 |
{
|
@@ -2221,10 +2274,10 @@
|
|
2221 |
"source": [
|
2222 |
"# merge the tournament aggregated metrics with the regular season aggregated metrics\n",
|
2223 |
"team_agg_df = pd.merge(\n",
|
2224 |
-
" left=team_reg_agg
|
2225 |
-
" right=team_tourney_agg
|
2226 |
" how=\"left\",\n",
|
2227 |
-
" on=[\"TeamID\", \"Season\", \"League\"]
|
2228 |
" suffixes=(\" reg\", \" tourney\"),\n",
|
2229 |
" validate=\"1:1\",\n",
|
2230 |
")\n",
|
@@ -2260,10 +2313,10 @@
|
|
2260 |
"output_type": "stream",
|
2261 |
"text": [
|
2262 |
"<class 'pandas.core.frame.DataFrame'>\n",
|
2263 |
-
"Int64Index:
|
2264 |
"Columns: 459 entries, TeamID to ChalkSeed\n",
|
2265 |
-
"dtypes: float64(
|
2266 |
-
"memory usage:
|
2267 |
]
|
2268 |
}
|
2269 |
],
|
@@ -2283,7 +2336,7 @@
|
|
2283 |
"<class 'pandas.core.frame.DataFrame'>\n",
|
2284 |
"Int64Index: 377608 entries, 0 to 377607\n",
|
2285 |
"Columns: 508 entries, Season to ChalkSeed\n",
|
2286 |
-
"dtypes: float64(
|
2287 |
"memory usage: 1.4+ GB\n"
|
2288 |
]
|
2289 |
}
|
@@ -2344,7 +2397,7 @@
|
|
2344 |
"<class 'pandas.core.frame.DataFrame'>\n",
|
2345 |
"Int64Index: 377608 entries, 0 to 377607\n",
|
2346 |
"Columns: 509 entries, Season to OppChalkSeed\n",
|
2347 |
-
"dtypes: float64(
|
2348 |
"memory usage: 1.4+ GB\n"
|
2349 |
]
|
2350 |
}
|
@@ -2352,7 +2405,9 @@
|
|
2352 |
"source": [
|
2353 |
"opp_chalk_seed_map = team_conf_seeds_df.groupby(\"TeamID\")[\"ChalkSeed\"].last()\n",
|
2354 |
"\n",
|
2355 |
-
"super_detailed_games_df[\"OppChalkSeed\"] = super_detailed_games_df[\"OppTeamID\"].map(
|
|
|
|
|
2356 |
"\n",
|
2357 |
"super_detailed_games_df.info()"
|
2358 |
]
|
@@ -2365,18 +2420,18 @@
|
|
2365 |
{
|
2366 |
"data": {
|
2367 |
"text/plain": [
|
2368 |
-
"0
|
2369 |
-
"1
|
2370 |
-
"2
|
2371 |
-
"3
|
2372 |
-
"4
|
2373 |
-
"
|
2374 |
-
"377603
|
2375 |
-
"377604
|
2376 |
-
"377605
|
2377 |
-
"377606
|
2378 |
-
"377607
|
2379 |
-
"Name:
|
2380 |
]
|
2381 |
},
|
2382 |
"execution_count": 22,
|
@@ -2385,7 +2440,9 @@
|
|
2385 |
}
|
2386 |
],
|
2387 |
"source": [
|
2388 |
-
"super_detailed_games_df[\"
|
|
|
|
|
2389 |
]
|
2390 |
},
|
2391 |
{
|
|
|
10 |
"import numpy as np\n",
|
11 |
"import os\n",
|
12 |
"\n",
|
13 |
+
"DATA_DIR = os.path.join(\"..\", \"data\")"
|
14 |
]
|
15 |
},
|
16 |
{
|
|
|
212 |
}
|
213 |
],
|
214 |
"source": [
|
215 |
+
"detailed_tourney_games_df = pd.concat(\n",
|
216 |
+
" [\n",
|
217 |
+
" pd.read_csv(os.path.join(DATA_DIR, \"MNCAATourneyDetailedResults.csv\")).assign(\n",
|
218 |
+
" League=\"M\"\n",
|
219 |
+
" ),\n",
|
220 |
+
" pd.read_csv(os.path.join(DATA_DIR, \"WNCAATourneyDetailedResults.csv\")).assign(\n",
|
221 |
+
" League=\"W\"\n",
|
222 |
+
" ),\n",
|
223 |
+
" ]\n",
|
224 |
+
")\n",
|
225 |
"\n",
|
226 |
"detailed_tourney_games_df.sample(5, random_state=1)"
|
227 |
]
|
|
|
425 |
}
|
426 |
],
|
427 |
"source": [
|
428 |
+
"detailed_reg_games_df = pd.concat(\n",
|
429 |
+
" [\n",
|
430 |
+
" pd.read_csv(os.path.join(DATA_DIR, \"MRegularSeasonDetailedResults.csv\")).assign(\n",
|
431 |
+
" League=\"M\"\n",
|
432 |
+
" ),\n",
|
433 |
+
" pd.read_csv(os.path.join(DATA_DIR, \"WRegularSeasonDetailedResults.csv\")).assign(\n",
|
434 |
+
" League=\"W\"\n",
|
435 |
+
" ),\n",
|
436 |
+
" ]\n",
|
437 |
+
")\n",
|
438 |
"\n",
|
439 |
"detailed_reg_games_df.sample(5, random_state=1)"
|
440 |
]
|
|
|
457 |
"\n",
|
458 |
"detailed_metrics = {\n",
|
459 |
" \"Score\",\n",
|
460 |
+
" # \"Loc\",\n",
|
461 |
" \"FGM\",\n",
|
462 |
" \"FGA\",\n",
|
463 |
" \"FGM3\",\n",
|
|
|
472 |
" \"PF\",\n",
|
473 |
"}\n",
|
474 |
"\n",
|
475 |
+
"w_renamed_cols = {f\"W{col}\": f\"Team{col}\" for col in detailed_metrics} | {\n",
|
476 |
+
" f\"L{col}\": f\"Opp{col}\" for col in detailed_metrics\n",
|
477 |
+
"}\n",
|
478 |
+
"l_renamed_cols = {f\"L{col}\": f\"Team{col}\" for col in detailed_metrics} | {\n",
|
479 |
+
" f\"W{col}\": f\"Opp{col}\" for col in detailed_metrics\n",
|
480 |
+
"}"
|
481 |
]
|
482 |
},
|
483 |
{
|
|
|
536 |
}
|
537 |
],
|
538 |
"source": [
|
539 |
+
"detailed_reg_games_df = pd.concat(\n",
|
540 |
+
" [\n",
|
541 |
+
" (\n",
|
542 |
+
" # detailed_reg_games_df[[col for col in detailed_reg_games_df.columns if col != \"LTeamID\"]]\n",
|
543 |
+
" detailed_reg_games_df[[col for col in detailed_reg_games_df.columns]]\n",
|
544 |
+
" .assign(GameResult=\"W\")\n",
|
545 |
+
" .rename(\n",
|
546 |
+
" columns=w_renamed_cols | {\"WTeamID\": \"TeamID\", \"LTeamID\": \"OppTeamID\"}\n",
|
547 |
+
" )\n",
|
548 |
+
" ),\n",
|
549 |
+
" (\n",
|
550 |
+
" # detailed_reg_games_df[[col for col in detailed_reg_games_df.columns if col != \"WTeamID\"]]\n",
|
551 |
+
" detailed_reg_games_df[[col for col in detailed_reg_games_df.columns]]\n",
|
552 |
+
" .assign(GameResult=\"L\")\n",
|
553 |
+
" .rename(\n",
|
554 |
+
" columns=l_renamed_cols | {\"LTeamID\": \"TeamID\", \"WTeamID\": \"OppTeamID\"}\n",
|
555 |
+
" )\n",
|
556 |
+
" ),\n",
|
557 |
+
" ]\n",
|
558 |
+
").reset_index(drop=True)\n",
|
559 |
"\n",
|
560 |
"detailed_reg_games_df.info()"
|
561 |
]
|
|
|
617 |
],
|
618 |
"source": [
|
619 |
"# do the same thing for the tournament games\n",
|
620 |
+
"detailed_tourney_games_df = pd.concat(\n",
|
621 |
+
" [\n",
|
622 |
+
" (\n",
|
623 |
+
" # detailed_tourney_games_df[[col for col in detailed_tourney_games_df.columns if col != \"LTeamID\"]]\n",
|
624 |
+
" detailed_tourney_games_df[\n",
|
625 |
+
" [col for col in detailed_tourney_games_df.columns]\n",
|
626 |
+
" ]\n",
|
627 |
+
" .assign(GameResult=\"W\")\n",
|
628 |
+
" .rename(\n",
|
629 |
+
" columns=w_renamed_cols | {\"WTeamID\": \"TeamID\", \"LTeamID\": \"OppTeamID\"}\n",
|
630 |
+
" )\n",
|
631 |
+
" ),\n",
|
632 |
+
" (\n",
|
633 |
+
" # detailed_tourney_games_df[[col for col in detailed_tourney_games_df.columns if col != \"WTeamID\"]]\n",
|
634 |
+
" detailed_tourney_games_df[\n",
|
635 |
+
" [col for col in detailed_tourney_games_df.columns]\n",
|
636 |
+
" ]\n",
|
637 |
+
" .assign(GameResult=\"L\")\n",
|
638 |
+
" .rename(\n",
|
639 |
+
" columns=l_renamed_cols | {\"LTeamID\": \"TeamID\", \"WTeamID\": \"OppTeamID\"}\n",
|
640 |
+
" )\n",
|
641 |
+
" ),\n",
|
642 |
+
" ]\n",
|
643 |
+
").reset_index(drop=True)\n",
|
644 |
"\n",
|
645 |
"detailed_tourney_games_df.info()"
|
646 |
]
|
|
|
651 |
"metadata": {},
|
652 |
"outputs": [],
|
653 |
"source": [
|
|
|
654 |
"for col in detailed_metrics:\n",
|
655 |
" detailed_reg_games_df[f\"{col}Diff\"] = detailed_reg_games_df.apply(\n",
|
656 |
" lambda row: row[f\"Team{col}\"] - row[f\"Opp{col}\"],\n",
|
|
|
660 |
" detailed_tourney_games_df[f\"{col}Diff\"] = detailed_tourney_games_df.apply(\n",
|
661 |
" lambda row: row[f\"Team{col}\"] - row[f\"Opp{col}\"],\n",
|
662 |
" axis=1,\n",
|
663 |
+
" )"
|
664 |
]
|
665 |
},
|
666 |
{
|
|
|
700 |
" <th>TeamFGM</th>\n",
|
701 |
" <th>TeamFGA</th>\n",
|
702 |
" <th>...</th>\n",
|
|
|
|
|
|
|
|
|
|
|
703 |
" <th>FTADiff</th>\n",
|
704 |
+
" <th>PFDiff</th>\n",
|
|
|
705 |
" <th>ScoreDiff</th>\n",
|
706 |
" <th>FGADiff</th>\n",
|
707 |
+
" <th>BlkDiff</th>\n",
|
708 |
+
" <th>FGM3Diff</th>\n",
|
709 |
+
" <th>ORDiff</th>\n",
|
710 |
+
" <th>StlDiff</th>\n",
|
711 |
+
" <th>AstDiff</th>\n",
|
712 |
+
" <th>DRDiff</th>\n",
|
713 |
" </tr>\n",
|
714 |
" </thead>\n",
|
715 |
" <tbody>\n",
|
|
|
726 |
" <td>21</td>\n",
|
727 |
" <td>55</td>\n",
|
728 |
" <td>...</td>\n",
|
|
|
|
|
729 |
" <td>-11</td>\n",
|
730 |
+
" <td>9</td>\n",
|
731 |
+
" <td>-28</td>\n",
|
732 |
+
" <td>-12</td>\n",
|
733 |
" <td>1</td>\n",
|
734 |
+
" <td>-3</td>\n",
|
735 |
" <td>-11</td>\n",
|
736 |
" <td>-7</td>\n",
|
737 |
+
" <td>-1</td>\n",
|
738 |
+
" <td>-4</td>\n",
|
|
|
739 |
" </tr>\n",
|
740 |
" <tr>\n",
|
741 |
" <th>100732</th>\n",
|
|
|
750 |
" <td>23</td>\n",
|
751 |
" <td>60</td>\n",
|
752 |
" <td>...</td>\n",
|
|
|
|
|
|
|
|
|
|
|
753 |
" <td>17</td>\n",
|
754 |
+
" <td>-9</td>\n",
|
|
|
755 |
" <td>12</td>\n",
|
756 |
" <td>-4</td>\n",
|
757 |
+
" <td>2</td>\n",
|
758 |
+
" <td>-2</td>\n",
|
759 |
+
" <td>-1</td>\n",
|
760 |
+
" <td>4</td>\n",
|
761 |
+
" <td>11</td>\n",
|
762 |
+
" <td>1</td>\n",
|
763 |
" </tr>\n",
|
764 |
" <tr>\n",
|
765 |
" <th>83150</th>\n",
|
|
|
774 |
" <td>27</td>\n",
|
775 |
" <td>58</td>\n",
|
776 |
" <td>...</td>\n",
|
777 |
+
" <td>10</td>\n",
|
778 |
" <td>-5</td>\n",
|
779 |
+
" <td>13</td>\n",
|
780 |
+
" <td>-6</td>\n",
|
781 |
+
" <td>2</td>\n",
|
782 |
" <td>1</td>\n",
|
783 |
" <td>4</td>\n",
|
|
|
|
|
|
|
784 |
" <td>-5</td>\n",
|
785 |
" <td>1</td>\n",
|
786 |
" <td>13</td>\n",
|
|
|
787 |
" </tr>\n",
|
788 |
" <tr>\n",
|
789 |
" <th>345009</th>\n",
|
|
|
798 |
" <td>19</td>\n",
|
799 |
" <td>55</td>\n",
|
800 |
" <td>...</td>\n",
|
|
|
|
|
|
|
|
|
|
|
801 |
" <td>-11</td>\n",
|
802 |
+
" <td>7</td>\n",
|
803 |
+
" <td>-7</td>\n",
|
804 |
+
" <td>13</td>\n",
|
805 |
" <td>-3</td>\n",
|
806 |
" <td>-1</td>\n",
|
807 |
+
" <td>2</td>\n",
|
808 |
+
" <td>-3</td>\n",
|
809 |
+
" <td>4</td>\n",
|
810 |
" <td>-7</td>\n",
|
|
|
811 |
" </tr>\n",
|
812 |
" <tr>\n",
|
813 |
" <th>318707</th>\n",
|
|
|
822 |
" <td>20</td>\n",
|
823 |
" <td>51</td>\n",
|
824 |
" <td>...</td>\n",
|
825 |
+
" <td>-11</td>\n",
|
826 |
" <td>2</td>\n",
|
827 |
+
" <td>-18</td>\n",
|
828 |
+
" <td>3</td>\n",
|
|
|
829 |
" <td>1</td>\n",
|
830 |
+
" <td>-3</td>\n",
|
831 |
+
" <td>2</td>\n",
|
832 |
" <td>-7</td>\n",
|
833 |
+
" <td>2</td>\n",
|
834 |
" <td>-3</td>\n",
|
|
|
|
|
835 |
" </tr>\n",
|
836 |
" </tbody>\n",
|
837 |
"</table>\n",
|
|
|
846 |
"345009 2019 4 3435 58 3292 65 H 0 \n",
|
847 |
"318707 2013 128 3322 45 3270 63 N 0 \n",
|
848 |
"\n",
|
849 |
+
" TeamFGM TeamFGA ... FTADiff PFDiff ScoreDiff FGADiff BlkDiff \\\n",
|
850 |
+
"337067 21 55 ... -11 9 -28 -12 1 \n",
|
851 |
+
"100732 23 60 ... 17 -9 12 -4 2 \n",
|
852 |
+
"83150 27 58 ... 10 -5 13 -6 2 \n",
|
853 |
+
"345009 19 55 ... -11 7 -7 13 -3 \n",
|
854 |
+
"318707 20 51 ... -11 2 -18 3 1 \n",
|
855 |
"\n",
|
856 |
+
" FGM3Diff ORDiff StlDiff AstDiff DRDiff \n",
|
857 |
+
"337067 -3 -11 -7 -1 -4 \n",
|
858 |
+
"100732 -2 -1 4 11 1 \n",
|
859 |
+
"83150 1 4 -5 1 13 \n",
|
860 |
+
"345009 -1 2 -3 4 -7 \n",
|
861 |
+
"318707 -3 2 -7 2 -3 \n",
|
862 |
"\n",
|
863 |
"[5 rows x 49 columns]"
|
864 |
]
|
|
|
907 |
"source": [
|
908 |
"# combine the two detailed game dataframes into one for future use\n",
|
909 |
"\n",
|
910 |
+
"all_detailed_games_df = pd.concat(\n",
|
911 |
+
" [\n",
|
912 |
+
" detailed_reg_games_df.assign(GameType=\"reg\"),\n",
|
913 |
+
" detailed_tourney_games_df.assign(GameType=\"tourney\"),\n",
|
914 |
+
" ]\n",
|
915 |
+
")"
|
916 |
]
|
917 |
},
|
918 |
{
|
|
|
1337 |
"source": [
|
1338 |
"team_reg_agg = (\n",
|
1339 |
" detailed_reg_games_df.groupby([\"TeamID\", \"Season\", \"League\"])\n",
|
1340 |
+
" .agg(\n",
|
1341 |
+
" {\n",
|
1342 |
+
" col: agg_funcs\n",
|
1343 |
+
" for col in detailed_reg_games_df.select_dtypes(\"number\").columns\n",
|
1344 |
+
" if col not in exclude_agg_cols\n",
|
1345 |
+
" }\n",
|
1346 |
+
" )\n",
|
1347 |
" .reset_index()\n",
|
1348 |
")\n",
|
1349 |
"\n",
|
|
|
1705 |
}
|
1706 |
],
|
1707 |
"source": [
|
1708 |
+
"# aggregate the same metrics for the tournament dataset\n",
|
1709 |
"\n",
|
1710 |
"team_tourney_agg = (\n",
|
1711 |
" detailed_tourney_games_df.groupby([\"TeamID\", \"Season\", \"League\"])\n",
|
1712 |
+
" .agg(\n",
|
1713 |
+
" {\n",
|
1714 |
+
" col: agg_funcs\n",
|
1715 |
+
" for col in detailed_tourney_games_df.select_dtypes(\"number\").columns\n",
|
1716 |
+
" if col not in exclude_agg_cols\n",
|
1717 |
+
" }\n",
|
1718 |
+
" )\n",
|
1719 |
" .reset_index()\n",
|
1720 |
")\n",
|
1721 |
"\n",
|
1722 |
+
"team_tourney_agg.columns = [\n",
|
1723 |
+
" \" \".join(col).strip() for col in team_tourney_agg.columns.values\n",
|
1724 |
+
"]\n",
|
1725 |
"\n",
|
1726 |
"team_tourney_agg.sample(10, random_state=1)"
|
1727 |
]
|
|
|
1772 |
" </thead>\n",
|
1773 |
" <tbody>\n",
|
1774 |
" <tr>\n",
|
1775 |
+
" <th>3591</th>\n",
|
1776 |
+
" <td>2004</td>\n",
|
1777 |
+
" <td>X02</td>\n",
|
1778 |
+
" <td>3243</td>\n",
|
1779 |
+
" <td>W</td>\n",
|
1780 |
+
" <td>big_twelve</td>\n",
|
1781 |
+
" <td>Kansas St</td>\n",
|
1782 |
+
" <td>NaN</td>\n",
|
1783 |
+
" <td>NaN</td>\n",
|
1784 |
+
" <td>2</td>\n",
|
1785 |
" </tr>\n",
|
1786 |
" <tr>\n",
|
1787 |
+
" <th>3528</th>\n",
|
1788 |
+
" <td>2013</td>\n",
|
1789 |
+
" <td>Y01</td>\n",
|
1790 |
+
" <td>3124</td>\n",
|
1791 |
+
" <td>W</td>\n",
|
1792 |
+
" <td>big_twelve</td>\n",
|
1793 |
+
" <td>Baylor</td>\n",
|
1794 |
+
" <td>NaN</td>\n",
|
1795 |
+
" <td>NaN</td>\n",
|
1796 |
+
" <td>1</td>\n",
|
1797 |
" </tr>\n",
|
1798 |
" <tr>\n",
|
1799 |
+
" <th>1891</th>\n",
|
1800 |
+
" <td>2003</td>\n",
|
1801 |
+
" <td>W02</td>\n",
|
1802 |
+
" <td>1448</td>\n",
|
1803 |
" <td>M</td>\n",
|
1804 |
+
" <td>acc</td>\n",
|
1805 |
+
" <td>Wake Forest</td>\n",
|
1806 |
+
" <td>1985.0</td>\n",
|
1807 |
+
" <td>2024.0</td>\n",
|
1808 |
+
" <td>2</td>\n",
|
1809 |
" </tr>\n",
|
1810 |
" <tr>\n",
|
1811 |
+
" <th>778</th>\n",
|
1812 |
+
" <td>2019</td>\n",
|
1813 |
+
" <td>Y01</td>\n",
|
1814 |
+
" <td>1314</td>\n",
|
1815 |
" <td>M</td>\n",
|
1816 |
+
" <td>acc</td>\n",
|
1817 |
+
" <td>North Carolina</td>\n",
|
1818 |
+
" <td>1985.0</td>\n",
|
1819 |
+
" <td>2024.0</td>\n",
|
1820 |
+
" <td>1</td>\n",
|
1821 |
" </tr>\n",
|
1822 |
" <tr>\n",
|
1823 |
+
" <th>2932</th>\n",
|
1824 |
+
" <td>2019</td>\n",
|
1825 |
+
" <td>X05</td>\n",
|
1826 |
+
" <td>3266</td>\n",
|
1827 |
+
" <td>W</td>\n",
|
1828 |
" <td>big_east</td>\n",
|
1829 |
+
" <td>Marquette</td>\n",
|
1830 |
+
" <td>NaN</td>\n",
|
1831 |
+
" <td>NaN</td>\n",
|
1832 |
+
" <td>5</td>\n",
|
1833 |
" </tr>\n",
|
1834 |
" </tbody>\n",
|
1835 |
"</table>\n",
|
1836 |
"</div>"
|
1837 |
],
|
1838 |
"text/plain": [
|
1839 |
+
" Season Seed TeamID League ConfAbbrev TeamName FirstD1Season \\\n",
|
1840 |
+
"3591 2004 X02 3243 W big_twelve Kansas St NaN \n",
|
1841 |
+
"3528 2013 Y01 3124 W big_twelve Baylor NaN \n",
|
1842 |
+
"1891 2003 W02 1448 M acc Wake Forest 1985.0 \n",
|
1843 |
+
"778 2019 Y01 1314 M acc North Carolina 1985.0 \n",
|
1844 |
+
"2932 2019 X05 3266 W big_east Marquette NaN \n",
|
1845 |
"\n",
|
1846 |
+
" LastD1Season ChalkSeed \n",
|
1847 |
+
"3591 NaN 2 \n",
|
1848 |
+
"3528 NaN 1 \n",
|
1849 |
+
"1891 2024.0 2 \n",
|
1850 |
+
"778 2024.0 1 \n",
|
1851 |
+
"2932 NaN 5 "
|
1852 |
]
|
1853 |
},
|
1854 |
"execution_count": 15,
|
|
|
1857 |
}
|
1858 |
],
|
1859 |
"source": [
|
1860 |
+
"conference_df = pd.concat(\n",
|
1861 |
+
" [\n",
|
1862 |
+
" pd.read_csv(os.path.join(DATA_DIR, \"MNCAATourneySeeds.csv\")).assign(League=\"M\"),\n",
|
1863 |
+
" pd.read_csv(os.path.join(DATA_DIR, \"WNCAATourneySeeds.csv\")).assign(League=\"W\"),\n",
|
1864 |
+
" ]\n",
|
1865 |
+
")\n",
|
|
|
1866 |
"\n",
|
1867 |
+
"team_conf_seeds_df = conference_df.merge(\n",
|
1868 |
+
" right=(\n",
|
1869 |
+
" pd.concat(\n",
|
1870 |
+
" [\n",
|
1871 |
+
" pd.read_csv(os.path.join(DATA_DIR, \"MTeamConferences.csv\")).assign(\n",
|
1872 |
+
" League=\"M\"\n",
|
1873 |
+
" ),\n",
|
1874 |
+
" pd.read_csv(os.path.join(DATA_DIR, \"WTeamConferences.csv\")).assign(\n",
|
1875 |
+
" League=\"W\"\n",
|
1876 |
+
" ),\n",
|
1877 |
+
" ]\n",
|
1878 |
+
" )\n",
|
1879 |
+
" ),\n",
|
1880 |
+
" on=[\"League\", \"Season\", \"TeamID\"],\n",
|
1881 |
+
" how=\"left\",\n",
|
1882 |
+
").merge(right=(\n",
|
1883 |
+
" pd.concat([\n",
|
1884 |
+
" pd.read_csv(os.path.join(DATA_DIR, \"MTeams.csv\")),\n",
|
1885 |
+
" pd.read_csv(os.path.join(DATA_DIR, \"WTeams.csv\")),\n",
|
1886 |
+
" ])),\n",
|
1887 |
+
" on=\"TeamID\",\n",
|
1888 |
")\n",
|
1889 |
"\n",
|
1890 |
"team_conf_seeds_df[\"ChalkSeed\"] = team_conf_seeds_df.apply(\n",
|
|
|
1892 |
" axis=1,\n",
|
1893 |
")\n",
|
1894 |
"\n",
|
1895 |
+
"team_conf_seeds_df.sample(5, random_state=1)"
|
1896 |
]
|
1897 |
},
|
1898 |
{
|
|
|
2274 |
"source": [
|
2275 |
"# merge the tournament aggregated metrics with the regular season aggregated metrics\n",
|
2276 |
"team_agg_df = pd.merge(\n",
|
2277 |
+
" left=team_reg_agg,\n",
|
2278 |
+
" right=team_tourney_agg,\n",
|
2279 |
" how=\"left\",\n",
|
2280 |
+
" on=[\"TeamID\", \"Season\", \"League\"],\n",
|
2281 |
" suffixes=(\" reg\", \" tourney\"),\n",
|
2282 |
" validate=\"1:1\",\n",
|
2283 |
")\n",
|
|
|
2313 |
"output_type": "stream",
|
2314 |
"text": [
|
2315 |
"<class 'pandas.core.frame.DataFrame'>\n",
|
2316 |
+
"Int64Index: 13305 entries, 0 to 13304\n",
|
2317 |
"Columns: 459 entries, TeamID to ChalkSeed\n",
|
2318 |
+
"dtypes: float64(453), int64(2), object(4)\n",
|
2319 |
+
"memory usage: 46.7+ MB\n"
|
2320 |
]
|
2321 |
}
|
2322 |
],
|
|
|
2336 |
"<class 'pandas.core.frame.DataFrame'>\n",
|
2337 |
"Int64Index: 377608 entries, 0 to 377607\n",
|
2338 |
"Columns: 508 entries, Season to ChalkSeed\n",
|
2339 |
+
"dtypes: float64(453), int64(48), object(7)\n",
|
2340 |
"memory usage: 1.4+ GB\n"
|
2341 |
]
|
2342 |
}
|
|
|
2397 |
"<class 'pandas.core.frame.DataFrame'>\n",
|
2398 |
"Int64Index: 377608 entries, 0 to 377607\n",
|
2399 |
"Columns: 509 entries, Season to OppChalkSeed\n",
|
2400 |
+
"dtypes: float64(454), int64(48), object(7)\n",
|
2401 |
"memory usage: 1.4+ GB\n"
|
2402 |
]
|
2403 |
}
|
|
|
2405 |
"source": [
|
2406 |
"opp_chalk_seed_map = team_conf_seeds_df.groupby(\"TeamID\")[\"ChalkSeed\"].last()\n",
|
2407 |
"\n",
|
2408 |
+
"super_detailed_games_df[\"OppChalkSeed\"] = super_detailed_games_df[\"OppTeamID\"].map(\n",
|
2409 |
+
" opp_chalk_seed_map\n",
|
2410 |
+
")\n",
|
2411 |
"\n",
|
2412 |
"super_detailed_games_df.info()"
|
2413 |
]
|
|
|
2420 |
{
|
2421 |
"data": {
|
2422 |
"text/plain": [
|
2423 |
+
"0 2.0\n",
|
2424 |
+
"1 -4.0\n",
|
2425 |
+
"2 1.0\n",
|
2426 |
+
"3 NaN\n",
|
2427 |
+
"4 -9.0\n",
|
2428 |
+
" ... \n",
|
2429 |
+
"377603 1.0\n",
|
2430 |
+
"377604 2.0\n",
|
2431 |
+
"377605 -1.0\n",
|
2432 |
+
"377606 -2.0\n",
|
2433 |
+
"377607 -1.0\n",
|
2434 |
+
"Name: ChalkSeedDiff, Length: 377608, dtype: float64"
|
2435 |
]
|
2436 |
},
|
2437 |
"execution_count": 22,
|
|
|
2440 |
}
|
2441 |
],
|
2442 |
"source": [
|
2443 |
+
"super_detailed_games_df[\"ChalkSeedDiff\"] = (\n",
|
2444 |
+
" super_detailed_games_df[\"ChalkSeed\"] - super_detailed_games_df[\"OppChalkSeed\"]\n",
|
2445 |
+
")"
|
2446 |
]
|
2447 |
},
|
2448 |
{
|
src/visual_eval.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn.metrics import roc_curve, precision_recall_curve
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def eval_binary_classification(pred: np.array, true: np.array):
|
7 |
+
plt.figure(figsize=(12, 6))
|
8 |
+
eval_roc_curve(pred, true)
|
9 |
+
eval_pr_curve(pred, true)
|
10 |
+
plt.tight_layout()
|
11 |
+
plt.show()
|
12 |
+
|
13 |
+
|
14 |
+
def eval_pr_curve(pred: np.array, true: np.array):
|
15 |
+
precision, recall, _ = precision_recall_curve(true, pred)
|
16 |
+
plt.subplot(1, 2, 1)
|
17 |
+
plt.plot(recall, precision, label="Precision-Recall Curve", color="red")
|
18 |
+
plt.ylim(0)
|
19 |
+
plt.xlabel("Recall")
|
20 |
+
plt.ylabel("Precision")
|
21 |
+
plt.title("Precision-Recall Curve")
|
22 |
+
plt.legend(loc="lower right")
|
23 |
+
|
24 |
+
|
25 |
+
def eval_roc_curve(pred: np.array, true: np.array) -> None:
|
26 |
+
false_pos_rate, true_pos_rate, _ = roc_curve(true, pred)
|
27 |
+
plt.subplot(1, 2, 2)
|
28 |
+
plt.plot(false_pos_rate, true_pos_rate, label="ROC Curve")
|
29 |
+
plt.plot([0, 1], [0, 1], linestyle="--", label="Random Guessing Model")
|
30 |
+
plt.title("ROC Curve vs. Random")
|
31 |
+
plt.xlabel("False Positive Rate")
|
32 |
+
plt.ylabel("True Positive Rate")
|
33 |
+
plt.legend(loc="lower right")
|
src/visualizations.py
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import matplotlib.pyplot as plt
|
3 |
-
from sklearn.metrics import (
|
4 |
-
roc_curve,
|
5 |
-
roc_auc_score,
|
6 |
-
precision_recall_curve,
|
7 |
-
)
|
8 |
-
|
9 |
-
|
10 |
-
def roc_plot(y_true: np.array, y_pred: np.array):
|
11 |
-
fpr, tpr, _ = roc_curve(y_true, y_pred)
|
12 |
-
|
13 |
-
# Plot ROC Curve
|
14 |
-
plt.plot(fpr, tpr, label="ROC Curve")
|
15 |
-
plt.plot([0, 1], [0, 1], linestyle="--", label="Random Model")
|
16 |
-
plt.title("ROC Curve vs. Random Model")
|
17 |
-
plt.xlabel("False Positive Rate")
|
18 |
-
plt.ylabel("True Positive Rate")
|
19 |
-
plt.legend("lower right")
|
20 |
-
|
21 |
-
plt.tight_layout()
|
22 |
-
plt.show()
|
23 |
-
|
24 |
-
|
25 |
-
def precision_recall_plot(y_true: np.array, y_pred: np.array, baseline_pred: np.array):
|
26 |
-
...
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|